This commit is contained in:
lz_db
2025-12-03 23:53:07 +08:00
parent 6958445ecf
commit f93f334256
12 changed files with 1036 additions and 2542 deletions

View File

@@ -1,41 +1,74 @@
from .base_sync import BaseSync
from loguru import logger
from typing import List, Dict
from typing import List, Dict, Any, Set, Tuple
import json
import asyncio
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from sqlalchemy import text, and_, select, delete
from models.orm_models import StrategyPosition
import time
class PositionSync(BaseSync):
"""持仓数据同步器(批量版本)"""
class PositionSyncBatch(BaseSync):
"""持仓数据批量同步器"""
def __init__(self):
super().__init__()
self.max_concurrent = 10 # 每个同步器的最大并发数
self.batch_size = 500 # 每批处理数量
async def sync_batch(self, accounts: Dict[str, Dict]):
"""批量同步所有账号的持仓数据"""
try:
logger.info(f"开始批量同步持仓数据,共 {len(accounts)} 个账号")
start_time = time.time()
# 按账号分组
account_groups = self._group_accounts_by_exchange(accounts)
# 1. 收集所有账号的持仓数据
all_positions = await self._collect_all_positions(accounts)
# 并发处理每个交易所的账号
tasks = []
for exchange_id, account_list in account_groups.items():
task = self._sync_exchange_accounts(exchange_id, account_list)
tasks.append(task)
if not all_positions:
logger.info("无持仓数据需要同步")
return
# 等待所有任务完成
results = await asyncio.gather(*tasks, return_exceptions=True)
logger.info(f"收集到 {len(all_positions)} 条持仓数据")
# 统计结果
success_count = sum(1 for r in results if isinstance(r, bool) and r)
logger.info(f"持仓批量同步完成: 成功 {success_count}/{len(results)} 个交易所组")
# 2. 批量同步到数据库
success, stats = await self._sync_positions_batch_to_db(all_positions)
elapsed = time.time() - start_time
if success:
logger.info(f"持仓批量同步完成: 处理 {stats['total']} 条,更新 {stats['updated']} 条,"
f"插入 {stats['inserted']} 条,删除 {stats['deleted']} 条,耗时 {elapsed:.2f}")
else:
logger.error("持仓批量同步失败")
except Exception as e:
logger.error(f"持仓批量同步失败: {e}")
async def _collect_all_positions(self, accounts: Dict[str, Dict]) -> List[Dict]:
"""收集所有账号的持仓数据"""
all_positions = []
try:
# 按交易所分组账号
account_groups = self._group_accounts_by_exchange(accounts)
# 并发收集每个交易所的数据
tasks = []
for exchange_id, account_list in account_groups.items():
task = self._collect_exchange_positions(exchange_id, account_list)
tasks.append(task)
# 等待所有任务完成并合并结果
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, list):
all_positions.extend(result)
except Exception as e:
logger.error(f"收集持仓数据失败: {e}")
return all_positions
def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]:
"""按交易所分组账号"""
groups = {}
@@ -47,46 +80,60 @@ class PositionSync(BaseSync):
groups[exchange_id].append(account_info)
return groups
async def _sync_exchange_accounts(self, exchange_id: str, account_list: List[Dict]):
"""同步某个交易所的所有账号"""
async def _collect_exchange_positions(self, exchange_id: str, account_list: List[Dict]) -> List[Dict]:
"""收集某个交易所的持仓数据"""
positions_list = []
try:
# 收集所有账号的持仓数据
all_positions = []
tasks = []
for account_info in account_list:
k_id = int(account_info['k_id'])
st_id = account_info.get('st_id', 0)
# 从Redis获取持仓数据
positions = await self._get_positions_from_redis(k_id, exchange_id)
if positions:
# 添加账号信息
for position in positions:
position['k_id'] = k_id
position['st_id'] = st_id
all_positions.extend(positions)
task = self._get_positions_from_redis(k_id, st_id, exchange_id)
tasks.append(task)
if not all_positions:
logger.debug(f"交易所 {exchange_id} 无持仓数据")
return True
# 并发获取
results = await asyncio.gather(*tasks, return_exceptions=True)
# 批量同步到数据库
success = self._sync_positions_batch_to_db(all_positions)
if success:
logger.info(f"交易所 {exchange_id} 持仓同步成功: {len(all_positions)} 条持仓")
return success
for result in results:
if isinstance(result, list):
positions_list.extend(result)
except Exception as e:
logger.error(f"同步交易所 {exchange_id} 持仓失败: {e}")
return False
logger.error(f"收集交易所 {exchange_id} 持仓数据失败: {e}")
return positions_list
def _sync_positions_batch_to_db(self, all_positions: List[Dict]) -> bool:
"""批量同步持仓数据到数据库(优化版)"""
session = self.db_manager.get_session()
async def _get_positions_from_redis(self, k_id: int, st_id: int, exchange_id: str) -> List[Dict]:
"""从Redis获取持仓数据"""
try:
# 按k_id分组
redis_key = f"{exchange_id}:positions:{k_id}"
redis_data = self.redis_client.client.hget(redis_key, 'positions')
if not redis_data:
return []
positions = json.loads(redis_data)
# 添加账号信息
for position in positions:
position['k_id'] = k_id
position['st_id'] = st_id
position['exchange_id'] = exchange_id
return positions
except Exception as e:
logger.error(f"获取Redis持仓数据失败: k_id={k_id}, error={e}")
return []
async def _sync_positions_batch_to_db(self, all_positions: List[Dict]) -> Tuple[bool, Dict]:
"""批量同步持仓数据到数据库"""
try:
if not all_positions:
return True, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0}
# 按账号分组
positions_by_account = {}
for position in all_positions:
k_id = position['k_id']
@@ -94,98 +141,239 @@ class PositionSync(BaseSync):
positions_by_account[k_id] = []
positions_by_account[k_id].append(position)
success_count = 0
logger.info(f"开始批量处理 {len(positions_by_account)} 个账号的持仓数据")
with session.begin():
for k_id, positions in positions_by_account.items():
try:
st_id = positions[0]['st_id'] if positions else 0
# 准备数据
insert_data = []
keep_keys = set()
for pos_data in positions:
try:
pos_dict = self._convert_position_data(pos_data)
if not all([pos_dict.get('symbol'), pos_dict.get('side')]):
continue
# 重命名qty为sum
if 'qty' in pos_dict:
pos_dict['sum'] = pos_dict.pop('qty')
insert_data.append(pos_dict)
keep_keys.add((pos_dict['symbol'], pos_dict['side']))
except Exception as e:
logger.error(f"转换持仓数据失败: {pos_data}, error={e}")
continue
if not insert_data:
continue
# 批量插入/更新
from sqlalchemy.dialects.mysql import insert
stmt = insert(StrategyPosition.__table__).values(insert_data)
update_dict = {
'price': stmt.inserted.price,
'sum': stmt.inserted.sum,
'asset_num': stmt.inserted.asset_num,
'asset_profit': stmt.inserted.asset_profit,
'leverage': stmt.inserted.leverage,
'uptime': stmt.inserted.uptime,
'profit_price': stmt.inserted.profit_price,
'stop_price': stmt.inserted.stop_price,
'liquidation_price': stmt.inserted.liquidation_price
}
stmt = stmt.on_duplicate_key_update(**update_dict)
session.execute(stmt)
# 删除多余持仓
if keep_keys:
existing_positions = session.execute(
select(StrategyPosition).where(
and_(
StrategyPosition.k_id == k_id,
StrategyPosition.st_id == st_id
)
)
).scalars().all()
to_delete_ids = []
for existing in existing_positions:
key = (existing.symbol, existing.side)
if key not in keep_keys:
to_delete_ids.append(existing.id)
if to_delete_ids:
# 分块删除
chunk_size = 100
for i in range(0, len(to_delete_ids), chunk_size):
chunk = to_delete_ids[i:i + chunk_size]
session.execute(
delete(StrategyPosition).where(
StrategyPosition.id.in_(chunk)
)
)
success_count += 1
except Exception as e:
logger.error(f"同步账号 {k_id} 持仓失败: {e}")
continue
# 批量处理每个账号
total_stats = {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0}
logger.info(f"批量同步完成: 成功 {success_count}/{len(positions_by_account)} 个账号")
return success_count > 0
for k_id, positions in positions_by_account.items():
st_id = positions[0]['st_id'] if positions else 0
# 处理单个账号的批量同步
success, stats = await self._sync_single_account_batch(k_id, st_id, positions)
if success:
total_stats['total'] += stats['total']
total_stats['updated'] += stats['updated']
total_stats['inserted'] += stats['inserted']
total_stats['deleted'] += stats['deleted']
return True, total_stats
except Exception as e:
logger.error(f"批量同步持仓到数据库失败: {e}")
return False
return False, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0}
async def _sync_single_account_batch(self, k_id: int, st_id: int, positions: List[Dict]) -> Tuple[bool, Dict]:
"""批量同步单个账号的持仓数据"""
session = self.db_manager.get_session()
try:
# 准备数据
insert_data = []
new_positions_map = {} # (symbol, side) -> position_id (用于删除)
for position_data in positions:
try:
position_dict = self._convert_position_data(position_data)
if not all([position_dict.get('symbol'), position_dict.get('side')]):
continue
symbol = position_dict['symbol']
side = position_dict['side']
key = (symbol, side)
# 重命名qty为sum
if 'qty' in position_dict:
position_dict['sum'] = position_dict.pop('qty')
insert_data.append(position_dict)
new_positions_map[key] = position_dict.get('id') # 如果有id的话
except Exception as e:
logger.error(f"转换持仓数据失败: {position_data}, error={e}")
continue
with session.begin():
if not insert_data:
# 清空该账号所有持仓
result = session.execute(
delete(StrategyPosition).where(
and_(
StrategyPosition.k_id == k_id,
StrategyPosition.st_id == st_id
)
)
)
deleted_count = result.rowcount
return True, {
'total': 0,
'updated': 0,
'inserted': 0,
'deleted': deleted_count
}
# 1. 批量插入/更新持仓数据
processed_count = self._batch_upsert_positions(session, insert_data)
# 2. 批量删除多余持仓
deleted_count = self._batch_delete_extra_positions(session, k_id, st_id, new_positions_map)
# 注意这里无法区分插入和更新的数量processed_count是总处理数
inserted_count = processed_count # 简化处理
updated_count = 0 # 需要更复杂的逻辑来区分
stats = {
'total': len(insert_data),
'updated': updated_count,
'inserted': inserted_count,
'deleted': deleted_count
}
return True, stats
except Exception as e:
logger.error(f"批量同步账号 {k_id} 持仓失败: {e}")
return False, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0}
finally:
session.close()
# 其他方法保持不变...
def _batch_upsert_positions(self, session, insert_data: List[Dict]) -> int:
"""批量插入/更新持仓数据"""
try:
# 分块处理
chunk_size = self.batch_size
total_processed = 0
for i in range(0, len(insert_data), chunk_size):
chunk = insert_data[i:i + chunk_size]
values_list = []
for data in chunk:
symbol = data.get('symbol').replace("'", "''") if data.get('symbol') else ''
values = (
f"({data['st_id']}, {data['k_id']}, '{data.get('asset', 'USDT')}', "
f"'{symbol}', "
f"{data.get('price') or 'NULL'}, {data.get('sum') or 'NULL'}, "
f"{data.get('asset_num') or 'NULL'}, {data.get('asset_profit') or 'NULL'}, "
f"{data.get('leverage') or 'NULL'}, {data.get('uptime') or 'NULL'}, "
f"{data.get('profit_price') or 'NULL'}, {data.get('stop_price') or 'NULL'}, "
f"{data.get('liquidation_price') or 'NULL'})"
)
values_list.append(values)
if values_list:
values_str = ", ".join(values_list)
sql = f"""
INSERT INTO deh_strategy_position_new
(st_id, k_id, asset, symbol, side, price, `sum`,
asset_num, asset_profit, leverage, uptime,
profit_price, stop_price, liquidation_price)
VALUES {values_str}
ON DUPLICATE KEY UPDATE
price = VALUES(price),
`sum` = VALUES(`sum`),
asset_num = VALUES(asset_num),
asset_profit = VALUES(asset_profit),
leverage = VALUES(leverage),
uptime = VALUES(uptime),
profit_price = VALUES(profit_price),
stop_price = VALUES(stop_price),
liquidation_price = VALUES(liquidation_price)
"""
session.execute(text(sql))
total_processed += len(chunk)
return total_processed
except Exception as e:
logger.error(f"批量插入/更新持仓失败: {e}")
raise
def _batch_delete_extra_positions(self, session, k_id: int, st_id: int, new_positions_map: Dict) -> int:
"""批量删除多余持仓"""
try:
if not new_positions_map:
# 删除所有持仓
result = session.execute(
delete(StrategyPosition).where(
and_(
StrategyPosition.k_id == k_id,
StrategyPosition.st_id == st_id
)
)
)
return result.rowcount
# 构建保留条件
conditions = []
for (symbol, side) in new_positions_map.keys():
safe_symbol = symbol.replace("'", "''") if symbol else ''
safe_side = side.replace("'", "''") if side else ''
conditions.append(f"(symbol = '{safe_symbol}' AND side = '{safe_side}')")
if conditions:
conditions_str = " OR ".join(conditions)
sql = f"""
DELETE FROM deh_strategy_position_new
WHERE k_id = {k_id} AND st_id = {st_id}
AND NOT ({conditions_str})
"""
result = session.execute(text(sql))
return result.rowcount
return 0
except Exception as e:
logger.error(f"批量删除持仓失败: k_id={k_id}, error={e}")
return 0
def _convert_position_data(self, data: Dict) -> Dict:
"""转换持仓数据格式"""
try:
# 安全转换函数
def safe_float(value, default=None):
if value is None:
return default
try:
return float(value)
except (ValueError, TypeError):
return default
def safe_int(value, default=None):
if value is None:
return default
try:
return int(float(value))
except (ValueError, TypeError):
return default
return {
'st_id': safe_int(data.get('st_id'), 0),
'k_id': safe_int(data.get('k_id'), 0),
'asset': data.get('asset', 'USDT'),
'symbol': data.get('symbol', ''),
'side': data.get('side', ''),
'price': safe_float(data.get('price')),
'qty': safe_float(data.get('qty')), # 后面会重命名为sum
'asset_num': safe_float(data.get('asset_num')),
'asset_profit': safe_float(data.get('asset_profit')),
'leverage': safe_int(data.get('leverage')),
'uptime': safe_int(data.get('uptime')),
'profit_price': safe_float(data.get('profit_price')),
'stop_price': safe_float(data.get('stop_price')),
'liquidation_price': safe_float(data.get('liquidation_price'))
}
except Exception as e:
logger.error(f"转换持仓数据异常: {data}, error={e}")
return {}
async def sync(self):
"""兼容旧接口"""
accounts = self.get_accounts_from_redis()
await self.sync_batch(accounts)