import asyncio from loguru import logger import signal import sys import time import json from typing import Dict import re from utils.redis_client import RedisClient from config.settings import SYNC_CONFIG from .position_sync import PositionSyncBatch from .order_sync import OrderSyncBatch # 使用批量版本 from .account_sync import AccountSyncBatch from utils.redis_batch_helper import RedisBatchHelper from config.settings import COMPUTER_NAMES, COMPUTER_NAME_PATTERN from typing import List, Dict, Any, Set, Optional class SyncManager: """同步管理器(完整批量版本)""" def __init__(self): self.is_running = True self.redis_client = RedisClient() self.sync_interval = SYNC_CONFIG['interval'] self.computer_names = self._get_computer_names() self.computer_name_pattern = re.compile(COMPUTER_NAME_PATTERN) # 初始化批量同步工具 self.redis_helper = None # 初始化同步器 self.syncers = [] if SYNC_CONFIG['enable_position_sync']: position_sync = PositionSyncBatch() self.syncers.append(position_sync) logger.info("启用持仓批量同步") if SYNC_CONFIG['enable_order_sync']: order_sync = OrderSyncBatch() self.syncers.append(order_sync) logger.info("启用订单批量同步") if SYNC_CONFIG['enable_account_sync']: account_sync = AccountSyncBatch() self.syncers.append(account_sync) logger.info("启用账户信息批量同步") # 性能统计 self.stats = { 'total_syncs': 0, 'last_sync_time': 0, 'avg_sync_time': 0, 'position': {'accounts': 0, 'positions': 0, 'time': 0}, 'order': {'accounts': 0, 'orders': 0, 'time': 0}, 'account': {'accounts': 0, 'records': 0, 'time': 0} } # 注册信号处理器 signal.signal(signal.SIGINT, self.signal_handler) signal.signal(signal.SIGTERM, self.signal_handler) async def start(self): """启动同步服务""" logger.info(f"同步服务启动,间隔 {self.sync_interval} 秒") while self.is_running: try: # 获取所有账号(只获取一次) accounts = await self.get_accounts_from_redis() if not accounts: logger.warning("未获取到任何账号,等待下次同步") await asyncio.sleep(self.sync_interval) continue self.stats['total_syncs'] += 1 sync_start = time.time() logger.info(f"第{self.stats['total_syncs']}次同步开始,共 {len(accounts)} 个账号") # 执行所有同步器 tasks = [syncer.sync(accounts) for syncer in self.syncers] await asyncio.gather(*tasks, return_exceptions=True) # 更新统计 sync_time = time.time() - sync_start self._update_stats(sync_time) logger.info(f"同步完成,总耗时 {sync_time:.2f} 秒,等待 {self.sync_interval} 秒") await asyncio.sleep(self.sync_interval) except asyncio.CancelledError: logger.info("同步任务被取消") break except Exception as e: logger.error(f"同步任务异常: {e}") await asyncio.sleep(30) def get_accounts_from_redis(self) -> Dict[str, Dict]: """从Redis获取所有计算机名的账号配置""" try: accounts_dict = {} total_keys_processed = 0 # 方法1:使用配置的计算机名列表 for computer_name in self.computer_names: accounts = self._get_accounts_by_computer_name(computer_name) total_keys_processed += 1 accounts_dict.update(accounts) # 方法2:如果配置的计算机名没有数据,尝试自动发现(备用方案) if not accounts_dict: logger.warning("配置的计算机名未找到数据,尝试自动发现...") accounts_dict = self._discover_all_accounts() self.sync_stats['total_accounts'] = len(accounts_dict) logger.info(f"从 {len(self.computer_names)} 个计算机名获取到 {len(accounts_dict)} 个账号") return accounts_dict except Exception as e: logger.error(f"获取账户信息失败: {e}") return {} def _get_computer_names(self) -> List[str]: """获取计算机名列表""" if ',' in COMPUTER_NAMES: names = [name.strip() for name in COMPUTER_NAMES.split(',')] logger.info(f"使用配置的计算机名列表: {names}") return names return [COMPUTER_NAMES.strip()] def _get_accounts_by_computer_name(self, computer_name: str) -> Dict[str, Dict]: """获取指定计算机名的账号""" accounts_dict = {} try: # 构建key redis_key = f"{computer_name}_strategy_api" # 从Redis获取数据 result = self.redis_client.client.hgetall(redis_key) if not result: logger.debug(f"未找到 {redis_key} 的策略API配置") return {} logger.info(f"从 {redis_key} 获取到 {len(result)} 个交易所配置") for exchange_name, accounts_json in result.items(): try: accounts = json.loads(accounts_json) if not accounts: continue # 格式化交易所ID exchange_id = self.format_exchange_id(exchange_name) for account_id, account_info in accounts.items(): parsed_account = self.parse_account(exchange_id, account_id, account_info) if parsed_account: # 添加计算机名标记 parsed_account['computer_name'] = computer_name accounts_dict[account_id] = parsed_account except json.JSONDecodeError as e: logger.error(f"解析交易所 {exchange_name} 的JSON数据失败: {e}") continue except Exception as e: logger.error(f"处理交易所 {exchange_name} 数据异常: {e}") continue logger.info(f"从 {redis_key} 解析到 {len(accounts_dict)} 个账号") except Exception as e: logger.error(f"获取计算机名 {computer_name} 的账号失败: {e}") return accounts_dict def _discover_all_accounts(self) -> Dict[str, Dict]: """自动发现所有匹配的账号key""" accounts_dict = {} discovered_keys = [] try: # 获取所有匹配模式的key pattern = "*_strategy_api" cursor = 0 while True: cursor, keys = self.redis_client.client.scan(cursor, match=pattern, count=100) for key in keys: key_str = key.decode('utf-8') if isinstance(key, bytes) else key discovered_keys.append(key_str) if cursor == 0: break logger.info(f"自动发现 {len(discovered_keys)} 个策略API key") # 处理每个发现的key for key_str in discovered_keys: # 提取计算机名 computer_name = key_str.replace('_strategy_api', '') # 验证计算机名格式 if self.computer_name_pattern.match(computer_name): accounts = self._get_accounts_by_computer_name(computer_name) accounts_dict.update(accounts) else: logger.warning(f"跳过不符合格式的计算机名: {computer_name}") logger.info(f"自动发现共获取到 {len(accounts_dict)} 个账号") except Exception as e: logger.error(f"自动发现账号失败: {e}") return accounts_dict def _discover_all_accounts(self) -> Dict[str, Dict]: """自动发现所有匹配的账号key""" accounts_dict = {} discovered_keys = [] try: # 获取所有匹配模式的key pattern = "*_strategy_api" cursor = 0 while True: cursor, keys = self.redis_client.client.scan(cursor, match=pattern, count=100) for key in keys: key_str = key.decode('utf-8') if isinstance(key, bytes) else key discovered_keys.append(key_str) if cursor == 0: break logger.info(f"自动发现 {len(discovered_keys)} 个策略API key") # 处理每个发现的key for key_str in discovered_keys: # 提取计算机名 computer_name = key_str.replace('_strategy_api', '') # 验证计算机名格式 if self.computer_name_pattern.match(computer_name): accounts = self._get_accounts_by_computer_name(computer_name) accounts_dict.update(accounts) else: logger.warning(f"跳过不符合格式的计算机名: {computer_name}") logger.info(f"自动发现共获取到 {len(accounts_dict)} 个账号") except Exception as e: logger.error(f"自动发现账号失败: {e}") return accounts_dict def format_exchange_id(self, key: str) -> str: """格式化交易所ID""" key = key.lower().strip() # 交易所名称映射 exchange_mapping = { 'metatrader': 'mt5', 'binance_spot_test': 'binance', 'binance_spot': 'binance', 'binance': 'binance', 'gate_spot': 'gate', 'okex': 'okx', 'okx': 'okx', 'bybit': 'bybit', 'bybit_spot': 'bybit', 'bybit_test': 'bybit', 'huobi': 'huobi', 'huobi_spot': 'huobi', 'gate': 'gate', 'gateio': 'gate', 'kucoin': 'kucoin', 'kucoin_spot': 'kucoin', 'mexc': 'mexc', 'mexc_spot': 'mexc', 'bitget': 'bitget', 'bitget_spot': 'bitget' } normalized_key = exchange_mapping.get(key, key) # 记录未映射的交易所 if normalized_key == key and key not in exchange_mapping.values(): logger.debug(f"未映射的交易所名称: {key}") return normalized_key def parse_account(self, exchange_id: str, account_id: str, account_info: str) -> Optional[Dict]: """解析账号信息""" try: source_account_info = json.loads(account_info) # 基础信息 account_data = { 'exchange_id': exchange_id, 'k_id': account_id, 'st_id': self._safe_int(source_account_info.get('st_id'), 0), 'add_time': self._safe_int(source_account_info.get('add_time'), 0), 'account_type': source_account_info.get('account_type', 'real'), 'api_key': source_account_info.get('api_key', ''), 'secret_key': source_account_info.get('secret_key', ''), 'password': source_account_info.get('password', ''), 'access_token': source_account_info.get('access_token', ''), 'remark': source_account_info.get('remark', '') } # 合并原始信息 result = {**source_account_info, **account_data} # 验证必要字段 if not result.get('st_id') or not result.get('exchange_id'): logger.warning(f"账号 {account_id} 缺少必要字段: st_id={result.get('st_id')}, exchange_id={result.get('exchange_id')}") return None return result except json.JSONDecodeError as e: logger.error(f"解析账号 {account_id} JSON数据失败: {e}, 原始数据: {account_info[:100]}...") return None except Exception as e: logger.error(f"处理账号 {account_id} 数据异常: {e}") return None def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]: """按交易所分组账号""" groups = {} for account_id, account_info in accounts.items(): exchange_id = account_info.get('exchange_id') if exchange_id: if exchange_id not in groups: groups[exchange_id] = [] groups[exchange_id].append(account_info) return groups def _update_stats(self, sync_time: float): """更新统计信息""" self.stats['last_sync_time'] = sync_time self.stats['avg_sync_time'] = (self.stats['avg_sync_time'] * 0.9 + sync_time * 0.1) # 打印详细统计 stats_lines = [ f"=== 第{self.stats['total_syncs']}次同步统计 ===", f"总耗时: {sync_time:.2f}秒 | 平均耗时: {self.stats['avg_sync_time']:.2f}秒" ] if self.stats['position']['accounts'] > 0: stats_lines.append( f"持仓: {self.stats['position']['accounts']}账号/{self.stats['position']['positions']}条" f"/{self.stats['position']['time']:.2f}秒" ) if self.stats['order']['accounts'] > 0: stats_lines.append( f"订单: {self.stats['order']['accounts']}账号/{self.stats['order']['orders']}条" f"/{self.stats['order']['time']:.2f}秒" ) if self.stats['account']['accounts'] > 0: stats_lines.append( f"账户: {self.stats['account']['accounts']}账号/{self.stats['account']['records']}条" f"/{self.stats['account']['time']:.2f}秒" ) logger.info("\n".join(stats_lines)) def signal_handler(self, signum, frame): """信号处理器""" logger.info(f"接收到信号 {signum},正在关闭...") self.is_running = False async def stop(self): """停止同步服务""" self.is_running = False # 关闭所有数据库连接 for syncer in self.syncers: if hasattr(syncer, 'db_manager'): syncer.db_manager.close() logger.info("同步服务停止")