Files
exchange_monitor_sync/sync/position_sync_batch.py
lz_db 6729ea935d 1
2025-12-03 15:08:42 +08:00

379 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from .base_sync import BaseSync
from loguru import logger
from typing import List, Dict, Any, Set, Tuple
import json
import asyncio
from datetime import datetime
from sqlalchemy import text, and_, select, delete
from models.orm_models import StrategyPosition
import time
class PositionSyncBatch(BaseSync):
"""持仓数据批量同步器"""
def __init__(self):
super().__init__()
self.batch_size = 500 # 每批处理数量
async def sync_batch(self, accounts: Dict[str, Dict]):
"""批量同步所有账号的持仓数据"""
try:
logger.info(f"开始批量同步持仓数据,共 {len(accounts)} 个账号")
start_time = time.time()
# 1. 收集所有账号的持仓数据
all_positions = await self._collect_all_positions(accounts)
if not all_positions:
logger.info("无持仓数据需要同步")
return
logger.info(f"收集到 {len(all_positions)} 条持仓数据")
# 2. 批量同步到数据库
success, stats = await self._sync_positions_batch_to_db(all_positions)
elapsed = time.time() - start_time
if success:
logger.info(f"持仓批量同步完成: 处理 {stats['total']} 条,更新 {stats['updated']} 条,"
f"插入 {stats['inserted']} 条,删除 {stats['deleted']} 条,耗时 {elapsed:.2f}")
else:
logger.error("持仓批量同步失败")
except Exception as e:
logger.error(f"持仓批量同步失败: {e}")
async def _collect_all_positions(self, accounts: Dict[str, Dict]) -> List[Dict]:
"""收集所有账号的持仓数据"""
all_positions = []
try:
# 按交易所分组账号
account_groups = self._group_accounts_by_exchange(accounts)
# 并发收集每个交易所的数据
tasks = []
for exchange_id, account_list in account_groups.items():
task = self._collect_exchange_positions(exchange_id, account_list)
tasks.append(task)
# 等待所有任务完成并合并结果
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, list):
all_positions.extend(result)
except Exception as e:
logger.error(f"收集持仓数据失败: {e}")
return all_positions
def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]:
"""按交易所分组账号"""
groups = {}
for account_id, account_info in accounts.items():
exchange_id = account_info.get('exchange_id')
if exchange_id:
if exchange_id not in groups:
groups[exchange_id] = []
groups[exchange_id].append(account_info)
return groups
async def _collect_exchange_positions(self, exchange_id: str, account_list: List[Dict]) -> List[Dict]:
"""收集某个交易所的持仓数据"""
positions_list = []
try:
tasks = []
for account_info in account_list:
k_id = int(account_info['k_id'])
st_id = account_info.get('st_id', 0)
task = self._get_positions_from_redis(k_id, st_id, exchange_id)
tasks.append(task)
# 并发获取
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, list):
positions_list.extend(result)
except Exception as e:
logger.error(f"收集交易所 {exchange_id} 持仓数据失败: {e}")
return positions_list
async def _get_positions_from_redis(self, k_id: int, st_id: int, exchange_id: str) -> List[Dict]:
"""从Redis获取持仓数据"""
try:
redis_key = f"{exchange_id}:positions:{k_id}"
redis_data = self.redis_client.client.hget(redis_key, 'positions')
if not redis_data:
return []
positions = json.loads(redis_data)
# 添加账号信息
for position in positions:
position['k_id'] = k_id
position['st_id'] = st_id
position['exchange_id'] = exchange_id
return positions
except Exception as e:
logger.error(f"获取Redis持仓数据失败: k_id={k_id}, error={e}")
return []
async def _sync_positions_batch_to_db(self, all_positions: List[Dict]) -> Tuple[bool, Dict]:
"""批量同步持仓数据到数据库"""
try:
if not all_positions:
return True, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0}
# 按账号分组
positions_by_account = {}
for position in all_positions:
k_id = position['k_id']
if k_id not in positions_by_account:
positions_by_account[k_id] = []
positions_by_account[k_id].append(position)
logger.info(f"开始批量处理 {len(positions_by_account)} 个账号的持仓数据")
# 批量处理每个账号
total_stats = {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0}
for k_id, positions in positions_by_account.items():
st_id = positions[0]['st_id'] if positions else 0
# 处理单个账号的批量同步
success, stats = await self._sync_single_account_batch(k_id, st_id, positions)
if success:
total_stats['total'] += stats['total']
total_stats['updated'] += stats['updated']
total_stats['inserted'] += stats['inserted']
total_stats['deleted'] += stats['deleted']
return True, total_stats
except Exception as e:
logger.error(f"批量同步持仓到数据库失败: {e}")
return False, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0}
async def _sync_single_account_batch(self, k_id: int, st_id: int, positions: List[Dict]) -> Tuple[bool, Dict]:
"""批量同步单个账号的持仓数据"""
session = self.db_manager.get_session()
try:
# 准备数据
insert_data = []
new_positions_map = {} # (symbol, side) -> position_id (用于删除)
for position_data in positions:
try:
position_dict = self._convert_position_data(position_data)
if not all([position_dict.get('symbol'), position_dict.get('side')]):
continue
symbol = position_dict['symbol']
side = position_dict['side']
key = (symbol, side)
# 重命名qty为sum
if 'qty' in position_dict:
position_dict['sum'] = position_dict.pop('qty')
insert_data.append(position_dict)
new_positions_map[key] = position_dict.get('id') # 如果有id的话
except Exception as e:
logger.error(f"转换持仓数据失败: {position_data}, error={e}")
continue
with session.begin():
if not insert_data:
# 清空该账号所有持仓
result = session.execute(
delete(StrategyPosition).where(
and_(
StrategyPosition.k_id == k_id,
StrategyPosition.st_id == st_id
)
)
)
deleted_count = result.rowcount
return True, {
'total': 0,
'updated': 0,
'inserted': 0,
'deleted': deleted_count
}
# 1. 批量插入/更新持仓数据
processed_count = self._batch_upsert_positions(session, insert_data)
# 2. 批量删除多余持仓
deleted_count = self._batch_delete_extra_positions(session, k_id, st_id, new_positions_map)
# 注意这里无法区分插入和更新的数量processed_count是总处理数
inserted_count = processed_count # 简化处理
updated_count = 0 # 需要更复杂的逻辑来区分
stats = {
'total': len(insert_data),
'updated': updated_count,
'inserted': inserted_count,
'deleted': deleted_count
}
return True, stats
except Exception as e:
logger.error(f"批量同步账号 {k_id} 持仓失败: {e}")
return False, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0}
finally:
session.close()
def _batch_upsert_positions(self, session, insert_data: List[Dict]) -> int:
"""批量插入/更新持仓数据"""
try:
# 分块处理
chunk_size = self.batch_size
total_processed = 0
for i in range(0, len(insert_data), chunk_size):
chunk = insert_data[i:i + chunk_size]
values_list = []
for data in chunk:
symbol = data.get('symbol').replace("'", "''") if data.get('symbol') else ''
values = (
f"({data['st_id']}, {data['k_id']}, '{data.get('asset', 'USDT')}', "
f"'{symbol}', "
f"{data.get('price') or 'NULL'}, {data.get('sum') or 'NULL'}, "
f"{data.get('asset_num') or 'NULL'}, {data.get('asset_profit') or 'NULL'}, "
f"{data.get('leverage') or 'NULL'}, {data.get('uptime') or 'NULL'}, "
f"{data.get('profit_price') or 'NULL'}, {data.get('stop_price') or 'NULL'}, "
f"{data.get('liquidation_price') or 'NULL'})"
)
values_list.append(values)
if values_list:
values_str = ", ".join(values_list)
sql = f"""
INSERT INTO deh_strategy_position_new
(st_id, k_id, asset, symbol, side, price, `sum`,
asset_num, asset_profit, leverage, uptime,
profit_price, stop_price, liquidation_price)
VALUES {values_str}
ON DUPLICATE KEY UPDATE
price = VALUES(price),
`sum` = VALUES(`sum`),
asset_num = VALUES(asset_num),
asset_profit = VALUES(asset_profit),
leverage = VALUES(leverage),
uptime = VALUES(uptime),
profit_price = VALUES(profit_price),
stop_price = VALUES(stop_price),
liquidation_price = VALUES(liquidation_price)
"""
session.execute(text(sql))
total_processed += len(chunk)
return total_processed
except Exception as e:
logger.error(f"批量插入/更新持仓失败: {e}")
raise
def _batch_delete_extra_positions(self, session, k_id: int, st_id: int, new_positions_map: Dict) -> int:
"""批量删除多余持仓"""
try:
if not new_positions_map:
# 删除所有持仓
result = session.execute(
delete(StrategyPosition).where(
and_(
StrategyPosition.k_id == k_id,
StrategyPosition.st_id == st_id
)
)
)
return result.rowcount
# 构建保留条件
conditions = []
for (symbol, side) in new_positions_map.keys():
safe_symbol = symbol.replace("'", "''") if symbol else ''
safe_side = side.replace("'", "''") if side else ''
conditions.append(f"(symbol = '{safe_symbol}' AND side = '{safe_side}')")
if conditions:
conditions_str = " OR ".join(conditions)
sql = f"""
DELETE FROM deh_strategy_position_new
WHERE k_id = {k_id} AND st_id = {st_id}
AND NOT ({conditions_str})
"""
result = session.execute(text(sql))
return result.rowcount
return 0
except Exception as e:
logger.error(f"批量删除持仓失败: k_id={k_id}, error={e}")
return 0
def _convert_position_data(self, data: Dict) -> Dict:
"""转换持仓数据格式"""
try:
# 安全转换函数
def safe_float(value, default=None):
if value is None:
return default
try:
return float(value)
except (ValueError, TypeError):
return default
def safe_int(value, default=None):
if value is None:
return default
try:
return int(float(value))
except (ValueError, TypeError):
return default
return {
'st_id': safe_int(data.get('st_id'), 0),
'k_id': safe_int(data.get('k_id'), 0),
'asset': data.get('asset', 'USDT'),
'symbol': data.get('symbol', ''),
'side': data.get('side', ''),
'price': safe_float(data.get('price')),
'qty': safe_float(data.get('qty')), # 后面会重命名为sum
'asset_num': safe_float(data.get('asset_num')),
'asset_profit': safe_float(data.get('asset_profit')),
'leverage': safe_int(data.get('leverage')),
'uptime': safe_int(data.get('uptime')),
'profit_price': safe_float(data.get('profit_price')),
'stop_price': safe_float(data.get('stop_price')),
'liquidation_price': safe_float(data.get('liquidation_price'))
}
except Exception as e:
logger.error(f"转换持仓数据异常: {data}, error={e}")
return {}
async def sync(self):
"""兼容旧接口"""
accounts = self.get_accounts_from_redis()
await self.sync_batch(accounts)