from .base_sync import BaseSync from loguru import logger from typing import List, Dict, Any, Set, Tuple import json import asyncio from datetime import datetime from sqlalchemy import text, and_, select, delete from models.orm_models import StrategyPosition import time class PositionSyncBatch(BaseSync): """持仓数据批量同步器""" def __init__(self): super().__init__() self.batch_size = 500 # 每批处理数量 async def sync_batch(self, accounts: Dict[str, Dict]): """批量同步所有账号的持仓数据""" try: logger.info(f"开始批量同步持仓数据,共 {len(accounts)} 个账号") start_time = time.time() # 1. 收集所有账号的持仓数据 all_positions = await self._collect_all_positions(accounts) if not all_positions: logger.info("无持仓数据需要同步") return logger.info(f"收集到 {len(all_positions)} 条持仓数据") # 2. 批量同步到数据库 success, stats = await self._sync_positions_batch_to_db(all_positions) elapsed = time.time() - start_time if success: logger.info(f"持仓批量同步完成: 处理 {stats['total']} 条,更新 {stats['updated']} 条," f"插入 {stats['inserted']} 条,删除 {stats['deleted']} 条,耗时 {elapsed:.2f}秒") else: logger.error("持仓批量同步失败") except Exception as e: logger.error(f"持仓批量同步失败: {e}") async def _collect_all_positions(self, accounts: Dict[str, Dict]) -> List[Dict]: """收集所有账号的持仓数据""" all_positions = [] try: # 按交易所分组账号 account_groups = self._group_accounts_by_exchange(accounts) # 并发收集每个交易所的数据 tasks = [] for exchange_id, account_list in account_groups.items(): task = self._collect_exchange_positions(exchange_id, account_list) tasks.append(task) # 等待所有任务完成并合并结果 results = await asyncio.gather(*tasks, return_exceptions=True) for result in results: if isinstance(result, list): all_positions.extend(result) except Exception as e: logger.error(f"收集持仓数据失败: {e}") return all_positions def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]: """按交易所分组账号""" groups = {} for account_id, account_info in accounts.items(): exchange_id = account_info.get('exchange_id') if exchange_id: if exchange_id not in groups: groups[exchange_id] = [] groups[exchange_id].append(account_info) return groups async def _collect_exchange_positions(self, exchange_id: str, account_list: List[Dict]) -> List[Dict]: """收集某个交易所的持仓数据""" positions_list = [] try: tasks = [] for account_info in account_list: k_id = int(account_info['k_id']) st_id = account_info.get('st_id', 0) task = self._get_positions_from_redis(k_id, st_id, exchange_id) tasks.append(task) # 并发获取 results = await asyncio.gather(*tasks, return_exceptions=True) for result in results: if isinstance(result, list): positions_list.extend(result) except Exception as e: logger.error(f"收集交易所 {exchange_id} 持仓数据失败: {e}") return positions_list async def _get_positions_from_redis(self, k_id: int, st_id: int, exchange_id: str) -> List[Dict]: """从Redis获取持仓数据""" try: redis_key = f"{exchange_id}:positions:{k_id}" redis_data = self.redis_client.client.hget(redis_key, 'positions') if not redis_data: return [] positions = json.loads(redis_data) # 添加账号信息 for position in positions: position['k_id'] = k_id position['st_id'] = st_id position['exchange_id'] = exchange_id return positions except Exception as e: logger.error(f"获取Redis持仓数据失败: k_id={k_id}, error={e}") return [] async def _sync_positions_batch_to_db(self, all_positions: List[Dict]) -> Tuple[bool, Dict]: """批量同步持仓数据到数据库""" try: if not all_positions: return True, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0} # 按账号分组 positions_by_account = {} for position in all_positions: k_id = position['k_id'] if k_id not in positions_by_account: positions_by_account[k_id] = [] positions_by_account[k_id].append(position) logger.info(f"开始批量处理 {len(positions_by_account)} 个账号的持仓数据") # 批量处理每个账号 total_stats = {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0} for k_id, positions in positions_by_account.items(): st_id = positions[0]['st_id'] if positions else 0 # 处理单个账号的批量同步 success, stats = await self._sync_single_account_batch(k_id, st_id, positions) if success: total_stats['total'] += stats['total'] total_stats['updated'] += stats['updated'] total_stats['inserted'] += stats['inserted'] total_stats['deleted'] += stats['deleted'] return True, total_stats except Exception as e: logger.error(f"批量同步持仓到数据库失败: {e}") return False, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0} async def _sync_single_account_batch(self, k_id: int, st_id: int, positions: List[Dict]) -> Tuple[bool, Dict]: """批量同步单个账号的持仓数据""" session = self.db_manager.get_session() try: # 准备数据 insert_data = [] new_positions_map = {} # (symbol, side) -> position_id (用于删除) for position_data in positions: try: position_dict = self._convert_position_data(position_data) if not all([position_dict.get('symbol'), position_dict.get('side')]): continue symbol = position_dict['symbol'] side = position_dict['side'] key = (symbol, side) # 重命名qty为sum if 'qty' in position_dict: position_dict['sum'] = position_dict.pop('qty') insert_data.append(position_dict) new_positions_map[key] = position_dict.get('id') # 如果有id的话 except Exception as e: logger.error(f"转换持仓数据失败: {position_data}, error={e}") continue with session.begin(): if not insert_data: # 清空该账号所有持仓 result = session.execute( delete(StrategyPosition).where( and_( StrategyPosition.k_id == k_id, StrategyPosition.st_id == st_id ) ) ) deleted_count = result.rowcount return True, { 'total': 0, 'updated': 0, 'inserted': 0, 'deleted': deleted_count } # 1. 批量插入/更新持仓数据 processed_count = self._batch_upsert_positions(session, insert_data) # 2. 批量删除多余持仓 deleted_count = self._batch_delete_extra_positions(session, k_id, st_id, new_positions_map) # 注意:这里无法区分插入和更新的数量,processed_count是总处理数 inserted_count = processed_count # 简化处理 updated_count = 0 # 需要更复杂的逻辑来区分 stats = { 'total': len(insert_data), 'updated': updated_count, 'inserted': inserted_count, 'deleted': deleted_count } return True, stats except Exception as e: logger.error(f"批量同步账号 {k_id} 持仓失败: {e}") return False, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0} finally: session.close() def _batch_upsert_positions(self, session, insert_data: List[Dict]) -> int: """批量插入/更新持仓数据""" try: # 分块处理 chunk_size = self.batch_size total_processed = 0 for i in range(0, len(insert_data), chunk_size): chunk = insert_data[i:i + chunk_size] values_list = [] for data in chunk: values = ( f"({data['st_id']}, {data['k_id']}, '{data.get('asset', 'USDT')}', " f"'{data['symbol'].replace(\"'\", \"''\")}', '{data['side']}', " f"{data.get('price') or 'NULL'}, {data.get('sum') or 'NULL'}, " f"{data.get('asset_num') or 'NULL'}, {data.get('asset_profit') or 'NULL'}, " f"{data.get('leverage') or 'NULL'}, {data.get('uptime') or 'NULL'}, " f"{data.get('profit_price') or 'NULL'}, {data.get('stop_price') or 'NULL'}, " f"{data.get('liquidation_price') or 'NULL'})" ) values_list.append(values) if values_list: values_str = ", ".join(values_list) sql = f""" INSERT INTO deh_strategy_position_new (st_id, k_id, asset, symbol, side, price, `sum`, asset_num, asset_profit, leverage, uptime, profit_price, stop_price, liquidation_price) VALUES {values_str} ON DUPLICATE KEY UPDATE price = VALUES(price), `sum` = VALUES(`sum`), asset_num = VALUES(asset_num), asset_profit = VALUES(asset_profit), leverage = VALUES(leverage), uptime = VALUES(uptime), profit_price = VALUES(profit_price), stop_price = VALUES(stop_price), liquidation_price = VALUES(liquidation_price) """ session.execute(text(sql)) total_processed += len(chunk) return total_processed except Exception as e: logger.error(f"批量插入/更新持仓失败: {e}") raise def _batch_delete_extra_positions(self, session, k_id: int, st_id: int, new_positions_map: Dict) -> int: """批量删除多余持仓""" try: if not new_positions_map: # 删除所有持仓 result = session.execute( delete(StrategyPosition).where( and_( StrategyPosition.k_id == k_id, StrategyPosition.st_id == st_id ) ) ) return result.rowcount # 构建保留条件 conditions = [] for (symbol, side) in new_positions_map.keys(): safe_symbol = symbol.replace("'", "''") if symbol else '' safe_side = side.replace("'", "''") if side else '' conditions.append(f"(symbol = '{safe_symbol}' AND side = '{safe_side}')") if conditions: conditions_str = " OR ".join(conditions) sql = f""" DELETE FROM deh_strategy_position_new WHERE k_id = {k_id} AND st_id = {st_id} AND NOT ({conditions_str}) """ result = session.execute(text(sql)) return result.rowcount return 0 except Exception as e: logger.error(f"批量删除持仓失败: k_id={k_id}, error={e}") return 0 def _convert_position_data(self, data: Dict) -> Dict: """转换持仓数据格式""" try: # 安全转换函数 def safe_float(value, default=None): if value is None: return default try: return float(value) except (ValueError, TypeError): return default def safe_int(value, default=None): if value is None: return default try: return int(float(value)) except (ValueError, TypeError): return default return { 'st_id': safe_int(data.get('st_id'), 0), 'k_id': safe_int(data.get('k_id'), 0), 'asset': data.get('asset', 'USDT'), 'symbol': data.get('symbol', ''), 'side': data.get('side', ''), 'price': safe_float(data.get('price')), 'qty': safe_float(data.get('qty')), # 后面会重命名为sum 'asset_num': safe_float(data.get('asset_num')), 'asset_profit': safe_float(data.get('asset_profit')), 'leverage': safe_int(data.get('leverage')), 'uptime': safe_int(data.get('uptime')), 'profit_price': safe_float(data.get('profit_price')), 'stop_price': safe_float(data.get('stop_price')), 'liquidation_price': safe_float(data.get('liquidation_price')) } except Exception as e: logger.error(f"转换持仓数据异常: {data}, error={e}") return {} async def sync(self): """兼容旧接口""" accounts = self.get_accounts_from_redis() await self.sync_batch(accounts)