from .base_sync import BaseSync from loguru import logger from typing import List, Dict, Any, Tuple import json import asyncio import time from datetime import datetime, timedelta from sqlalchemy import text import redis class OrderSyncBatch(BaseSync): """订单数据批量同步器""" def __init__(self): super().__init__() self.batch_size = 1000 # 每批处理数量 self.recent_days = 3 # 同步最近几天的数据 async def sync_batch(self, accounts: Dict[str, Dict]): """批量同步所有账号的订单数据""" try: logger.info(f"开始批量同步订单数据,共 {len(accounts)} 个账号") start_time = time.time() # 1. 收集所有账号的订单数据 all_orders = await self._collect_all_orders(accounts) if not all_orders: logger.info("无订单数据需要同步") return logger.info(f"收集到 {len(all_orders)} 条订单数据") # 2. 批量同步到数据库 success, processed_count = await self._sync_orders_batch_to_db(all_orders) elapsed = time.time() - start_time if success: logger.info(f"订单批量同步完成: 处理 {processed_count} 条订单,耗时 {elapsed:.2f}秒") else: logger.error("订单批量同步失败") except Exception as e: logger.error(f"订单批量同步失败: {e}") async def _collect_all_orders(self, accounts: Dict[str, Dict]) -> List[Dict]: """收集所有账号的订单数据""" all_orders = [] try: # 按交易所分组账号 account_groups = self._group_accounts_by_exchange(accounts) # 并发收集每个交易所的数据 tasks = [] for exchange_id, account_list in account_groups.items(): task = self._collect_exchange_orders(exchange_id, account_list) tasks.append(task) # 等待所有任务完成并合并结果 results = await asyncio.gather(*tasks, return_exceptions=True) for result in results: if isinstance(result, list): all_orders.extend(result) except Exception as e: logger.error(f"收集订单数据失败: {e}") return all_orders def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]: """按交易所分组账号""" groups = {} for account_id, account_info in accounts.items(): exchange_id = account_info.get('exchange_id') if exchange_id: if exchange_id not in groups: groups[exchange_id] = [] groups[exchange_id].append(account_info) return groups async def _collect_exchange_orders(self, exchange_id: str, account_list: List[Dict]) -> List[Dict]: """收集某个交易所的订单数据""" orders_list = [] try: # 并发获取每个账号的数据 tasks = [] for account_info in account_list: k_id = int(account_info['k_id']) st_id = account_info.get('st_id', 0) task = self._get_recent_orders_from_redis(k_id, st_id, exchange_id) tasks.append(task) results = await asyncio.gather(*tasks, return_exceptions=True) for result in results: if isinstance(result, list): orders_list.extend(result) logger.debug(f"交易所 {exchange_id}: 收集到 {len(orders_list)} 条订单") except Exception as e: logger.error(f"收集交易所 {exchange_id} 订单数据失败: {e}") return orders_list async def _get_recent_orders_from_redis(self, k_id: int, st_id: int, exchange_id: str) -> List[Dict]: """从Redis获取最近N天的订单数据""" try: redis_key = f"{exchange_id}:orders:{k_id}" # 计算最近N天的日期 today = datetime.now() recent_dates = [] for i in range(self.recent_days): date = today - timedelta(days=i) date_format = date.strftime('%Y-%m-%d') recent_dates.append(date_format) # 使用scan获取所有符合条件的key cursor = 0 recent_keys = [] while True: cursor, keys = self.redis_client.client.hscan(redis_key, cursor, count=1000) for key, _ in keys.items(): key_str = key.decode('utf-8') if isinstance(key, bytes) else key if key_str == 'positions': continue # 检查是否以最近N天的日期开头 for date_format in recent_dates: if key_str.startswith(date_format + '_'): recent_keys.append(key_str) break if cursor == 0: break if not recent_keys: return [] # 批量获取订单数据 orders_list = [] # 分批获取,避免单次hgetall数据量太大 chunk_size = 500 for i in range(0, len(recent_keys), chunk_size): chunk_keys = recent_keys[i:i + chunk_size] # 使用hmget批量获取 chunk_values = self.redis_client.client.hmget(redis_key, chunk_keys) for key, order_json in zip(chunk_keys, chunk_values): if not order_json: continue try: order = json.loads(order_json) # 验证时间 order_time = order.get('time', 0) if order_time >= int(time.time()) - self.recent_days * 24 * 3600: # 添加账号信息 order['k_id'] = k_id order['st_id'] = st_id order['exchange_id'] = exchange_id orders_list.append(order) except json.JSONDecodeError as e: logger.debug(f"解析订单JSON失败: key={key}, error={e}") continue return orders_list except Exception as e: logger.error(f"获取Redis订单数据失败: k_id={k_id}, error={e}") return [] async def _sync_orders_batch_to_db(self, all_orders: List[Dict]) -> Tuple[bool, int]: """批量同步订单数据到数据库""" try: if not all_orders: return True, 0 # 转换数据 converted_orders = [] for order in all_orders: try: order_dict = self._convert_order_data(order) # 检查完整性 required_fields = ['order_id', 'symbol', 'side', 'time'] if not all(order_dict.get(field) for field in required_fields): continue converted_orders.append(order_dict) except Exception as e: logger.error(f"转换订单数据失败: {order}, error={e}") continue if not converted_orders: return True, 0 # 使用批量工具同步 from utils.batch_order_sync import BatchOrderSync batch_tool = BatchOrderSync(self.db_manager, self.batch_size) success, processed_count = batch_tool.sync_orders_batch(converted_orders) return success, processed_count except Exception as e: logger.error(f"批量同步订单到数据库失败: {e}") return False, 0 def _convert_order_data(self, data: Dict) -> Dict: """转换订单数据格式""" try: # 安全转换函数 def safe_float(value): if value is None: return None try: return float(value) except (ValueError, TypeError): return None def safe_int(value): if value is None: return None try: return int(float(value)) except (ValueError, TypeError): return None def safe_str(value): if value is None: return '' return str(value) return { 'st_id': safe_int(data.get('st_id'), 0), 'k_id': safe_int(data.get('k_id'), 0), 'asset': 'USDT', 'order_id': safe_str(data.get('order_id')), 'symbol': safe_str(data.get('symbol')), 'side': safe_str(data.get('side')), 'price': safe_float(data.get('price')), 'time': safe_int(data.get('time')), 'order_qty': safe_float(data.get('order_qty')), 'last_qty': safe_float(data.get('last_qty')), 'avg_price': safe_float(data.get('avg_price')), 'exchange_id': None # 忽略该字段 } except Exception as e: logger.error(f"转换订单数据异常: {data}, error={e}") return {} async def sync(self): """兼容旧接口""" accounts = self.get_accounts_from_redis() await self.sync_batch(accounts)