from .base_sync import BaseSync from loguru import logger from typing import List, Dict import json import asyncio from concurrent.futures import ThreadPoolExecutor class PositionSync(BaseSync): """持仓数据同步器(批量版本)""" def __init__(self): super().__init__() self.max_concurrent = 10 # 每个同步器的最大并发数 async def sync_batch(self, accounts: Dict[str, Dict]): """批量同步所有账号的持仓数据""" try: logger.info(f"开始批量同步持仓数据,共 {len(accounts)} 个账号") # 按账号分组 account_groups = self._group_accounts_by_exchange(accounts) # 并发处理每个交易所的账号 tasks = [] for exchange_id, account_list in account_groups.items(): task = self._sync_exchange_accounts(exchange_id, account_list) tasks.append(task) # 等待所有任务完成 results = await asyncio.gather(*tasks, return_exceptions=True) # 统计结果 success_count = sum(1 for r in results if isinstance(r, bool) and r) logger.info(f"持仓批量同步完成: 成功 {success_count}/{len(results)} 个交易所组") except Exception as e: logger.error(f"持仓批量同步失败: {e}") def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]: """按交易所分组账号""" groups = {} for account_id, account_info in accounts.items(): exchange_id = account_info.get('exchange_id') if exchange_id: if exchange_id not in groups: groups[exchange_id] = [] groups[exchange_id].append(account_info) return groups async def _sync_exchange_accounts(self, exchange_id: str, account_list: List[Dict]): """同步某个交易所的所有账号""" try: # 收集所有账号的持仓数据 all_positions = [] for account_info in account_list: k_id = int(account_info['k_id']) st_id = account_info.get('st_id', 0) # 从Redis获取持仓数据 positions = await self._get_positions_from_redis(k_id, exchange_id) if positions: # 添加账号信息 for position in positions: position['k_id'] = k_id position['st_id'] = st_id all_positions.extend(positions) if not all_positions: logger.debug(f"交易所 {exchange_id} 无持仓数据") return True # 批量同步到数据库 success = self._sync_positions_batch_to_db(all_positions) if success: logger.info(f"交易所 {exchange_id} 持仓同步成功: {len(all_positions)} 条持仓") return success except Exception as e: logger.error(f"同步交易所 {exchange_id} 持仓失败: {e}") return False def _sync_positions_batch_to_db(self, all_positions: List[Dict]) -> bool: """批量同步持仓数据到数据库(优化版)""" session = self.db_manager.get_session() try: # 按k_id分组 positions_by_account = {} for position in all_positions: k_id = position['k_id'] if k_id not in positions_by_account: positions_by_account[k_id] = [] positions_by_account[k_id].append(position) success_count = 0 with session.begin(): for k_id, positions in positions_by_account.items(): try: st_id = positions[0]['st_id'] if positions else 0 # 准备数据 insert_data = [] keep_keys = set() for pos_data in positions: try: pos_dict = self._convert_position_data(pos_data) if not all([pos_dict.get('symbol'), pos_dict.get('side')]): continue # 重命名qty为sum if 'qty' in pos_dict: pos_dict['sum'] = pos_dict.pop('qty') insert_data.append(pos_dict) keep_keys.add((pos_dict['symbol'], pos_dict['side'])) except Exception as e: logger.error(f"转换持仓数据失败: {pos_data}, error={e}") continue if not insert_data: continue # 批量插入/更新 from sqlalchemy.dialects.mysql import insert stmt = insert(StrategyPosition.__table__).values(insert_data) update_dict = { 'price': stmt.inserted.price, 'sum': stmt.inserted.sum, 'asset_num': stmt.inserted.asset_num, 'asset_profit': stmt.inserted.asset_profit, 'leverage': stmt.inserted.leverage, 'uptime': stmt.inserted.uptime, 'profit_price': stmt.inserted.profit_price, 'stop_price': stmt.inserted.stop_price, 'liquidation_price': stmt.inserted.liquidation_price } stmt = stmt.on_duplicate_key_update(**update_dict) session.execute(stmt) # 删除多余持仓 if keep_keys: existing_positions = session.execute( select(StrategyPosition).where( and_( StrategyPosition.k_id == k_id, StrategyPosition.st_id == st_id ) ) ).scalars().all() to_delete_ids = [] for existing in existing_positions: key = (existing.symbol, existing.side) if key not in keep_keys: to_delete_ids.append(existing.id) if to_delete_ids: # 分块删除 chunk_size = 100 for i in range(0, len(to_delete_ids), chunk_size): chunk = to_delete_ids[i:i + chunk_size] session.execute( delete(StrategyPosition).where( StrategyPosition.id.in_(chunk) ) ) success_count += 1 except Exception as e: logger.error(f"同步账号 {k_id} 持仓失败: {e}") continue logger.info(f"批量同步完成: 成功 {success_count}/{len(positions_by_account)} 个账号") return success_count > 0 except Exception as e: logger.error(f"批量同步持仓到数据库失败: {e}") return False finally: session.close() # 其他方法保持不变...