This commit is contained in:
lz_db
2025-12-03 14:40:14 +08:00
parent c8a6cfead1
commit 803d40b88e
10 changed files with 2408 additions and 97 deletions

17
.env
View File

@@ -38,4 +38,19 @@ COMPUTER_NAMES=lz_c01,lz_c02,lz_c03
COMPUTER_NAME_PATTERN=^lz_c\d{2}$
# 并发配置
MAX_CONCURRENT=10
MAX_CONCURRENT=10
# 订单同步配置
ORDER_SYNC_RECENT_DAYS=3
ORDER_BATCH_SIZE=1000
ORDER_REDIS_SCAN_COUNT=1000
# 持仓同步配置
POSITION_BATCH_SIZE=500
# 账户同步配置
ACCOUNT_SYNC_RECENT_DAYS=3
# 并发控制
MAX_CONCURRENT_ACCOUNTS=50
REDIS_BATCH_SIZE=20

366
sync/account_sync_batch.py Normal file
View File

@@ -0,0 +1,366 @@
from .base_sync import BaseSync
from loguru import logger
from typing import List, Dict, Any, Set
import json
import time
from datetime import datetime, timedelta
from sqlalchemy import text, and_
from models.orm_models import StrategyKX
class AccountSyncBatch(BaseSync):
"""账户信息批量同步器"""
async def sync_batch(self, accounts: Dict[str, Dict]):
"""批量同步所有账号的账户信息"""
try:
logger.info(f"开始批量同步账户信息,共 {len(accounts)} 个账号")
# 收集所有账号的数据
all_account_data = await self._collect_all_account_data(accounts)
if not all_account_data:
logger.info("无账户信息数据需要同步")
return
# 批量同步到数据库
success = await self._sync_account_info_batch_to_db(all_account_data)
if success:
logger.info(f"账户信息批量同步完成: 处理 {len(all_account_data)} 条记录")
else:
logger.error("账户信息批量同步失败")
except Exception as e:
logger.error(f"账户信息批量同步失败: {e}")
async def _collect_all_account_data(self, accounts: Dict[str, Dict]) -> List[Dict]:
"""收集所有账号的账户信息数据"""
all_account_data = []
try:
# 按交易所分组账号
account_groups = self._group_accounts_by_exchange(accounts)
# 并发收集每个交易所的数据
tasks = []
for exchange_id, account_list in account_groups.items():
task = self._collect_exchange_account_data(exchange_id, account_list)
tasks.append(task)
# 等待所有任务完成并合并结果
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, list):
all_account_data.extend(result)
logger.info(f"收集到 {len(all_account_data)} 条账户信息记录")
except Exception as e:
logger.error(f"收集账户信息数据失败: {e}")
return all_account_data
def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]:
"""按交易所分组账号"""
groups = {}
for account_id, account_info in accounts.items():
exchange_id = account_info.get('exchange_id')
if exchange_id:
if exchange_id not in groups:
groups[exchange_id] = []
groups[exchange_id].append(account_info)
return groups
async def _collect_exchange_account_data(self, exchange_id: str, account_list: List[Dict]) -> List[Dict]:
"""收集某个交易所的账户信息数据"""
account_data_list = []
try:
for account_info in account_list:
k_id = int(account_info['k_id'])
st_id = account_info.get('st_id', 0)
# 从Redis获取账户信息数据
account_data = await self._get_account_info_from_redis(k_id, st_id, exchange_id)
account_data_list.extend(account_data)
logger.debug(f"交易所 {exchange_id}: 收集到 {len(account_data_list)} 条账户信息")
except Exception as e:
logger.error(f"收集交易所 {exchange_id} 账户信息失败: {e}")
return account_data_list
async def _get_account_info_from_redis(self, k_id: int, st_id: int, exchange_id: str) -> List[Dict]:
"""从Redis获取账户信息数据批量优化版本"""
try:
redis_key = f"{exchange_id}:balance:{k_id}"
redis_funds = self.redis_client.client.hgetall(redis_key)
if not redis_funds:
return []
# 按天统计数据
from config.settings import SYNC_CONFIG
recent_days = SYNC_CONFIG['recent_days']
today = datetime.now()
date_stats = {}
# 收集所有日期的数据
for fund_key, fund_json in redis_funds.items():
try:
fund_data = json.loads(fund_json)
date_str = fund_data.get('lz_time', '')
lz_type = fund_data.get('lz_type', '')
if not date_str or lz_type not in ['lz_balance', 'deposit', 'withdrawal']:
continue
# 只处理最近N天的数据
date_obj = datetime.strptime(date_str, '%Y-%m-%d')
if (today - date_obj).days > recent_days:
continue
if date_str not in date_stats:
date_stats[date_str] = {
'balance': 0.0,
'deposit': 0.0,
'withdrawal': 0.0,
'has_balance': False
}
lz_amount = float(fund_data.get('lz_amount', 0))
if lz_type == 'lz_balance':
date_stats[date_str]['balance'] = lz_amount
date_stats[date_str]['has_balance'] = True
elif lz_type == 'deposit':
date_stats[date_str]['deposit'] += lz_amount
elif lz_type == 'withdrawal':
date_stats[date_str]['withdrawal'] += lz_amount
except (json.JSONDecodeError, ValueError) as e:
logger.debug(f"解析Redis数据失败: {fund_key}, error={e}")
continue
# 转换为账户信息数据
account_data_list = []
sorted_dates = sorted(date_stats.keys())
# 获取前一天余额用于计算利润
prev_balance_map = self._get_previous_balances(redis_funds, sorted_dates)
for date_str in sorted_dates:
stats = date_stats[date_str]
# 如果没有余额数据但有充提数据,仍然处理
if not stats['has_balance'] and stats['deposit'] == 0 and stats['withdrawal'] == 0:
continue
balance = stats['balance']
deposit = stats['deposit']
withdrawal = stats['withdrawal']
# 计算利润
prev_balance = prev_balance_map.get(date_str, 0.0)
profit = balance - deposit - withdrawal - prev_balance
# 转换时间戳
date_obj = datetime.strptime(date_str, '%Y-%m-%d')
time_timestamp = int(date_obj.timestamp())
account_data = {
'st_id': st_id,
'k_id': k_id,
'balance': balance,
'withdrawal': withdrawal,
'deposit': deposit,
'other': 0.0, # 暂时为0
'profit': profit,
'time': time_timestamp
}
account_data_list.append(account_data)
return account_data_list
except Exception as e:
logger.error(f"获取Redis账户信息失败: k_id={k_id}, error={e}")
return []
def _get_previous_balances(self, redis_funds: Dict, sorted_dates: List[str]) -> Dict[str, float]:
"""获取前一天的余额"""
prev_balance_map = {}
prev_date = None
for date_str in sorted_dates:
# 查找前一天的余额
if prev_date:
for fund_key, fund_json in redis_funds.items():
try:
fund_data = json.loads(fund_json)
if (fund_data.get('lz_time') == prev_date and
fund_data.get('lz_type') == 'lz_balance'):
prev_balance_map[date_str] = float(fund_data.get('lz_amount', 0))
break
except:
continue
else:
prev_balance_map[date_str] = 0.0
prev_date = date_str
return prev_balance_map
async def _sync_account_info_batch_to_db(self, account_data_list: List[Dict]) -> bool:
"""批量同步账户信息到数据库(最高效版本)"""
session = self.db_manager.get_session()
try:
if not account_data_list:
return True
with session.begin():
# 方法1使用原生SQL批量插入/更新(性能最好)
success = self._batch_upsert_account_info(session, account_data_list)
if not success:
# 方法2回退到ORM批量操作
success = self._batch_orm_upsert_account_info(session, account_data_list)
return success
except Exception as e:
logger.error(f"批量同步账户信息到数据库失败: {e}")
return False
finally:
session.close()
def _batch_upsert_account_info(self, session, account_data_list: List[Dict]) -> bool:
"""使用原生SQL批量插入/更新账户信息"""
try:
# 准备批量数据
values_list = []
for data in account_data_list:
values = (
f"({data['st_id']}, {data['k_id']}, 'USDT', "
f"{data['balance']}, {data['withdrawal']}, {data['deposit']}, "
f"{data['other']}, {data['profit']}, {data['time']})"
)
values_list.append(values)
if not values_list:
return True
values_str = ", ".join(values_list)
# 使用INSERT ... ON DUPLICATE KEY UPDATE
sql = f"""
INSERT INTO deh_strategy_kx_new
(st_id, k_id, asset, balance, withdrawal, deposit, other, profit, time)
VALUES {values_str}
ON DUPLICATE KEY UPDATE
balance = VALUES(balance),
withdrawal = VALUES(withdrawal),
deposit = VALUES(deposit),
other = VALUES(other),
profit = VALUES(profit),
up_time = NOW()
"""
session.execute(text(sql))
logger.info(f"原生SQL批量更新账户信息: {len(account_data_list)} 条记录")
return True
except Exception as e:
logger.error(f"原生SQL批量更新账户信息失败: {e}")
return False
def _batch_orm_upsert_account_info(self, session, account_data_list: List[Dict]) -> bool:
"""使用ORM批量插入/更新账户信息"""
try:
# 分组数据以提高效率
account_data_by_key = {}
for data in account_data_list:
key = (data['k_id'], data['st_id'], data['time'])
account_data_by_key[key] = data
# 批量查询现有记录
existing_records = self._batch_query_existing_records(session, list(account_data_by_key.keys()))
# 批量更新或插入
to_update = []
to_insert = []
for key, data in account_data_by_key.items():
if key in existing_records:
# 更新
record = existing_records[key]
record.balance = data['balance']
record.withdrawal = data['withdrawal']
record.deposit = data['deposit']
record.other = data['other']
record.profit = data['profit']
else:
# 插入
to_insert.append(StrategyKX(**data))
# 批量插入新记录
if to_insert:
session.add_all(to_insert)
logger.info(f"ORM批量更新账户信息: 更新 {len(existing_records)} 条,插入 {len(to_insert)}")
return True
except Exception as e:
logger.error(f"ORM批量更新账户信息失败: {e}")
return False
def _batch_query_existing_records(self, session, keys: List[tuple]) -> Dict[tuple, StrategyKX]:
"""批量查询现有记录"""
existing_records = {}
try:
if not keys:
return existing_records
# 构建查询条件
conditions = []
for k_id, st_id, time_val in keys:
conditions.append(f"(k_id = {k_id} AND st_id = {st_id} AND time = {time_val})")
if conditions:
conditions_str = " OR ".join(conditions)
sql = f"""
SELECT * FROM deh_strategy_kx_new
WHERE {conditions_str}
"""
results = session.execute(text(sql)).fetchall()
for row in results:
key = (row.k_id, row.st_id, row.time)
existing_records[key] = StrategyKX(
id=row.id,
st_id=row.st_id,
k_id=row.k_id,
asset=row.asset,
balance=row.balance,
withdrawal=row.withdrawal,
deposit=row.deposit,
other=row.other,
profit=row.profit,
time=row.time
)
except Exception as e:
logger.error(f"批量查询现有记录失败: {e}")
return existing_records
async def sync(self):
"""兼容旧接口"""
accounts = self.get_accounts_from_redis()
await self.sync_batch(accounts)

View File

@@ -1,8 +1,10 @@
# sync/base_sync.py
from abc import ABC, abstractmethod
from loguru import logger
from typing import List, Dict, Any, Set
from typing import List, Dict, Any, Set, Optional
import json
import re
import time
from utils.redis_client import RedisClient
from utils.database_manager import DatabaseManager
@@ -16,28 +18,52 @@ class BaseSync(ABC):
self.db_manager = DatabaseManager()
self.computer_names = self._get_computer_names()
self.computer_name_pattern = re.compile(COMPUTER_NAME_PATTERN)
self.sync_stats = {
'total_accounts': 0,
'success_count': 0,
'error_count': 0,
'last_sync_time': 0,
'avg_sync_time': 0
}
def _get_computer_names(self) -> List[str]:
"""获取计算机名列表"""
if ',' in COMPUTER_NAMES:
return [name.strip() for name in COMPUTER_NAMES.split(',')]
names = [name.strip() for name in COMPUTER_NAMES.split(',')]
logger.info(f"使用配置的计算机名列表: {names}")
return names
return [COMPUTER_NAMES.strip()]
@abstractmethod
async def sync(self):
"""执行同步(兼容旧接口)"""
pass
@abstractmethod
async def sync_batch(self, accounts: Dict[str, Dict]):
"""批量同步数据"""
pass
def get_accounts_from_redis(self) -> Dict[str, Dict]:
"""从Redis获取所有计算机名的账号配置"""
try:
accounts_dict = {}
total_keys_processed = 0
# 方法1使用配置的计算机名列表
for computer_name in self.computer_names:
accounts = self._get_accounts_by_computer_name(computer_name)
total_keys_processed += 1
accounts_dict.update(accounts)
# 方法2自动发现所有匹配的key(备用方案)
# 方法2如果配置的计算机名没有数据,尝试自动发现(备用方案)
if not accounts_dict:
logger.warning("配置的计算机名未找到数据,尝试自动发现...")
accounts_dict = self._discover_all_accounts()
self.sync_stats['total_accounts'] = len(accounts_dict)
logger.info(f"{len(self.computer_names)} 个计算机名获取到 {len(accounts_dict)} 个账号")
return accounts_dict
except Exception as e:
@@ -58,6 +84,8 @@ class BaseSync(ABC):
logger.debug(f"未找到 {redis_key} 的策略API配置")
return {}
logger.info(f"{redis_key} 获取到 {len(result)} 个交易所配置")
for exchange_name, accounts_json in result.items():
try:
accounts = json.loads(accounts_json)
@@ -77,8 +105,11 @@ class BaseSync(ABC):
except json.JSONDecodeError as e:
logger.error(f"解析交易所 {exchange_name} 的JSON数据失败: {e}")
continue
except Exception as e:
logger.error(f"处理交易所 {exchange_name} 数据异常: {e}")
continue
logger.info(f"{redis_key} 获取{len(accounts_dict)} 个账号")
logger.info(f"{redis_key} 解析{len(accounts_dict)} 个账号")
except Exception as e:
logger.error(f"获取计算机名 {computer_name} 的账号失败: {e}")
@@ -88,10 +119,11 @@ class BaseSync(ABC):
def _discover_all_accounts(self) -> Dict[str, Dict]:
"""自动发现所有匹配的账号key"""
accounts_dict = {}
discovered_keys = []
try:
# 获取所有匹配模式的key
pattern = f"*_strategy_api"
pattern = "*_strategy_api"
cursor = 0
while True:
@@ -99,23 +131,265 @@ class BaseSync(ABC):
for key in keys:
key_str = key.decode('utf-8') if isinstance(key, bytes) else key
# 提取计算机名
computer_name = key_str.replace('_strategy_api', '')
# 验证计算机名格式
if self.computer_name_pattern.match(computer_name):
accounts = self._get_accounts_by_computer_name(computer_name)
accounts_dict.update(accounts)
discovered_keys.append(key_str)
if cursor == 0:
break
logger.info(f"自动发现 {len(accounts_dict)}账号")
logger.info(f"自动发现 {len(discovered_keys)}策略API key")
# 处理每个发现的key
for key_str in discovered_keys:
# 提取计算机名
computer_name = key_str.replace('_strategy_api', '')
# 验证计算机名格式
if self.computer_name_pattern.match(computer_name):
accounts = self._get_accounts_by_computer_name(computer_name)
accounts_dict.update(accounts)
else:
logger.warning(f"跳过不符合格式的计算机名: {computer_name}")
logger.info(f"自动发现共获取到 {len(accounts_dict)} 个账号")
except Exception as e:
logger.error(f"自动发现账号失败: {e}")
return accounts_dict
# 其他方法保持不变...
def format_exchange_id(self, key: str) -> str:
"""格式化交易所ID"""
key = key.lower().strip()
# 交易所名称映射
exchange_mapping = {
'metatrader': 'mt5',
'binance_spot_test': 'binance',
'binance_spot': 'binance',
'binance': 'binance',
'gate_spot': 'gate',
'okex': 'okx',
'okx': 'okx',
'bybit': 'bybit',
'bybit_spot': 'bybit',
'bybit_test': 'bybit',
'huobi': 'huobi',
'huobi_spot': 'huobi',
'gate': 'gate',
'gateio': 'gate',
'kucoin': 'kucoin',
'kucoin_spot': 'kucoin',
'mexc': 'mexc',
'mexc_spot': 'mexc',
'bitget': 'bitget',
'bitget_spot': 'bitget'
}
normalized_key = exchange_mapping.get(key, key)
# 记录未映射的交易所
if normalized_key == key and key not in exchange_mapping.values():
logger.debug(f"未映射的交易所名称: {key}")
return normalized_key
def parse_account(self, exchange_id: str, account_id: str, account_info: str) -> Optional[Dict]:
"""解析账号信息"""
try:
source_account_info = json.loads(account_info)
# 基础信息
account_data = {
'exchange_id': exchange_id,
'k_id': account_id,
'st_id': self._safe_int(source_account_info.get('st_id'), 0),
'add_time': self._safe_int(source_account_info.get('add_time'), 0),
'account_type': source_account_info.get('account_type', 'real'),
'api_key': source_account_info.get('api_key', ''),
'secret_key': source_account_info.get('secret_key', ''),
'password': source_account_info.get('password', ''),
'access_token': source_account_info.get('access_token', ''),
'remark': source_account_info.get('remark', '')
}
# MT5特殊处理
if exchange_id == 'mt5':
# 解析服务器地址和端口
server_info = source_account_info.get('secret_key', '')
if ':' in server_info:
host, port = server_info.split(':', 1)
account_data['mt5_host'] = host
account_data['mt5_port'] = self._safe_int(port, 0)
# 合并原始信息
result = {**source_account_info, **account_data}
# 验证必要字段
if not result.get('st_id') or not result.get('exchange_id'):
logger.warning(f"账号 {account_id} 缺少必要字段: st_id={result.get('st_id')}, exchange_id={result.get('exchange_id')}")
return None
return result
except json.JSONDecodeError as e:
logger.error(f"解析账号 {account_id} JSON数据失败: {e}, 原始数据: {account_info[:100]}...")
return None
except Exception as e:
logger.error(f"处理账号 {account_id} 数据异常: {e}")
return None
def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]:
"""按交易所分组账号"""
groups = {}
for account_id, account_info in accounts.items():
exchange_id = account_info.get('exchange_id')
if exchange_id:
if exchange_id not in groups:
groups[exchange_id] = []
groups[exchange_id].append(account_info)
return groups
def _safe_float(self, value: Any, default: float = 0.0) -> float:
"""安全转换为float"""
if value is None:
return default
try:
if isinstance(value, str):
value = value.strip()
if value == '':
return default
return float(value)
except (ValueError, TypeError):
return default
def _safe_int(self, value: Any, default: int = 0) -> int:
"""安全转换为int"""
if value is None:
return default
try:
if isinstance(value, str):
value = value.strip()
if value == '':
return default
return int(float(value))
except (ValueError, TypeError):
return default
def _safe_str(self, value: Any, default: str = '') -> str:
"""安全转换为str"""
if value is None:
return default
try:
result = str(value).strip()
return result if result else default
except:
return default
def _escape_sql_value(self, value: Any) -> str:
"""转义SQL值"""
if value is None:
return 'NULL'
if isinstance(value, bool):
return '1' if value else '0'
if isinstance(value, (int, float)):
return str(value)
if isinstance(value, str):
# 转义单引号
escaped = value.replace("'", "''")
return f"'{escaped}'"
# 其他类型转换为字符串
escaped = str(value).replace("'", "''")
return f"'{escaped}'"
def _build_sql_values_list(self, data_list: List[Dict], fields_mapping: Dict[str, str] = None) -> List[str]:
"""构建SQL VALUES列表"""
values_list = []
for data in data_list:
try:
value_parts = []
for field, value in data.items():
# 应用字段映射
if fields_mapping and field in fields_mapping:
db_field = fields_mapping[field]
else:
db_field = field
escaped_value = self._escape_sql_value(value)
value_parts.append(escaped_value)
values_str = ", ".join(value_parts)
values_list.append(f"({values_str})")
except Exception as e:
logger.error(f"构建SQL值失败: {data}, error={e}")
continue
return values_list
def _get_recent_dates(self, days: int) -> List[str]:
"""获取最近N天的日期列表"""
from datetime import datetime, timedelta
dates = []
today = datetime.now()
for i in range(days):
date = today - timedelta(days=i)
dates.append(date.strftime('%Y-%m-%d'))
return dates
def _date_to_timestamp(self, date_str: str) -> int:
"""将日期字符串转换为时间戳当天0点"""
from datetime import datetime
try:
dt = datetime.strptime(date_str, '%Y-%m-%d')
return int(dt.timestamp())
except ValueError:
return 0
def update_stats(self, success: bool = True, sync_time: float = 0):
"""更新统计信息"""
if success:
self.sync_stats['success_count'] += 1
else:
self.sync_stats['error_count'] += 1
if sync_time > 0:
self.sync_stats['last_sync_time'] = sync_time
# 计算平均时间(滑动平均)
if self.sync_stats['avg_sync_time'] == 0:
self.sync_stats['avg_sync_time'] = sync_time
else:
self.sync_stats['avg_sync_time'] = (
self.sync_stats['avg_sync_time'] * 0.9 + sync_time * 0.1
)
def print_stats(self, sync_type: str = ""):
"""打印统计信息"""
stats = self.sync_stats
prefix = f"[{sync_type}] " if sync_type else ""
stats_str = (
f"{prefix}统计: 账号数={stats['total_accounts']}, "
f"成功={stats['success_count']}, 失败={stats['error_count']}, "
f"本次耗时={stats['last_sync_time']:.2f}s, "
f"平均耗时={stats['avg_sync_time']:.2f}s"
)
if stats['error_count'] > 0:
logger.warning(stats_str)
else:
logger.info(stats_str)
def reset_stats(self):
"""重置统计信息"""
self.sync_stats = {
'total_accounts': 0,
'success_count': 0,
'error_count': 0,
'last_sync_time': 0,
'avg_sync_time': 0
}

View File

@@ -1,99 +1,97 @@
import asyncio
from loguru import logger
from typing import List, Dict, Optional
import signal
import sys
from concurrent.futures import ThreadPoolExecutor
import time
from asyncio import Semaphore
from typing import Dict
from config.settings import SYNC_CONFIG
from .position_sync import PositionSync
from .order_sync import OrderSync
from .account_sync import AccountSync
from .position_sync_batch import PositionSyncBatch
from .order_sync_batch import OrderSyncBatch # 使用批量版本
from .account_sync_batch import AccountSyncBatch
from utils.batch_position_sync import BatchPositionSync
from utils.batch_order_sync import BatchOrderSync
from utils.batch_account_sync import BatchAccountSync
from utils.redis_batch_helper import RedisBatchHelper
class SyncManager:
"""同步管理器(支持批量并发处理"""
"""同步管理器(完整批量版本"""
def __init__(self):
self.is_running = True
self.sync_interval = SYNC_CONFIG['interval']
self.max_concurrent = int(os.getenv('MAX_CONCURRENT', '10')) # 最大并发数
# 初始化批量同步工具
self.batch_tools = {}
self.redis_helper = None
# 初始化同步器
self.syncers = []
self.executor = ThreadPoolExecutor(max_workers=self.max_concurrent)
self.semaphore = Semaphore(self.max_concurrent) # 控制并发数
if SYNC_CONFIG['enable_position_sync']:
self.syncers.append(PositionSync())
logger.info("启用持仓同步")
position_sync = PositionSyncBatch()
self.syncers.append(position_sync)
self.batch_tools['position'] = BatchPositionSync(position_sync.db_manager)
logger.info("启用持仓批量同步")
if SYNC_CONFIG['enable_order_sync']:
self.syncers.append(OrderSync())
logger.info("启用订单同步")
order_sync = OrderSyncBatch()
self.syncers.append(order_sync)
self.batch_tools['order'] = BatchOrderSync(order_sync.db_manager)
# 初始化Redis批量助手
if order_sync.redis_client:
self.redis_helper = RedisBatchHelper(order_sync.redis_client.client)
logger.info("启用订单批量同步")
if SYNC_CONFIG['enable_account_sync']:
self.syncers.append(AccountSync())
logger.info("启用账户信息同步")
account_sync = AccountSyncBatch()
self.syncers.append(account_sync)
self.batch_tools['account'] = BatchAccountSync(account_sync.db_manager)
logger.info("启用账户信息批量同步")
# 性能统计
self.stats = {
'total_accounts': 0,
'success_count': 0,
'error_count': 0,
'total_syncs': 0,
'last_sync_time': 0,
'avg_sync_time': 0
'avg_sync_time': 0,
'position': {'accounts': 0, 'positions': 0, 'time': 0},
'order': {'accounts': 0, 'orders': 0, 'time': 0},
'account': {'accounts': 0, 'records': 0, 'time': 0}
}
# 注册信号处理器
signal.signal(signal.SIGINT, self.signal_handler)
signal.signal(signal.SIGTERM, self.signal_handler)
async def _run_syncer_with_limit(self, syncer):
"""带并发限制的运行"""
async with self.semaphore:
return await self._run_syncer(syncer)
def signal_handler(self, signum, frame):
"""信号处理器"""
logger.info(f"接收到信号 {signum},正在关闭...")
self.is_running = False
def batch_process_accounts(self, accounts: Dict[str, Dict], batch_size: int = 100):
"""分批处理账号"""
account_items = list(accounts.items())
for i in range(0, len(account_items), batch_size):
batch = dict(account_items[i:i + batch_size])
# 处理这批账号
self._process_account_batch(batch)
# 批次间休息,避免数据库压力过大
time.sleep(0.1)
async def start(self):
"""启动同步服务"""
logger.info(f"同步服务启动,间隔 {self.sync_interval},最大并发 {self.max_concurrent}")
logger.info(f"同步服务启动,间隔 {self.sync_interval}")
while self.is_running:
try:
start_time = time.time()
self.stats['total_syncs'] += 1
sync_start = time.time()
# 执行所有同步器
tasks = [self._run_syncer(syncer) for syncer in self.syncers]
results = await asyncio.gather(*tasks, return_exceptions=True)
# 获取所有账号(只获取一次)
accounts = await self._get_all_accounts()
if not accounts:
logger.warning("未获取到任何账号,等待下次同步")
await asyncio.sleep(self.sync_interval)
continue
logger.info(f"{self.stats['total_syncs']}次同步开始,共 {len(accounts)} 个账号")
# 并发执行所有同步
await self._execute_all_syncers_concurrent(accounts)
# 更新统计
sync_time = time.time() - start_time
self.stats['last_sync_time'] = sync_time
self.stats['avg_sync_time'] = (self.stats['avg_sync_time'] * 0.9 + sync_time * 0.1)
sync_time = time.time() - sync_start
self._update_stats(sync_time)
# 打印统计信息
self._print_stats()
logger.debug(f"同步完成,耗时 {sync_time:.2f} 秒,等待 {self.sync_interval}")
logger.info(f"同步完成,总耗时 {sync_time:.2f} 秒,等待 {self.sync_interval}")
await asyncio.sleep(self.sync_interval)
except asyncio.CancelledError:
@@ -101,41 +99,182 @@ class SyncManager:
break
except Exception as e:
logger.error(f"同步任务异常: {e}")
self.stats['error_count'] += 1
await asyncio.sleep(30) # 出错后等待30秒
await asyncio.sleep(30)
async def _run_syncer(self, syncer):
"""运行单个同步器"""
try:
# 获取所有账号
accounts = syncer.get_accounts_from_redis()
self.stats['total_accounts'] = len(accounts)
async def _get_all_accounts(self) -> Dict[str, Dict]:
"""获取所有账号"""
if not self.syncers:
return {}
# 使用第一个同步器获取账号
return self.syncers[0].get_accounts_from_redis()
async def _execute_all_syncers_concurrent(self, accounts: Dict[str, Dict]):
"""并发执行所有同步器"""
tasks = []
# 持仓批量同步
if 'position' in self.batch_tools:
task = self._sync_positions_batch(accounts)
tasks.append(task)
# 订单批量同步
if 'order' in self.batch_tools:
task = self._sync_orders_batch(accounts)
tasks.append(task)
# 账户信息批量同步
if 'account' in self.batch_tools:
task = self._sync_accounts_batch(accounts)
tasks.append(task)
# 并发执行所有任务
if tasks:
results = await asyncio.gather(*tasks, return_exceptions=True)
if not accounts:
logger.warning("未获取到任何账号")
# 检查结果
for i, result in enumerate(results):
if isinstance(result, Exception):
logger.error(f"同步任务 {i} 失败: {result}")
async def _sync_positions_batch(self, accounts: Dict[str, Dict]):
"""批量同步持仓数据"""
try:
start_time = time.time()
# 收集所有持仓数据
position_sync = next((s for s in self.syncers if isinstance(s, PositionSyncBatch)), None)
if not position_sync:
return
# 批量处理账号
await syncer.sync_batch(accounts)
self.stats['success_count'] += 1
all_positions = await position_sync._collect_all_positions(accounts)
if not all_positions:
self.stats['position'] = {'accounts': 0, 'positions': 0, 'time': 0}
return
# 使用批量工具同步
batch_tool = self.batch_tools['position']
success, stats = batch_tool.sync_positions_batch(all_positions)
if success:
elapsed = time.time() - start_time
self.stats['position'] = {
'accounts': len(accounts),
'positions': stats['total'],
'time': elapsed
}
except Exception as e:
logger.error(f"批量同步持仓失败: {e}")
self.stats['position'] = {'accounts': 0, 'positions': 0, 'time': 0}
async def _sync_orders_batch(self, accounts: Dict[str, Dict]):
"""批量同步订单数据"""
try:
start_time = time.time()
# 收集所有订单数据
order_sync = next((s for s in self.syncers if isinstance(s, OrderSyncBatch)), None)
if not order_sync:
return
all_orders = await order_sync._collect_all_orders(accounts)
if not all_orders:
self.stats['order'] = {'accounts': 0, 'orders': 0, 'time': 0}
return
# 使用批量工具同步
batch_tool = self.batch_tools['order']
success, processed_count = batch_tool.sync_orders_batch(all_orders)
if success:
elapsed = time.time() - start_time
self.stats['order'] = {
'accounts': len(accounts),
'orders': processed_count,
'time': elapsed
}
except Exception as e:
logger.error(f"批量同步订单失败: {e}")
self.stats['order'] = {'accounts': 0, 'orders': 0, 'time': 0}
async def _sync_accounts_batch(self, accounts: Dict[str, Dict]):
"""批量同步账户信息数据"""
try:
start_time = time.time()
# 收集所有账户数据
account_sync = next((s for s in self.syncers if isinstance(s, AccountSyncBatch)), None)
if not account_sync:
return
all_account_data = await account_sync._collect_all_account_data(accounts)
if not all_account_data:
self.stats['account'] = {'accounts': 0, 'records': 0, 'time': 0}
return
# 使用批量工具同步
batch_tool = self.batch_tools['account']
updated, inserted = batch_tool.sync_accounts_batch(all_account_data)
elapsed = time.time() - start_time
self.stats['account'] = {
'accounts': len(accounts),
'records': len(all_account_data),
'time': elapsed
}
except Exception as e:
logger.error(f"同步器 {syncer.__class__.__name__} 执行失败: {e}")
self.stats['error_count'] += 1
logger.error(f"批量同步账户信息失败: {e}")
self.stats['account'] = {'accounts': 0, 'records': 0, 'time': 0}
def _print_stats(self):
"""打印统计信息"""
stats_str = (
f"统计: 账号数={self.stats['total_accounts']}, "
f"成功={self.stats['success_count']}, "
f"失败={self.stats['error_count']}, "
f"本次耗时={self.stats['last_sync_time']:.2f}s, "
f"平均耗时={self.stats['avg_sync_time']:.2f}s"
)
logger.info(stats_str)
def _update_stats(self, sync_time: float):
"""更新统计信息"""
self.stats['last_sync_time'] = sync_time
self.stats['avg_sync_time'] = (self.stats['avg_sync_time'] * 0.9 + sync_time * 0.1)
# 打印详细统计
stats_lines = [
f"=== 第{self.stats['total_syncs']}次同步统计 ===",
f"总耗时: {sync_time:.2f}秒 | 平均耗时: {self.stats['avg_sync_time']:.2f}"
]
if self.stats['position']['accounts'] > 0:
stats_lines.append(
f"持仓: {self.stats['position']['accounts']}账号/{self.stats['position']['positions']}"
f"/{self.stats['position']['time']:.2f}"
)
if self.stats['order']['accounts'] > 0:
stats_lines.append(
f"订单: {self.stats['order']['accounts']}账号/{self.stats['order']['orders']}"
f"/{self.stats['order']['time']:.2f}"
)
if self.stats['account']['accounts'] > 0:
stats_lines.append(
f"账户: {self.stats['account']['accounts']}账号/{self.stats['account']['records']}"
f"/{self.stats['account']['time']:.2f}"
)
logger.info("\n".join(stats_lines))
def signal_handler(self, signum, frame):
"""信号处理器"""
logger.info(f"接收到信号 {signum},正在关闭...")
self.is_running = False
async def stop(self):
"""停止同步服务"""
self.is_running = False
self.executor.shutdown(wait=True)
# 关闭所有数据库连接
for syncer in self.syncers:
if hasattr(syncer, 'db_manager'):
syncer.db_manager.close()
logger.info("同步服务停止")

269
sync/order_sync_batch.py Normal file
View File

@@ -0,0 +1,269 @@
from .base_sync import BaseSync
from loguru import logger
from typing import List, Dict, Any, Tuple
import json
import asyncio
import time
from datetime import datetime, timedelta
from sqlalchemy import text
import redis
class OrderSyncBatch(BaseSync):
"""订单数据批量同步器"""
def __init__(self):
super().__init__()
self.batch_size = 1000 # 每批处理数量
self.recent_days = 3 # 同步最近几天的数据
async def sync_batch(self, accounts: Dict[str, Dict]):
"""批量同步所有账号的订单数据"""
try:
logger.info(f"开始批量同步订单数据,共 {len(accounts)} 个账号")
start_time = time.time()
# 1. 收集所有账号的订单数据
all_orders = await self._collect_all_orders(accounts)
if not all_orders:
logger.info("无订单数据需要同步")
return
logger.info(f"收集到 {len(all_orders)} 条订单数据")
# 2. 批量同步到数据库
success, processed_count = await self._sync_orders_batch_to_db(all_orders)
elapsed = time.time() - start_time
if success:
logger.info(f"订单批量同步完成: 处理 {processed_count} 条订单,耗时 {elapsed:.2f}")
else:
logger.error("订单批量同步失败")
except Exception as e:
logger.error(f"订单批量同步失败: {e}")
async def _collect_all_orders(self, accounts: Dict[str, Dict]) -> List[Dict]:
"""收集所有账号的订单数据"""
all_orders = []
try:
# 按交易所分组账号
account_groups = self._group_accounts_by_exchange(accounts)
# 并发收集每个交易所的数据
tasks = []
for exchange_id, account_list in account_groups.items():
task = self._collect_exchange_orders(exchange_id, account_list)
tasks.append(task)
# 等待所有任务完成并合并结果
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, list):
all_orders.extend(result)
except Exception as e:
logger.error(f"收集订单数据失败: {e}")
return all_orders
def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]:
"""按交易所分组账号"""
groups = {}
for account_id, account_info in accounts.items():
exchange_id = account_info.get('exchange_id')
if exchange_id:
if exchange_id not in groups:
groups[exchange_id] = []
groups[exchange_id].append(account_info)
return groups
async def _collect_exchange_orders(self, exchange_id: str, account_list: List[Dict]) -> List[Dict]:
"""收集某个交易所的订单数据"""
orders_list = []
try:
# 并发获取每个账号的数据
tasks = []
for account_info in account_list:
k_id = int(account_info['k_id'])
st_id = account_info.get('st_id', 0)
task = self._get_recent_orders_from_redis(k_id, st_id, exchange_id)
tasks.append(task)
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, list):
orders_list.extend(result)
logger.debug(f"交易所 {exchange_id}: 收集到 {len(orders_list)} 条订单")
except Exception as e:
logger.error(f"收集交易所 {exchange_id} 订单数据失败: {e}")
return orders_list
async def _get_recent_orders_from_redis(self, k_id: int, st_id: int, exchange_id: str) -> List[Dict]:
"""从Redis获取最近N天的订单数据"""
try:
redis_key = f"{exchange_id}:orders:{k_id}"
# 计算最近N天的日期
today = datetime.now()
recent_dates = []
for i in range(self.recent_days):
date = today - timedelta(days=i)
date_format = date.strftime('%Y-%m-%d')
recent_dates.append(date_format)
# 使用scan获取所有符合条件的key
cursor = 0
recent_keys = []
while True:
cursor, keys = self.redis_client.client.hscan(redis_key, cursor, count=1000)
for key, _ in keys.items():
key_str = key.decode('utf-8') if isinstance(key, bytes) else key
if key_str == 'positions':
continue
# 检查是否以最近N天的日期开头
for date_format in recent_dates:
if key_str.startswith(date_format + '_'):
recent_keys.append(key_str)
break
if cursor == 0:
break
if not recent_keys:
return []
# 批量获取订单数据
orders_list = []
# 分批获取避免单次hgetall数据量太大
chunk_size = 500
for i in range(0, len(recent_keys), chunk_size):
chunk_keys = recent_keys[i:i + chunk_size]
# 使用hmget批量获取
chunk_values = self.redis_client.client.hmget(redis_key, chunk_keys)
for key, order_json in zip(chunk_keys, chunk_values):
if not order_json:
continue
try:
order = json.loads(order_json)
# 验证时间
order_time = order.get('time', 0)
if order_time >= int(time.time()) - self.recent_days * 24 * 3600:
# 添加账号信息
order['k_id'] = k_id
order['st_id'] = st_id
order['exchange_id'] = exchange_id
orders_list.append(order)
except json.JSONDecodeError as e:
logger.debug(f"解析订单JSON失败: key={key}, error={e}")
continue
return orders_list
except Exception as e:
logger.error(f"获取Redis订单数据失败: k_id={k_id}, error={e}")
return []
async def _sync_orders_batch_to_db(self, all_orders: List[Dict]) -> Tuple[bool, int]:
"""批量同步订单数据到数据库"""
try:
if not all_orders:
return True, 0
# 转换数据
converted_orders = []
for order in all_orders:
try:
order_dict = self._convert_order_data(order)
# 检查完整性
required_fields = ['order_id', 'symbol', 'side', 'time']
if not all(order_dict.get(field) for field in required_fields):
continue
converted_orders.append(order_dict)
except Exception as e:
logger.error(f"转换订单数据失败: {order}, error={e}")
continue
if not converted_orders:
return True, 0
# 使用批量工具同步
from utils.batch_order_sync import BatchOrderSync
batch_tool = BatchOrderSync(self.db_manager, self.batch_size)
success, processed_count = batch_tool.sync_orders_batch(converted_orders)
return success, processed_count
except Exception as e:
logger.error(f"批量同步订单到数据库失败: {e}")
return False, 0
def _convert_order_data(self, data: Dict) -> Dict:
"""转换订单数据格式"""
try:
# 安全转换函数
def safe_float(value):
if value is None:
return None
try:
return float(value)
except (ValueError, TypeError):
return None
def safe_int(value):
if value is None:
return None
try:
return int(float(value))
except (ValueError, TypeError):
return None
def safe_str(value):
if value is None:
return ''
return str(value)
return {
'st_id': safe_int(data.get('st_id'), 0),
'k_id': safe_int(data.get('k_id'), 0),
'asset': 'USDT',
'order_id': safe_str(data.get('order_id')),
'symbol': safe_str(data.get('symbol')),
'side': safe_str(data.get('side')),
'price': safe_float(data.get('price')),
'time': safe_int(data.get('time')),
'order_qty': safe_float(data.get('order_qty')),
'last_qty': safe_float(data.get('last_qty')),
'avg_price': safe_float(data.get('avg_price')),
'exchange_id': None # 忽略该字段
}
except Exception as e:
logger.error(f"转换订单数据异常: {data}, error={e}")
return {}
async def sync(self):
"""兼容旧接口"""
accounts = self.get_accounts_from_redis()
await self.sync_batch(accounts)

378
sync/position_sync_batch.py Normal file
View File

@@ -0,0 +1,378 @@
from .base_sync import BaseSync
from loguru import logger
from typing import List, Dict, Any, Set, Tuple
import json
import asyncio
from datetime import datetime
from sqlalchemy import text, and_, select, delete
from models.orm_models import StrategyPosition
import time
class PositionSyncBatch(BaseSync):
"""持仓数据批量同步器"""
def __init__(self):
super().__init__()
self.batch_size = 500 # 每批处理数量
async def sync_batch(self, accounts: Dict[str, Dict]):
"""批量同步所有账号的持仓数据"""
try:
logger.info(f"开始批量同步持仓数据,共 {len(accounts)} 个账号")
start_time = time.time()
# 1. 收集所有账号的持仓数据
all_positions = await self._collect_all_positions(accounts)
if not all_positions:
logger.info("无持仓数据需要同步")
return
logger.info(f"收集到 {len(all_positions)} 条持仓数据")
# 2. 批量同步到数据库
success, stats = await self._sync_positions_batch_to_db(all_positions)
elapsed = time.time() - start_time
if success:
logger.info(f"持仓批量同步完成: 处理 {stats['total']} 条,更新 {stats['updated']} 条,"
f"插入 {stats['inserted']} 条,删除 {stats['deleted']} 条,耗时 {elapsed:.2f}")
else:
logger.error("持仓批量同步失败")
except Exception as e:
logger.error(f"持仓批量同步失败: {e}")
async def _collect_all_positions(self, accounts: Dict[str, Dict]) -> List[Dict]:
"""收集所有账号的持仓数据"""
all_positions = []
try:
# 按交易所分组账号
account_groups = self._group_accounts_by_exchange(accounts)
# 并发收集每个交易所的数据
tasks = []
for exchange_id, account_list in account_groups.items():
task = self._collect_exchange_positions(exchange_id, account_list)
tasks.append(task)
# 等待所有任务完成并合并结果
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, list):
all_positions.extend(result)
except Exception as e:
logger.error(f"收集持仓数据失败: {e}")
return all_positions
def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]:
"""按交易所分组账号"""
groups = {}
for account_id, account_info in accounts.items():
exchange_id = account_info.get('exchange_id')
if exchange_id:
if exchange_id not in groups:
groups[exchange_id] = []
groups[exchange_id].append(account_info)
return groups
async def _collect_exchange_positions(self, exchange_id: str, account_list: List[Dict]) -> List[Dict]:
"""收集某个交易所的持仓数据"""
positions_list = []
try:
tasks = []
for account_info in account_list:
k_id = int(account_info['k_id'])
st_id = account_info.get('st_id', 0)
task = self._get_positions_from_redis(k_id, st_id, exchange_id)
tasks.append(task)
# 并发获取
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, list):
positions_list.extend(result)
except Exception as e:
logger.error(f"收集交易所 {exchange_id} 持仓数据失败: {e}")
return positions_list
async def _get_positions_from_redis(self, k_id: int, st_id: int, exchange_id: str) -> List[Dict]:
"""从Redis获取持仓数据"""
try:
redis_key = f"{exchange_id}:positions:{k_id}"
redis_data = self.redis_client.client.hget(redis_key, 'positions')
if not redis_data:
return []
positions = json.loads(redis_data)
# 添加账号信息
for position in positions:
position['k_id'] = k_id
position['st_id'] = st_id
position['exchange_id'] = exchange_id
return positions
except Exception as e:
logger.error(f"获取Redis持仓数据失败: k_id={k_id}, error={e}")
return []
async def _sync_positions_batch_to_db(self, all_positions: List[Dict]) -> Tuple[bool, Dict]:
"""批量同步持仓数据到数据库"""
try:
if not all_positions:
return True, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0}
# 按账号分组
positions_by_account = {}
for position in all_positions:
k_id = position['k_id']
if k_id not in positions_by_account:
positions_by_account[k_id] = []
positions_by_account[k_id].append(position)
logger.info(f"开始批量处理 {len(positions_by_account)} 个账号的持仓数据")
# 批量处理每个账号
total_stats = {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0}
for k_id, positions in positions_by_account.items():
st_id = positions[0]['st_id'] if positions else 0
# 处理单个账号的批量同步
success, stats = await self._sync_single_account_batch(k_id, st_id, positions)
if success:
total_stats['total'] += stats['total']
total_stats['updated'] += stats['updated']
total_stats['inserted'] += stats['inserted']
total_stats['deleted'] += stats['deleted']
return True, total_stats
except Exception as e:
logger.error(f"批量同步持仓到数据库失败: {e}")
return False, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0}
async def _sync_single_account_batch(self, k_id: int, st_id: int, positions: List[Dict]) -> Tuple[bool, Dict]:
"""批量同步单个账号的持仓数据"""
session = self.db_manager.get_session()
try:
# 准备数据
insert_data = []
new_positions_map = {} # (symbol, side) -> position_id (用于删除)
for position_data in positions:
try:
position_dict = self._convert_position_data(position_data)
if not all([position_dict.get('symbol'), position_dict.get('side')]):
continue
symbol = position_dict['symbol']
side = position_dict['side']
key = (symbol, side)
# 重命名qty为sum
if 'qty' in position_dict:
position_dict['sum'] = position_dict.pop('qty')
insert_data.append(position_dict)
new_positions_map[key] = position_dict.get('id') # 如果有id的话
except Exception as e:
logger.error(f"转换持仓数据失败: {position_data}, error={e}")
continue
with session.begin():
if not insert_data:
# 清空该账号所有持仓
result = session.execute(
delete(StrategyPosition).where(
and_(
StrategyPosition.k_id == k_id,
StrategyPosition.st_id == st_id
)
)
)
deleted_count = result.rowcount
return True, {
'total': 0,
'updated': 0,
'inserted': 0,
'deleted': deleted_count
}
# 1. 批量插入/更新持仓数据
processed_count = self._batch_upsert_positions(session, insert_data)
# 2. 批量删除多余持仓
deleted_count = self._batch_delete_extra_positions(session, k_id, st_id, new_positions_map)
# 注意这里无法区分插入和更新的数量processed_count是总处理数
inserted_count = processed_count # 简化处理
updated_count = 0 # 需要更复杂的逻辑来区分
stats = {
'total': len(insert_data),
'updated': updated_count,
'inserted': inserted_count,
'deleted': deleted_count
}
return True, stats
except Exception as e:
logger.error(f"批量同步账号 {k_id} 持仓失败: {e}")
return False, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0}
finally:
session.close()
def _batch_upsert_positions(self, session, insert_data: List[Dict]) -> int:
"""批量插入/更新持仓数据"""
try:
# 分块处理
chunk_size = self.batch_size
total_processed = 0
for i in range(0, len(insert_data), chunk_size):
chunk = insert_data[i:i + chunk_size]
values_list = []
for data in chunk:
values = (
f"({data['st_id']}, {data['k_id']}, '{data.get('asset', 'USDT')}', "
f"'{data['symbol'].replace(\"'\", \"''\")}', '{data['side']}', "
f"{data.get('price') or 'NULL'}, {data.get('sum') or 'NULL'}, "
f"{data.get('asset_num') or 'NULL'}, {data.get('asset_profit') or 'NULL'}, "
f"{data.get('leverage') or 'NULL'}, {data.get('uptime') or 'NULL'}, "
f"{data.get('profit_price') or 'NULL'}, {data.get('stop_price') or 'NULL'}, "
f"{data.get('liquidation_price') or 'NULL'})"
)
values_list.append(values)
if values_list:
values_str = ", ".join(values_list)
sql = f"""
INSERT INTO deh_strategy_position_new
(st_id, k_id, asset, symbol, side, price, `sum`,
asset_num, asset_profit, leverage, uptime,
profit_price, stop_price, liquidation_price)
VALUES {values_str}
ON DUPLICATE KEY UPDATE
price = VALUES(price),
`sum` = VALUES(`sum`),
asset_num = VALUES(asset_num),
asset_profit = VALUES(asset_profit),
leverage = VALUES(leverage),
uptime = VALUES(uptime),
profit_price = VALUES(profit_price),
stop_price = VALUES(stop_price),
liquidation_price = VALUES(liquidation_price)
"""
session.execute(text(sql))
total_processed += len(chunk)
return total_processed
except Exception as e:
logger.error(f"批量插入/更新持仓失败: {e}")
raise
def _batch_delete_extra_positions(self, session, k_id: int, st_id: int, new_positions_map: Dict) -> int:
"""批量删除多余持仓"""
try:
if not new_positions_map:
# 删除所有持仓
result = session.execute(
delete(StrategyPosition).where(
and_(
StrategyPosition.k_id == k_id,
StrategyPosition.st_id == st_id
)
)
)
return result.rowcount
# 构建保留条件
conditions = []
for (symbol, side) in new_positions_map.keys():
safe_symbol = symbol.replace("'", "''") if symbol else ''
safe_side = side.replace("'", "''") if side else ''
conditions.append(f"(symbol = '{safe_symbol}' AND side = '{safe_side}')")
if conditions:
conditions_str = " OR ".join(conditions)
sql = f"""
DELETE FROM deh_strategy_position_new
WHERE k_id = {k_id} AND st_id = {st_id}
AND NOT ({conditions_str})
"""
result = session.execute(text(sql))
return result.rowcount
return 0
except Exception as e:
logger.error(f"批量删除持仓失败: k_id={k_id}, error={e}")
return 0
def _convert_position_data(self, data: Dict) -> Dict:
"""转换持仓数据格式"""
try:
# 安全转换函数
def safe_float(value, default=None):
if value is None:
return default
try:
return float(value)
except (ValueError, TypeError):
return default
def safe_int(value, default=None):
if value is None:
return default
try:
return int(float(value))
except (ValueError, TypeError):
return default
return {
'st_id': safe_int(data.get('st_id'), 0),
'k_id': safe_int(data.get('k_id'), 0),
'asset': data.get('asset', 'USDT'),
'symbol': data.get('symbol', ''),
'side': data.get('side', ''),
'price': safe_float(data.get('price')),
'qty': safe_float(data.get('qty')), # 后面会重命名为sum
'asset_num': safe_float(data.get('asset_num')),
'asset_profit': safe_float(data.get('asset_profit')),
'leverage': safe_int(data.get('leverage')),
'uptime': safe_int(data.get('uptime')),
'profit_price': safe_float(data.get('profit_price')),
'stop_price': safe_float(data.get('stop_price')),
'liquidation_price': safe_float(data.get('liquidation_price'))
}
except Exception as e:
logger.error(f"转换持仓数据异常: {data}, error={e}")
return {}
async def sync(self):
"""兼容旧接口"""
accounts = self.get_accounts_from_redis()
await self.sync_batch(accounts)

174
utils/batch_account_sync.py Normal file
View File

@@ -0,0 +1,174 @@
from typing import List, Dict, Any, Tuple
from loguru import logger
from sqlalchemy import text
import time
class BatchAccountSync:
"""账户信息批量同步工具"""
def __init__(self, db_manager):
self.db_manager = db_manager
def sync_accounts_batch(self, all_account_data: List[Dict]) -> Tuple[int, int]:
"""批量同步账户信息(最高效版本)"""
if not all_account_data:
return 0, 0
session = self.db_manager.get_session()
try:
start_time = time.time()
# 方法1使用临时表进行批量操作性能最好
updated_count, inserted_count = self._sync_using_temp_table(session, all_account_data)
elapsed = time.time() - start_time
logger.info(f"账户信息批量同步完成: 更新 {updated_count} 条,插入 {inserted_count} 条,耗时 {elapsed:.2f}")
return updated_count, inserted_count
except Exception as e:
logger.error(f"账户信息批量同步失败: {e}")
return 0, 0
finally:
session.close()
def _sync_using_temp_table(self, session, all_account_data: List[Dict]) -> Tuple[int, int]:
"""使用临时表进行批量同步"""
try:
# 1. 创建临时表
session.execute(text("""
CREATE TEMPORARY TABLE IF NOT EXISTS temp_account_info (
st_id INT,
k_id INT,
asset VARCHAR(32),
balance DECIMAL(20, 8),
withdrawal DECIMAL(20, 8),
deposit DECIMAL(20, 8),
other DECIMAL(20, 8),
profit DECIMAL(20, 8),
time INT,
PRIMARY KEY (k_id, st_id, time)
)
"""))
# 2. 清空临时表
session.execute(text("TRUNCATE TABLE temp_account_info"))
# 3. 批量插入数据到临时表
chunk_size = 1000
total_inserted = 0
for i in range(0, len(all_account_data), chunk_size):
chunk = all_account_data[i:i + chunk_size]
values_list = []
for data in chunk:
values = (
f"({data['st_id']}, {data['k_id']}, 'USDT', "
f"{data['balance']}, {data['withdrawal']}, {data['deposit']}, "
f"{data['other']}, {data['profit']}, {data['time']})"
)
values_list.append(values)
if values_list:
values_str = ", ".join(values_list)
sql = f"""
INSERT INTO temp_account_info
(st_id, k_id, asset, balance, withdrawal, deposit, other, profit, time)
VALUES {values_str}
"""
session.execute(text(sql))
total_inserted += len(chunk)
# 4. 使用临时表更新主表
# 更新已存在的记录
update_result = session.execute(text("""
UPDATE deh_strategy_kx_new main
INNER JOIN temp_account_info temp
ON main.k_id = temp.k_id
AND main.st_id = temp.st_id
AND main.time = temp.time
SET main.balance = temp.balance,
main.withdrawal = temp.withdrawal,
main.deposit = temp.deposit,
main.other = temp.other,
main.profit = temp.profit,
main.up_time = NOW()
"""))
updated_count = update_result.rowcount
# 插入新记录
insert_result = session.execute(text("""
INSERT INTO deh_strategy_kx_new
(st_id, k_id, asset, balance, withdrawal, deposit, other, profit, time, up_time)
SELECT
st_id, k_id, asset, balance, withdrawal, deposit, other, profit, time, NOW()
FROM temp_account_info temp
WHERE NOT EXISTS (
SELECT 1 FROM deh_strategy_kx_new main
WHERE main.k_id = temp.k_id
AND main.st_id = temp.st_id
AND main.time = temp.time
)
"""))
inserted_count = insert_result.rowcount
# 5. 删除临时表
session.execute(text("DROP TEMPORARY TABLE IF EXISTS temp_account_info"))
session.commit()
return updated_count, inserted_count
except Exception as e:
session.rollback()
logger.error(f"临时表同步失败: {e}")
raise
def _sync_using_on_duplicate(self, session, all_account_data: List[Dict]) -> Tuple[int, int]:
"""使用ON DUPLICATE KEY UPDATE批量同步简化版"""
try:
# 分块执行避免SQL过长
chunk_size = 1000
total_processed = 0
for i in range(0, len(all_account_data), chunk_size):
chunk = all_account_data[i:i + chunk_size]
values_list = []
for data in chunk:
values = (
f"({data['st_id']}, {data['k_id']}, 'USDT', "
f"{data['balance']}, {data['withdrawal']}, {data['deposit']}, "
f"{data['other']}, {data['profit']}, {data['time']})"
)
values_list.append(values)
if values_list:
values_str = ", ".join(values_list)
sql = f"""
INSERT INTO deh_strategy_kx_new
(st_id, k_id, asset, balance, withdrawal, deposit, other, profit, time)
VALUES {values_str}
ON DUPLICATE KEY UPDATE
balance = VALUES(balance),
withdrawal = VALUES(withdrawal),
deposit = VALUES(deposit),
other = VALUES(other),
profit = VALUES(profit),
up_time = NOW()
"""
result = session.execute(text(sql))
total_processed += len(chunk)
session.commit()
# 注意:这里无法区分更新和插入的数量
return total_processed, 0
except Exception as e:
session.rollback()
logger.error(f"ON DUPLICATE同步失败: {e}")
raise

313
utils/batch_order_sync.py Normal file
View File

@@ -0,0 +1,313 @@
from typing import List, Dict, Any, Tuple
from loguru import logger
from sqlalchemy import text
import time
class BatchOrderSync:
"""订单数据批量同步工具(最高性能)"""
def __init__(self, db_manager, batch_size: int = 1000):
self.db_manager = db_manager
self.batch_size = batch_size
def sync_orders_batch(self, all_orders: List[Dict]) -> Tuple[bool, int]:
"""批量同步订单数据"""
if not all_orders:
return True, 0
session = self.db_manager.get_session()
try:
start_time = time.time()
# 方法1使用临时表性能最好
processed_count = self._sync_using_temp_table(session, all_orders)
elapsed = time.time() - start_time
logger.info(f"订单批量同步完成: 处理 {processed_count} 条订单,耗时 {elapsed:.2f}")
return True, processed_count
except Exception as e:
logger.error(f"订单批量同步失败: {e}")
return False, 0
finally:
session.close()
def _sync_using_temp_table(self, session, all_orders: List[Dict]) -> int:
"""使用临时表批量同步订单"""
try:
# 1. 创建临时表
session.execute(text("""
CREATE TEMPORARY TABLE IF NOT EXISTS temp_orders (
st_id INT,
k_id INT,
asset VARCHAR(32),
order_id VARCHAR(765),
symbol VARCHAR(120),
side VARCHAR(120),
price FLOAT,
time INT,
order_qty FLOAT,
last_qty FLOAT,
avg_price FLOAT,
exchange_id INT,
UNIQUE KEY idx_unique_order (order_id, symbol, k_id, side)
)
"""))
# 2. 清空临时表
session.execute(text("TRUNCATE TABLE temp_orders"))
# 3. 批量插入数据到临时表(分块)
inserted_count = self._batch_insert_to_temp_table(session, all_orders)
if inserted_count == 0:
session.execute(text("DROP TEMPORARY TABLE IF EXISTS temp_orders"))
return 0
# 4. 使用临时表更新主表
# 更新已存在的记录(只更新需要比较的字段)
update_result = session.execute(text("""
UPDATE deh_strategy_order_new main
INNER JOIN temp_orders temp
ON main.order_id = temp.order_id
AND main.symbol = temp.symbol
AND main.k_id = temp.k_id
AND main.side = temp.side
SET main.side = temp.side,
main.price = temp.price,
main.time = temp.time,
main.order_qty = temp.order_qty,
main.last_qty = temp.last_qty,
main.avg_price = temp.avg_price
WHERE main.side != temp.side
OR main.price != temp.price
OR main.time != temp.time
OR main.order_qty != temp.order_qty
OR main.last_qty != temp.last_qty
OR main.avg_price != temp.avg_price
"""))
updated_count = update_result.rowcount
# 插入新记录
insert_result = session.execute(text("""
INSERT INTO deh_strategy_order_new
(st_id, k_id, asset, order_id, symbol, side, price, time,
order_qty, last_qty, avg_price, exchange_id)
SELECT
st_id, k_id, asset, order_id, symbol, side, price, time,
order_qty, last_qty, avg_price, exchange_id
FROM temp_orders temp
WHERE NOT EXISTS (
SELECT 1 FROM deh_strategy_order_new main
WHERE main.order_id = temp.order_id
AND main.symbol = temp.symbol
AND main.k_id = temp.k_id
AND main.side = temp.side
)
"""))
inserted_count = insert_result.rowcount
# 5. 删除临时表
session.execute(text("DROP TEMPORARY TABLE IF EXISTS temp_orders"))
session.commit()
total_processed = updated_count + inserted_count
logger.info(f"订单批量同步: 更新 {updated_count} 条,插入 {inserted_count}")
return total_processed
except Exception as e:
session.rollback()
logger.error(f"临时表同步订单失败: {e}")
raise
def _batch_insert_to_temp_table(self, session, all_orders: List[Dict]) -> int:
"""批量插入数据到临时表"""
total_inserted = 0
try:
# 分块处理
for i in range(0, len(all_orders), self.batch_size):
chunk = all_orders[i:i + self.batch_size]
values_list = []
for order in chunk:
try:
# 处理NULL值
price = order.get('price')
time_val = order.get('time')
order_qty = order.get('order_qty')
last_qty = order.get('last_qty')
avg_price = order.get('avg_price')
# 转义单引号
symbol = order.get('symbol').replace("'", "''") if order.get('symbol') else ''
order_id = order.get('order_id').replace("'", "''") if order.get('order_id') else ''
values = (
f"({order['st_id']}, {order['k_id']}, '{order.get('asset', 'USDT')}', "
f"'{order_id}', "
f"'{symbol}', "
f"'{order['side']}', "
f"{price if price is not None else 'NULL'}, "
f"{time_val if time_val is not None else 'NULL'}, "
f"{order_qty if order_qty is not None else 'NULL'}, "
f"{last_qty if last_qty is not None else 'NULL'}, "
f"{avg_price if avg_price is not None else 'NULL'}, "
"NULL)"
)
values_list.append(values)
except Exception as e:
logger.error(f"构建订单值失败: {order}, error={e}")
continue
if values_list:
values_str = ", ".join(values_list)
sql = f"""
INSERT INTO temp_orders
(st_id, k_id, asset, order_id, symbol, side, price, time,
order_qty, last_qty, avg_price, exchange_id)
VALUES {values_str}
"""
result = session.execute(text(sql))
total_inserted += len(chunk)
return total_inserted
except Exception as e:
logger.error(f"批量插入临时表失败: {e}")
raise
def _batch_insert_to_temp_table1(self, session, all_orders: List[Dict]) -> int:
"""批量插入数据到临时表使用参数化查询temp_orders"""
total_inserted = 0
try:
# 分块处理
for i in range(0, len(all_orders), self.batch_size):
chunk = all_orders[i:i + self.batch_size]
# 准备参数化数据
insert_data = []
for order in chunk:
try:
insert_data.append({
'st_id': order['st_id'],
'k_id': order['k_id'],
'asset': order.get('asset', 'USDT'),
'order_id': order['order_id'],
'symbol': order['symbol'],
'side': order['side'],
'price': order.get('price'),
'time': order.get('time'),
'order_qty': order.get('order_qty'),
'last_qty': order.get('last_qty'),
'avg_price': order.get('avg_price')
# exchange_id 留空使用默认值NULL
})
except KeyError as e:
logger.error(f"订单数据缺少必要字段: {order}, missing={e}")
continue
except Exception as e:
logger.error(f"处理订单数据失败: {order}, error={e}")
continue
if insert_data:
sql = text(f"""
INSERT INTO {self.temp_table_name}
(st_id, k_id, asset, order_id, symbol, side, price, time,
order_qty, last_qty, avg_price)
VALUES
(:st_id, :k_id, :asset, :order_id, :symbol, :side, :price, :time,
:order_qty, :last_qty, :avg_price)
""")
try:
session.execute(sql, insert_data)
session.commit()
total_inserted += len(insert_data)
logger.debug(f"插入 {len(insert_data)} 条数据到临时表")
except Exception as e:
session.rollback()
logger.error(f"执行批量插入失败: {e}")
raise
logger.info(f"总共插入 {total_inserted} 条数据到临时表")
return total_inserted
except Exception as e:
logger.error(f"批量插入临时表失败: {e}")
session.rollback()
raise
def _sync_using_on_duplicate(self, session, all_orders: List[Dict]) -> int:
"""使用ON DUPLICATE KEY UPDATE批量同步简化版"""
try:
total_processed = 0
# 分块执行
for i in range(0, len(all_orders), self.batch_size):
chunk = all_orders[i:i + self.batch_size]
values_list = []
for order in chunk:
try:
# 处理NULL值
price = order.get('price')
time_val = order.get('time')
order_qty = order.get('order_qty')
last_qty = order.get('last_qty')
avg_price = order.get('avg_price')
symbol = order.get('symbol').replace("'", "''") if order.get('symbol') else ''
order_id = order.get('order_id').replace("'", "''") if order.get('order_id') else ''
values = (
f"({order['st_id']}, {order['k_id']}, '{order.get('asset', 'USDT')}', "
f"'{order_id}', "
f"'{symbol}', "
f"'{order['side']}', "
f"{price if price is not None else 'NULL'}, "
f"{time_val if time_val is not None else 'NULL'}, "
f"{order_qty if order_qty is not None else 'NULL'}, "
f"{last_qty if last_qty is not None else 'NULL'}, "
f"{avg_price if avg_price is not None else 'NULL'}, "
"NULL)"
)
values_list.append(values)
except Exception as e:
logger.error(f"构建订单值失败: {order}, error={e}")
continue
if values_list:
values_str = ", ".join(values_list)
sql = f"""
INSERT INTO deh_strategy_order_new
(st_id, k_id, asset, order_id, symbol, side, price, time,
order_qty, last_qty, avg_price, exchange_id)
VALUES {values_str}
ON DUPLICATE KEY UPDATE
side = VALUES(side),
price = VALUES(price),
time = VALUES(time),
order_qty = VALUES(order_qty),
last_qty = VALUES(last_qty),
avg_price = VALUES(avg_price)
"""
session.execute(text(sql))
total_processed += len(chunk)
session.commit()
return total_processed
except Exception as e:
session.rollback()
logger.error(f"ON DUPLICATE同步订单失败: {e}")
raise

View File

@@ -0,0 +1,254 @@
from typing import List, Dict, Any, Tuple
from loguru import logger
from sqlalchemy import text
import time
class BatchPositionSync:
"""持仓数据批量同步工具(使用临时表,最高性能)"""
def __init__(self, db_manager, batch_size: int = 500):
self.db_manager = db_manager
self.batch_size = batch_size
def sync_positions_batch(self, all_positions: List[Dict]) -> Tuple[bool, Dict]:
"""批量同步持仓数据(最高效版本)"""
if not all_positions:
return True, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0}
session = self.db_manager.get_session()
try:
start_time = time.time()
# 按账号分组
positions_by_account = self._group_positions_by_account(all_positions)
total_stats = {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0}
with session.begin():
# 处理每个账号
for (k_id, st_id), positions in positions_by_account.items():
success, stats = self._sync_account_using_temp_table(
session, k_id, st_id, positions
)
if success:
total_stats['total'] += stats['total']
total_stats['updated'] += stats['updated']
total_stats['inserted'] += stats['inserted']
total_stats['deleted'] += stats['deleted']
elapsed = time.time() - start_time
logger.info(f"持仓批量同步完成: 处理 {len(positions_by_account)} 个账号,"
f"总持仓 {total_stats['total']} 条,耗时 {elapsed:.2f}")
return True, total_stats
except Exception as e:
logger.error(f"持仓批量同步失败: {e}")
return False, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0}
finally:
session.close()
def _group_positions_by_account(self, all_positions: List[Dict]) -> Dict[Tuple[int, int], List[Dict]]:
"""按账号分组持仓数据"""
groups = {}
for position in all_positions:
k_id = position.get('k_id')
st_id = position.get('st_id', 0)
key = (k_id, st_id)
if key not in groups:
groups[key] = []
groups[key].append(position)
return groups
def _sync_account_using_temp_table(self, session, k_id: int, st_id: int, positions: List[Dict]) -> Tuple[bool, Dict]:
"""使用临时表同步单个账号的持仓数据"""
try:
# 1. 创建临时表
session.execute(text("""
CREATE TEMPORARY TABLE IF NOT EXISTS temp_positions (
st_id INT,
k_id INT,
asset VARCHAR(32),
symbol VARCHAR(50),
side VARCHAR(10),
price FLOAT,
`sum` FLOAT,
asset_num DECIMAL(20, 8),
asset_profit DECIMAL(20, 8),
leverage INT,
uptime INT,
profit_price DECIMAL(20, 8),
stop_price DECIMAL(20, 8),
liquidation_price DECIMAL(20, 8),
PRIMARY KEY (k_id, st_id, symbol, side)
)
"""))
# 2. 清空临时表
session.execute(text("TRUNCATE TABLE temp_positions"))
# 3. 批量插入数据到临时表
self._batch_insert_to_temp_table(session, positions)
# 4. 使用临时表更新主表
# 更新已存在的记录
update_result = session.execute(text(f"""
UPDATE deh_strategy_position_new main
INNER JOIN temp_positions temp
ON main.k_id = temp.k_id
AND main.st_id = temp.st_id
AND main.symbol = temp.symbol
AND main.side = temp.side
SET main.price = temp.price,
main.`sum` = temp.`sum`,
main.asset_num = temp.asset_num,
main.asset_profit = temp.asset_profit,
main.leverage = temp.leverage,
main.uptime = temp.uptime,
main.profit_price = temp.profit_price,
main.stop_price = temp.stop_price,
main.liquidation_price = temp.liquidation_price
WHERE main.k_id = {k_id} AND main.st_id = {st_id}
"""))
updated_count = update_result.rowcount
# 插入新记录
insert_result = session.execute(text(f"""
INSERT INTO deh_strategy_position_new
(st_id, k_id, asset, symbol, side, price, `sum`,
asset_num, asset_profit, leverage, uptime,
profit_price, stop_price, liquidation_price)
SELECT
st_id, k_id, asset, symbol, side, price, `sum`,
asset_num, asset_profit, leverage, uptime,
profit_price, stop_price, liquidation_price
FROM temp_positions temp
WHERE NOT EXISTS (
SELECT 1 FROM deh_strategy_position_new main
WHERE main.k_id = temp.k_id
AND main.st_id = temp.st_id
AND main.symbol = temp.symbol
AND main.side = temp.side
)
AND temp.k_id = {k_id} AND temp.st_id = {st_id}
"""))
inserted_count = insert_result.rowcount
# 5. 删除多余持仓(在临时表中不存在但在主表中存在的)
delete_result = session.execute(text(f"""
DELETE main
FROM deh_strategy_position_new main
LEFT JOIN temp_positions temp
ON main.k_id = temp.k_id
AND main.st_id = temp.st_id
AND main.symbol = temp.symbol
AND main.side = temp.side
WHERE main.k_id = {k_id} AND main.st_id = {st_id}
AND temp.symbol IS NULL
"""))
deleted_count = delete_result.rowcount
# 6. 删除临时表
session.execute(text("DROP TEMPORARY TABLE IF EXISTS temp_positions"))
stats = {
'total': len(positions),
'updated': updated_count,
'inserted': inserted_count,
'deleted': deleted_count
}
logger.debug(f"账号({k_id},{st_id})持仓同步: 更新{updated_count} 插入{inserted_count} 删除{deleted_count}")
return True, stats
except Exception as e:
logger.error(f"临时表同步账号({k_id},{st_id})持仓失败: {e}")
return False, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0}
def _batch_insert_to_temp_table(self, session, positions: List[Dict]):
"""批量插入数据到临时表(使用参数化查询)"""
if not positions:
return
# 分块处理
for i in range(0, len(positions), self.batch_size):
chunk = positions[i:i + self.batch_size]
# 准备参数化数据
insert_data = []
for position in chunk:
try:
data = self._convert_position_for_temp(position)
if not all([data.get('symbol'), data.get('side')]):
continue
insert_data.append({
'st_id': data['st_id'],
'k_id': data['k_id'],
'asset': data.get('asset', 'USDT'),
'symbol': data['symbol'],
'side': data['side'],
'price': data.get('price'),
'sum_val': data.get('sum'), # 注意字段名
'asset_num': data.get('asset_num'),
'asset_profit': data.get('asset_profit'),
'leverage': data.get('leverage'),
'uptime': data.get('uptime'),
'profit_price': data.get('profit_price'),
'stop_price': data.get('stop_price'),
'liquidation_price': data.get('liquidation_price')
})
except Exception as e:
logger.error(f"转换持仓数据失败: {position}, error={e}")
continue
if insert_data:
sql = """
INSERT INTO temp_positions
(st_id, k_id, asset, symbol, side, price, `sum`,
asset_num, asset_profit, leverage, uptime,
profit_price, stop_price, liquidation_price)
VALUES
(:st_id, :k_id, :asset, :symbol, :side, :price, :sum_val,
:asset_num, :asset_profit, :leverage, :uptime,
:profit_price, :stop_price, :liquidation_price)
"""
session.execute(text(sql), insert_data)
def _convert_position_for_temp(self, data: Dict) -> Dict:
"""转换持仓数据格式用于临时表"""
# 使用安全转换
def safe_float(value):
try:
return float(value) if value is not None else None
except:
return None
def safe_int(value):
try:
return int(value) if value is not None else None
except:
return None
return {
'st_id': safe_int(data.get('st_id')) or 0,
'k_id': safe_int(data.get('k_id')) or 0,
'asset': data.get('asset', 'USDT'),
'symbol': str(data.get('symbol', '')),
'side': str(data.get('side', '')),
'price': safe_float(data.get('price')),
'sum': safe_float(data.get('qty')), # 注意这里直接使用sum
'asset_num': safe_float(data.get('asset_num')),
'asset_profit': safe_float(data.get('asset_profit')),
'leverage': safe_int(data.get('leverage')),
'uptime': safe_int(data.get('uptime')),
'profit_price': safe_float(data.get('profit_price')),
'stop_price': safe_float(data.get('stop_price')),
'liquidation_price': safe_float(data.get('liquidation_price'))
}

129
utils/redis_batch_helper.py Normal file
View File

@@ -0,0 +1,129 @@
import redis
from loguru import logger
from typing import List, Dict, Tuple
import json
import time
from datetime import datetime, timedelta
class RedisBatchHelper:
"""Redis批量数据获取助手"""
def __init__(self, redis_client):
self.redis_client = redis_client
def get_recent_orders_batch(self, exchange_id: str, account_list: List[Tuple[int, int]],
recent_days: int = 3) -> List[Dict]:
"""批量获取多个账号的最近订单数据(优化内存使用)"""
all_orders = []
try:
# 分批处理账号,避免内存过大
batch_size = 20 # 每批处理20个账号
for i in range(0, len(account_list), batch_size):
batch_accounts = account_list[i:i + batch_size]
# 并发获取这批账号的数据
batch_orders = self._get_batch_accounts_orders(exchange_id, batch_accounts, recent_days)
all_orders.extend(batch_orders)
# 批次间休息避免Redis压力过大
if i + batch_size < len(account_list):
time.sleep(0.05)
logger.info(f"批量获取订单完成: {len(account_list)}个账号,{len(all_orders)}条订单")
except Exception as e:
logger.error(f"批量获取订单失败: {e}")
return all_orders
def _get_batch_accounts_orders(self, exchange_id: str, account_list: List[Tuple[int, int]],
recent_days: int) -> List[Dict]:
"""获取一批账号的订单数据"""
batch_orders = []
try:
# 计算最近日期
today = datetime.now()
recent_dates = []
for i in range(recent_days):
date = today - timedelta(days=i)
recent_dates.append(date.strftime('%Y-%m-%d'))
# 为每个账号构建key列表
all_keys = []
key_to_account = {}
for k_id, st_id in account_list:
redis_key = f"{exchange_id}:orders:{k_id}"
# 获取该账号的所有key
try:
account_keys = self.redis_client.hkeys(redis_key)
for key in account_keys:
key_str = key.decode('utf-8') if isinstance(key, bytes) else key
if key_str == 'positions':
continue
# 检查是否是最近日期
for date_format in recent_dates:
if key_str.startswith(date_format + '_'):
all_keys.append((redis_key, key_str))
key_to_account[(redis_key, key_str)] = (k_id, st_id)
break
except Exception as e:
logger.error(f"获取账号 {k_id} 的key失败: {e}")
continue
if not all_keys:
return batch_orders
# 分批获取订单数据
chunk_size = 500
for i in range(0, len(all_keys), chunk_size):
chunk = all_keys[i:i + chunk_size]
# 按redis_key分组使用hmget批量获取
keys_by_redis_key = {}
for redis_key, key_str in chunk:
if redis_key not in keys_by_redis_key:
keys_by_redis_key[redis_key] = []
keys_by_redis_key[redis_key].append(key_str)
# 为每个redis_key批量获取
for redis_key, key_list in keys_by_redis_key.items():
try:
values = self.redis_client.hmget(redis_key, key_list)
for key_str, order_json in zip(key_list, values):
if not order_json:
continue
try:
order = json.loads(order_json)
# 验证时间
order_time = order.get('time', 0)
if order_time >= int(time.time()) - recent_days * 24 * 3600:
# 添加账号信息
k_id, st_id = key_to_account.get((redis_key, key_str), (0, 0))
order['k_id'] = k_id
order['st_id'] = st_id
order['exchange_id'] = exchange_id
batch_orders.append(order)
except json.JSONDecodeError as e:
logger.debug(f"解析订单JSON失败: key={key_str}, error={e}")
continue
except Exception as e:
logger.error(f"批量获取Redis数据失败: {redis_key}, error={e}")
continue
except Exception as e:
logger.error(f"获取批量账号订单失败: {e}")
return batch_orders