This commit is contained in:
lz_db
2025-12-03 14:40:14 +08:00
parent c8a6cfead1
commit 803d40b88e
10 changed files with 2408 additions and 97 deletions

View File

@@ -1,99 +1,97 @@
import asyncio
from loguru import logger
from typing import List, Dict, Optional
import signal
import sys
from concurrent.futures import ThreadPoolExecutor
import time
from asyncio import Semaphore
from typing import Dict
from config.settings import SYNC_CONFIG
from .position_sync import PositionSync
from .order_sync import OrderSync
from .account_sync import AccountSync
from .position_sync_batch import PositionSyncBatch
from .order_sync_batch import OrderSyncBatch # 使用批量版本
from .account_sync_batch import AccountSyncBatch
from utils.batch_position_sync import BatchPositionSync
from utils.batch_order_sync import BatchOrderSync
from utils.batch_account_sync import BatchAccountSync
from utils.redis_batch_helper import RedisBatchHelper
class SyncManager:
"""同步管理器(支持批量并发处理"""
"""同步管理器(完整批量版本"""
def __init__(self):
self.is_running = True
self.sync_interval = SYNC_CONFIG['interval']
self.max_concurrent = int(os.getenv('MAX_CONCURRENT', '10')) # 最大并发数
# 初始化批量同步工具
self.batch_tools = {}
self.redis_helper = None
# 初始化同步器
self.syncers = []
self.executor = ThreadPoolExecutor(max_workers=self.max_concurrent)
self.semaphore = Semaphore(self.max_concurrent) # 控制并发数
if SYNC_CONFIG['enable_position_sync']:
self.syncers.append(PositionSync())
logger.info("启用持仓同步")
position_sync = PositionSyncBatch()
self.syncers.append(position_sync)
self.batch_tools['position'] = BatchPositionSync(position_sync.db_manager)
logger.info("启用持仓批量同步")
if SYNC_CONFIG['enable_order_sync']:
self.syncers.append(OrderSync())
logger.info("启用订单同步")
order_sync = OrderSyncBatch()
self.syncers.append(order_sync)
self.batch_tools['order'] = BatchOrderSync(order_sync.db_manager)
# 初始化Redis批量助手
if order_sync.redis_client:
self.redis_helper = RedisBatchHelper(order_sync.redis_client.client)
logger.info("启用订单批量同步")
if SYNC_CONFIG['enable_account_sync']:
self.syncers.append(AccountSync())
logger.info("启用账户信息同步")
account_sync = AccountSyncBatch()
self.syncers.append(account_sync)
self.batch_tools['account'] = BatchAccountSync(account_sync.db_manager)
logger.info("启用账户信息批量同步")
# 性能统计
self.stats = {
'total_accounts': 0,
'success_count': 0,
'error_count': 0,
'total_syncs': 0,
'last_sync_time': 0,
'avg_sync_time': 0
'avg_sync_time': 0,
'position': {'accounts': 0, 'positions': 0, 'time': 0},
'order': {'accounts': 0, 'orders': 0, 'time': 0},
'account': {'accounts': 0, 'records': 0, 'time': 0}
}
# 注册信号处理器
signal.signal(signal.SIGINT, self.signal_handler)
signal.signal(signal.SIGTERM, self.signal_handler)
async def _run_syncer_with_limit(self, syncer):
"""带并发限制的运行"""
async with self.semaphore:
return await self._run_syncer(syncer)
def signal_handler(self, signum, frame):
"""信号处理器"""
logger.info(f"接收到信号 {signum},正在关闭...")
self.is_running = False
def batch_process_accounts(self, accounts: Dict[str, Dict], batch_size: int = 100):
"""分批处理账号"""
account_items = list(accounts.items())
for i in range(0, len(account_items), batch_size):
batch = dict(account_items[i:i + batch_size])
# 处理这批账号
self._process_account_batch(batch)
# 批次间休息,避免数据库压力过大
time.sleep(0.1)
async def start(self):
"""启动同步服务"""
logger.info(f"同步服务启动,间隔 {self.sync_interval},最大并发 {self.max_concurrent}")
logger.info(f"同步服务启动,间隔 {self.sync_interval}")
while self.is_running:
try:
start_time = time.time()
self.stats['total_syncs'] += 1
sync_start = time.time()
# 执行所有同步器
tasks = [self._run_syncer(syncer) for syncer in self.syncers]
results = await asyncio.gather(*tasks, return_exceptions=True)
# 获取所有账号(只获取一次)
accounts = await self._get_all_accounts()
if not accounts:
logger.warning("未获取到任何账号,等待下次同步")
await asyncio.sleep(self.sync_interval)
continue
logger.info(f"{self.stats['total_syncs']}次同步开始,共 {len(accounts)} 个账号")
# 并发执行所有同步
await self._execute_all_syncers_concurrent(accounts)
# 更新统计
sync_time = time.time() - start_time
self.stats['last_sync_time'] = sync_time
self.stats['avg_sync_time'] = (self.stats['avg_sync_time'] * 0.9 + sync_time * 0.1)
sync_time = time.time() - sync_start
self._update_stats(sync_time)
# 打印统计信息
self._print_stats()
logger.debug(f"同步完成,耗时 {sync_time:.2f} 秒,等待 {self.sync_interval}")
logger.info(f"同步完成,总耗时 {sync_time:.2f} 秒,等待 {self.sync_interval}")
await asyncio.sleep(self.sync_interval)
except asyncio.CancelledError:
@@ -101,41 +99,182 @@ class SyncManager:
break
except Exception as e:
logger.error(f"同步任务异常: {e}")
self.stats['error_count'] += 1
await asyncio.sleep(30) # 出错后等待30秒
await asyncio.sleep(30)
async def _run_syncer(self, syncer):
"""运行单个同步器"""
try:
# 获取所有账号
accounts = syncer.get_accounts_from_redis()
self.stats['total_accounts'] = len(accounts)
async def _get_all_accounts(self) -> Dict[str, Dict]:
"""获取所有账号"""
if not self.syncers:
return {}
# 使用第一个同步器获取账号
return self.syncers[0].get_accounts_from_redis()
async def _execute_all_syncers_concurrent(self, accounts: Dict[str, Dict]):
"""并发执行所有同步器"""
tasks = []
# 持仓批量同步
if 'position' in self.batch_tools:
task = self._sync_positions_batch(accounts)
tasks.append(task)
# 订单批量同步
if 'order' in self.batch_tools:
task = self._sync_orders_batch(accounts)
tasks.append(task)
# 账户信息批量同步
if 'account' in self.batch_tools:
task = self._sync_accounts_batch(accounts)
tasks.append(task)
# 并发执行所有任务
if tasks:
results = await asyncio.gather(*tasks, return_exceptions=True)
if not accounts:
logger.warning("未获取到任何账号")
# 检查结果
for i, result in enumerate(results):
if isinstance(result, Exception):
logger.error(f"同步任务 {i} 失败: {result}")
async def _sync_positions_batch(self, accounts: Dict[str, Dict]):
"""批量同步持仓数据"""
try:
start_time = time.time()
# 收集所有持仓数据
position_sync = next((s for s in self.syncers if isinstance(s, PositionSyncBatch)), None)
if not position_sync:
return
# 批量处理账号
await syncer.sync_batch(accounts)
self.stats['success_count'] += 1
all_positions = await position_sync._collect_all_positions(accounts)
if not all_positions:
self.stats['position'] = {'accounts': 0, 'positions': 0, 'time': 0}
return
# 使用批量工具同步
batch_tool = self.batch_tools['position']
success, stats = batch_tool.sync_positions_batch(all_positions)
if success:
elapsed = time.time() - start_time
self.stats['position'] = {
'accounts': len(accounts),
'positions': stats['total'],
'time': elapsed
}
except Exception as e:
logger.error(f"批量同步持仓失败: {e}")
self.stats['position'] = {'accounts': 0, 'positions': 0, 'time': 0}
async def _sync_orders_batch(self, accounts: Dict[str, Dict]):
"""批量同步订单数据"""
try:
start_time = time.time()
# 收集所有订单数据
order_sync = next((s for s in self.syncers if isinstance(s, OrderSyncBatch)), None)
if not order_sync:
return
all_orders = await order_sync._collect_all_orders(accounts)
if not all_orders:
self.stats['order'] = {'accounts': 0, 'orders': 0, 'time': 0}
return
# 使用批量工具同步
batch_tool = self.batch_tools['order']
success, processed_count = batch_tool.sync_orders_batch(all_orders)
if success:
elapsed = time.time() - start_time
self.stats['order'] = {
'accounts': len(accounts),
'orders': processed_count,
'time': elapsed
}
except Exception as e:
logger.error(f"批量同步订单失败: {e}")
self.stats['order'] = {'accounts': 0, 'orders': 0, 'time': 0}
async def _sync_accounts_batch(self, accounts: Dict[str, Dict]):
"""批量同步账户信息数据"""
try:
start_time = time.time()
# 收集所有账户数据
account_sync = next((s for s in self.syncers if isinstance(s, AccountSyncBatch)), None)
if not account_sync:
return
all_account_data = await account_sync._collect_all_account_data(accounts)
if not all_account_data:
self.stats['account'] = {'accounts': 0, 'records': 0, 'time': 0}
return
# 使用批量工具同步
batch_tool = self.batch_tools['account']
updated, inserted = batch_tool.sync_accounts_batch(all_account_data)
elapsed = time.time() - start_time
self.stats['account'] = {
'accounts': len(accounts),
'records': len(all_account_data),
'time': elapsed
}
except Exception as e:
logger.error(f"同步器 {syncer.__class__.__name__} 执行失败: {e}")
self.stats['error_count'] += 1
logger.error(f"批量同步账户信息失败: {e}")
self.stats['account'] = {'accounts': 0, 'records': 0, 'time': 0}
def _print_stats(self):
"""打印统计信息"""
stats_str = (
f"统计: 账号数={self.stats['total_accounts']}, "
f"成功={self.stats['success_count']}, "
f"失败={self.stats['error_count']}, "
f"本次耗时={self.stats['last_sync_time']:.2f}s, "
f"平均耗时={self.stats['avg_sync_time']:.2f}s"
)
logger.info(stats_str)
def _update_stats(self, sync_time: float):
"""更新统计信息"""
self.stats['last_sync_time'] = sync_time
self.stats['avg_sync_time'] = (self.stats['avg_sync_time'] * 0.9 + sync_time * 0.1)
# 打印详细统计
stats_lines = [
f"=== 第{self.stats['total_syncs']}次同步统计 ===",
f"总耗时: {sync_time:.2f}秒 | 平均耗时: {self.stats['avg_sync_time']:.2f}"
]
if self.stats['position']['accounts'] > 0:
stats_lines.append(
f"持仓: {self.stats['position']['accounts']}账号/{self.stats['position']['positions']}"
f"/{self.stats['position']['time']:.2f}"
)
if self.stats['order']['accounts'] > 0:
stats_lines.append(
f"订单: {self.stats['order']['accounts']}账号/{self.stats['order']['orders']}"
f"/{self.stats['order']['time']:.2f}"
)
if self.stats['account']['accounts'] > 0:
stats_lines.append(
f"账户: {self.stats['account']['accounts']}账号/{self.stats['account']['records']}"
f"/{self.stats['account']['time']:.2f}"
)
logger.info("\n".join(stats_lines))
def signal_handler(self, signum, frame):
"""信号处理器"""
logger.info(f"接收到信号 {signum},正在关闭...")
self.is_running = False
async def stop(self):
"""停止同步服务"""
self.is_running = False
self.executor.shutdown(wait=True)
# 关闭所有数据库连接
for syncer in self.syncers:
if hasattr(syncer, 'db_manager'):
syncer.db_manager.close()
logger.info("同步服务停止")