378 lines
16 KiB
Python
378 lines
16 KiB
Python
from .base_sync import BaseSync
|
||
from loguru import logger
|
||
from typing import List, Dict, Any, Set, Tuple
|
||
import json
|
||
import asyncio
|
||
from datetime import datetime
|
||
from sqlalchemy import text, and_, select, delete
|
||
from models.orm_models import StrategyPosition
|
||
import time
|
||
|
||
class PositionSyncBatch(BaseSync):
|
||
"""持仓数据批量同步器"""
|
||
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.batch_size = 500 # 每批处理数量
|
||
|
||
async def sync_batch(self, accounts: Dict[str, Dict]):
|
||
"""批量同步所有账号的持仓数据"""
|
||
try:
|
||
logger.info(f"开始批量同步持仓数据,共 {len(accounts)} 个账号")
|
||
start_time = time.time()
|
||
|
||
# 1. 收集所有账号的持仓数据
|
||
all_positions = await self._collect_all_positions(accounts)
|
||
|
||
if not all_positions:
|
||
logger.info("无持仓数据需要同步")
|
||
return
|
||
|
||
logger.info(f"收集到 {len(all_positions)} 条持仓数据")
|
||
|
||
# 2. 批量同步到数据库
|
||
success, stats = await self._sync_positions_batch_to_db(all_positions)
|
||
|
||
elapsed = time.time() - start_time
|
||
if success:
|
||
logger.info(f"持仓批量同步完成: 处理 {stats['total']} 条,更新 {stats['updated']} 条,"
|
||
f"插入 {stats['inserted']} 条,删除 {stats['deleted']} 条,耗时 {elapsed:.2f}秒")
|
||
else:
|
||
logger.error("持仓批量同步失败")
|
||
|
||
except Exception as e:
|
||
logger.error(f"持仓批量同步失败: {e}")
|
||
|
||
async def _collect_all_positions(self, accounts: Dict[str, Dict]) -> List[Dict]:
|
||
"""收集所有账号的持仓数据"""
|
||
all_positions = []
|
||
|
||
try:
|
||
# 按交易所分组账号
|
||
account_groups = self._group_accounts_by_exchange(accounts)
|
||
|
||
# 并发收集每个交易所的数据
|
||
tasks = []
|
||
for exchange_id, account_list in account_groups.items():
|
||
task = self._collect_exchange_positions(exchange_id, account_list)
|
||
tasks.append(task)
|
||
|
||
# 等待所有任务完成并合并结果
|
||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
||
for result in results:
|
||
if isinstance(result, list):
|
||
all_positions.extend(result)
|
||
|
||
except Exception as e:
|
||
logger.error(f"收集持仓数据失败: {e}")
|
||
|
||
return all_positions
|
||
|
||
def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]:
|
||
"""按交易所分组账号"""
|
||
groups = {}
|
||
for account_id, account_info in accounts.items():
|
||
exchange_id = account_info.get('exchange_id')
|
||
if exchange_id:
|
||
if exchange_id not in groups:
|
||
groups[exchange_id] = []
|
||
groups[exchange_id].append(account_info)
|
||
return groups
|
||
|
||
async def _collect_exchange_positions(self, exchange_id: str, account_list: List[Dict]) -> List[Dict]:
|
||
"""收集某个交易所的持仓数据"""
|
||
positions_list = []
|
||
|
||
try:
|
||
tasks = []
|
||
for account_info in account_list:
|
||
k_id = int(account_info['k_id'])
|
||
st_id = account_info.get('st_id', 0)
|
||
task = self._get_positions_from_redis(k_id, st_id, exchange_id)
|
||
tasks.append(task)
|
||
|
||
# 并发获取
|
||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
||
for result in results:
|
||
if isinstance(result, list):
|
||
positions_list.extend(result)
|
||
|
||
except Exception as e:
|
||
logger.error(f"收集交易所 {exchange_id} 持仓数据失败: {e}")
|
||
|
||
return positions_list
|
||
|
||
async def _get_positions_from_redis(self, k_id: int, st_id: int, exchange_id: str) -> List[Dict]:
|
||
"""从Redis获取持仓数据"""
|
||
try:
|
||
redis_key = f"{exchange_id}:positions:{k_id}"
|
||
redis_data = self.redis_client.client.hget(redis_key, 'positions')
|
||
|
||
if not redis_data:
|
||
return []
|
||
|
||
positions = json.loads(redis_data)
|
||
|
||
# 添加账号信息
|
||
for position in positions:
|
||
position['k_id'] = k_id
|
||
position['st_id'] = st_id
|
||
position['exchange_id'] = exchange_id
|
||
|
||
return positions
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取Redis持仓数据失败: k_id={k_id}, error={e}")
|
||
return []
|
||
|
||
async def _sync_positions_batch_to_db(self, all_positions: List[Dict]) -> Tuple[bool, Dict]:
|
||
"""批量同步持仓数据到数据库"""
|
||
try:
|
||
if not all_positions:
|
||
return True, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0}
|
||
|
||
# 按账号分组
|
||
positions_by_account = {}
|
||
for position in all_positions:
|
||
k_id = position['k_id']
|
||
if k_id not in positions_by_account:
|
||
positions_by_account[k_id] = []
|
||
positions_by_account[k_id].append(position)
|
||
|
||
logger.info(f"开始批量处理 {len(positions_by_account)} 个账号的持仓数据")
|
||
|
||
# 批量处理每个账号
|
||
total_stats = {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0}
|
||
|
||
for k_id, positions in positions_by_account.items():
|
||
st_id = positions[0]['st_id'] if positions else 0
|
||
|
||
# 处理单个账号的批量同步
|
||
success, stats = await self._sync_single_account_batch(k_id, st_id, positions)
|
||
|
||
if success:
|
||
total_stats['total'] += stats['total']
|
||
total_stats['updated'] += stats['updated']
|
||
total_stats['inserted'] += stats['inserted']
|
||
total_stats['deleted'] += stats['deleted']
|
||
|
||
return True, total_stats
|
||
|
||
except Exception as e:
|
||
logger.error(f"批量同步持仓到数据库失败: {e}")
|
||
return False, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0}
|
||
|
||
async def _sync_single_account_batch(self, k_id: int, st_id: int, positions: List[Dict]) -> Tuple[bool, Dict]:
|
||
"""批量同步单个账号的持仓数据"""
|
||
session = self.db_manager.get_session()
|
||
try:
|
||
# 准备数据
|
||
insert_data = []
|
||
new_positions_map = {} # (symbol, side) -> position_id (用于删除)
|
||
|
||
for position_data in positions:
|
||
try:
|
||
position_dict = self._convert_position_data(position_data)
|
||
if not all([position_dict.get('symbol'), position_dict.get('side')]):
|
||
continue
|
||
|
||
symbol = position_dict['symbol']
|
||
side = position_dict['side']
|
||
key = (symbol, side)
|
||
|
||
# 重命名qty为sum
|
||
if 'qty' in position_dict:
|
||
position_dict['sum'] = position_dict.pop('qty')
|
||
|
||
insert_data.append(position_dict)
|
||
new_positions_map[key] = position_dict.get('id') # 如果有id的话
|
||
|
||
except Exception as e:
|
||
logger.error(f"转换持仓数据失败: {position_data}, error={e}")
|
||
continue
|
||
|
||
with session.begin():
|
||
if not insert_data:
|
||
# 清空该账号所有持仓
|
||
result = session.execute(
|
||
delete(StrategyPosition).where(
|
||
and_(
|
||
StrategyPosition.k_id == k_id,
|
||
StrategyPosition.st_id == st_id
|
||
)
|
||
)
|
||
)
|
||
deleted_count = result.rowcount
|
||
|
||
return True, {
|
||
'total': 0,
|
||
'updated': 0,
|
||
'inserted': 0,
|
||
'deleted': deleted_count
|
||
}
|
||
|
||
# 1. 批量插入/更新持仓数据
|
||
processed_count = self._batch_upsert_positions(session, insert_data)
|
||
|
||
# 2. 批量删除多余持仓
|
||
deleted_count = self._batch_delete_extra_positions(session, k_id, st_id, new_positions_map)
|
||
|
||
# 注意:这里无法区分插入和更新的数量,processed_count是总处理数
|
||
inserted_count = processed_count # 简化处理
|
||
updated_count = 0 # 需要更复杂的逻辑来区分
|
||
|
||
stats = {
|
||
'total': len(insert_data),
|
||
'updated': updated_count,
|
||
'inserted': inserted_count,
|
||
'deleted': deleted_count
|
||
}
|
||
|
||
return True, stats
|
||
|
||
except Exception as e:
|
||
logger.error(f"批量同步账号 {k_id} 持仓失败: {e}")
|
||
return False, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0}
|
||
finally:
|
||
session.close()
|
||
|
||
def _batch_upsert_positions(self, session, insert_data: List[Dict]) -> int:
|
||
"""批量插入/更新持仓数据"""
|
||
try:
|
||
# 分块处理
|
||
chunk_size = self.batch_size
|
||
total_processed = 0
|
||
|
||
for i in range(0, len(insert_data), chunk_size):
|
||
chunk = insert_data[i:i + chunk_size]
|
||
|
||
values_list = []
|
||
for data in chunk:
|
||
values = (
|
||
f"({data['st_id']}, {data['k_id']}, '{data.get('asset', 'USDT')}', "
|
||
f"'{data['symbol'].replace(\"'\", \"''\")}', '{data['side']}', "
|
||
f"{data.get('price') or 'NULL'}, {data.get('sum') or 'NULL'}, "
|
||
f"{data.get('asset_num') or 'NULL'}, {data.get('asset_profit') or 'NULL'}, "
|
||
f"{data.get('leverage') or 'NULL'}, {data.get('uptime') or 'NULL'}, "
|
||
f"{data.get('profit_price') or 'NULL'}, {data.get('stop_price') or 'NULL'}, "
|
||
f"{data.get('liquidation_price') or 'NULL'})"
|
||
)
|
||
values_list.append(values)
|
||
|
||
if values_list:
|
||
values_str = ", ".join(values_list)
|
||
|
||
sql = f"""
|
||
INSERT INTO deh_strategy_position_new
|
||
(st_id, k_id, asset, symbol, side, price, `sum`,
|
||
asset_num, asset_profit, leverage, uptime,
|
||
profit_price, stop_price, liquidation_price)
|
||
VALUES {values_str}
|
||
ON DUPLICATE KEY UPDATE
|
||
price = VALUES(price),
|
||
`sum` = VALUES(`sum`),
|
||
asset_num = VALUES(asset_num),
|
||
asset_profit = VALUES(asset_profit),
|
||
leverage = VALUES(leverage),
|
||
uptime = VALUES(uptime),
|
||
profit_price = VALUES(profit_price),
|
||
stop_price = VALUES(stop_price),
|
||
liquidation_price = VALUES(liquidation_price)
|
||
"""
|
||
|
||
session.execute(text(sql))
|
||
total_processed += len(chunk)
|
||
|
||
return total_processed
|
||
|
||
except Exception as e:
|
||
logger.error(f"批量插入/更新持仓失败: {e}")
|
||
raise
|
||
|
||
def _batch_delete_extra_positions(self, session, k_id: int, st_id: int, new_positions_map: Dict) -> int:
|
||
"""批量删除多余持仓"""
|
||
try:
|
||
if not new_positions_map:
|
||
# 删除所有持仓
|
||
result = session.execute(
|
||
delete(StrategyPosition).where(
|
||
and_(
|
||
StrategyPosition.k_id == k_id,
|
||
StrategyPosition.st_id == st_id
|
||
)
|
||
)
|
||
)
|
||
return result.rowcount
|
||
|
||
# 构建保留条件
|
||
conditions = []
|
||
for (symbol, side) in new_positions_map.keys():
|
||
safe_symbol = symbol.replace("'", "''") if symbol else ''
|
||
safe_side = side.replace("'", "''") if side else ''
|
||
conditions.append(f"(symbol = '{safe_symbol}' AND side = '{safe_side}')")
|
||
|
||
if conditions:
|
||
conditions_str = " OR ".join(conditions)
|
||
|
||
sql = f"""
|
||
DELETE FROM deh_strategy_position_new
|
||
WHERE k_id = {k_id} AND st_id = {st_id}
|
||
AND NOT ({conditions_str})
|
||
"""
|
||
|
||
result = session.execute(text(sql))
|
||
return result.rowcount
|
||
|
||
return 0
|
||
|
||
except Exception as e:
|
||
logger.error(f"批量删除持仓失败: k_id={k_id}, error={e}")
|
||
return 0
|
||
|
||
def _convert_position_data(self, data: Dict) -> Dict:
|
||
"""转换持仓数据格式"""
|
||
try:
|
||
# 安全转换函数
|
||
def safe_float(value, default=None):
|
||
if value is None:
|
||
return default
|
||
try:
|
||
return float(value)
|
||
except (ValueError, TypeError):
|
||
return default
|
||
|
||
def safe_int(value, default=None):
|
||
if value is None:
|
||
return default
|
||
try:
|
||
return int(float(value))
|
||
except (ValueError, TypeError):
|
||
return default
|
||
|
||
return {
|
||
'st_id': safe_int(data.get('st_id'), 0),
|
||
'k_id': safe_int(data.get('k_id'), 0),
|
||
'asset': data.get('asset', 'USDT'),
|
||
'symbol': data.get('symbol', ''),
|
||
'side': data.get('side', ''),
|
||
'price': safe_float(data.get('price')),
|
||
'qty': safe_float(data.get('qty')), # 后面会重命名为sum
|
||
'asset_num': safe_float(data.get('asset_num')),
|
||
'asset_profit': safe_float(data.get('asset_profit')),
|
||
'leverage': safe_int(data.get('leverage')),
|
||
'uptime': safe_int(data.get('uptime')),
|
||
'profit_price': safe_float(data.get('profit_price')),
|
||
'stop_price': safe_float(data.get('stop_price')),
|
||
'liquidation_price': safe_float(data.get('liquidation_price'))
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"转换持仓数据异常: {data}, error={e}")
|
||
return {}
|
||
|
||
async def sync(self):
|
||
"""兼容旧接口"""
|
||
accounts = self.get_accounts_from_redis()
|
||
await self.sync_batch(accounts) |