537 lines
23 KiB
Python
537 lines
23 KiB
Python
from .base_sync import BaseSync
|
||
from loguru import logger
|
||
from typing import List, Dict, Any, Set, Tuple
|
||
import json
|
||
import asyncio
|
||
import utils.helpers as helpers
|
||
from datetime import datetime
|
||
from sqlalchemy import text, and_, select, delete
|
||
from models.orm_models import StrategyPosition
|
||
import time
|
||
|
||
class PositionSyncBatch(BaseSync):
|
||
"""持仓数据批量同步器"""
|
||
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.batch_size = 500 # 每批处理数量
|
||
|
||
async def sync_batch(self, accounts: Dict[str, Dict]):
|
||
"""批量同步所有账号的持仓数据"""
|
||
try:
|
||
logger.info(f"开始批量同步持仓数据,共 {len(accounts)} 个账号")
|
||
|
||
start_time = time.time()
|
||
|
||
# 1. 收集所有账号的持仓数据
|
||
all_positions = await self._collect_all_positions(accounts)
|
||
|
||
if not all_positions:
|
||
logger.info("无持仓数据需要同步")
|
||
return
|
||
|
||
logger.info(f"收集到 {len(all_positions)} 条持仓数据")
|
||
|
||
# 2. 批量同步到数据库
|
||
success, stats = await self._sync_positions_batch_to_db_optimized_v3(all_positions)
|
||
|
||
elapsed = time.time() - start_time
|
||
if success:
|
||
logger.info(f"持仓批量同步完成: 处理 {stats['total']} 条,受影响 {stats['affected']} 条,"
|
||
f"删除 {stats['deleted']} 条,耗时 {elapsed:.2f}秒")
|
||
else:
|
||
logger.error("持仓批量同步失败")
|
||
|
||
except Exception as e:
|
||
logger.error(f"持仓批量同步失败: {e}")# 获取完整的错误信息
|
||
import traceback
|
||
error_details = {
|
||
'error_type': type(e).__name__,
|
||
'error_message': str(e),
|
||
'traceback': traceback.format_exc()
|
||
}
|
||
logger.error("完整堆栈跟踪:\n{traceback}", traceback=error_details['traceback'])
|
||
|
||
async def _sync_positions_batch_to_db_optimized(self, all_positions: List[Dict]) -> Tuple[bool, Dict]:
|
||
"""
|
||
批量同步持仓数据(不使用临时表)
|
||
|
||
Args:
|
||
all_positions: 所有持仓数据列表,每个持仓包含k_id(账号ID)等字段
|
||
|
||
Returns:
|
||
Tuple[bool, Dict]: (是否成功, 结果统计)
|
||
"""
|
||
if not all_positions:
|
||
return True, {'total': 0, 'affected': 0, 'deleted': 0, 'errors': []}
|
||
|
||
|
||
session = self.db_manager.get_session()
|
||
|
||
results = {
|
||
'total': 0,
|
||
'affected': 0,
|
||
'deleted': 0,
|
||
'errors': []
|
||
}
|
||
|
||
# 按账号分组
|
||
positions_by_account = {}
|
||
for position in all_positions:
|
||
# print(position['symbol'])
|
||
k_id = position['k_id']
|
||
if k_id not in positions_by_account:
|
||
positions_by_account[k_id] = []
|
||
positions_by_account[k_id].append(position)
|
||
|
||
logger.info(f"开始批量处理 {len(positions_by_account)} 个账号的持仓数据")
|
||
|
||
try:
|
||
# 按分组处理(10个账号一组)
|
||
account_ids = list(positions_by_account.keys())
|
||
|
||
for group_idx in range(0, len(account_ids), 10):
|
||
group_account_ids = account_ids[group_idx:group_idx + 10]
|
||
logger.info(f"处理第 {group_idx//10 + 1} 组账号: {group_account_ids}")
|
||
|
||
# 收集本组所有持仓数据
|
||
group_positions = []
|
||
for k_id in group_account_ids:
|
||
group_positions.extend(positions_by_account[k_id])
|
||
|
||
if not group_positions:
|
||
continue
|
||
|
||
# 处理持仓数据
|
||
processed_positions = []
|
||
account_position_keys = {} # 记录每个账号的持仓标识
|
||
|
||
for raw_position in group_positions:
|
||
try:
|
||
k_id = raw_position['k_id']
|
||
processed = self._convert_position_data(raw_position)
|
||
|
||
# 检查必要字段
|
||
if not all([processed.get('symbol'), processed.get('side')]):
|
||
continue
|
||
|
||
# 确保st_id存在
|
||
if 'st_id' not in processed:
|
||
processed['st_id'] = raw_position.get('st_id', 0)
|
||
|
||
# 确保k_id存在
|
||
if 'k_id' not in processed:
|
||
processed['k_id'] = k_id
|
||
|
||
# 重命名qty为sum(如果存在)
|
||
if 'qty' in processed:
|
||
processed['sum'] = processed.pop('qty')
|
||
|
||
|
||
processed_positions.append(processed)
|
||
|
||
# 记录持仓唯一标识
|
||
if k_id not in account_position_keys:
|
||
account_position_keys[k_id] = set()
|
||
|
||
position_key = f"{processed['st_id']}&{processed['symbol']}&{processed['side']}"
|
||
# print(position_key)
|
||
account_position_keys[k_id].add(position_key)
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理持仓数据失败: {raw_position}, error={e}")
|
||
continue
|
||
|
||
# 批量插入或更新
|
||
if processed_positions:
|
||
try:
|
||
# 使用ON DUPLICATE KEY UPDATE批量处理
|
||
upsert_sql = text("""
|
||
INSERT INTO deh_strategy_position_new
|
||
(st_id, k_id, asset, symbol, side, price, `sum`,
|
||
asset_num, asset_profit, leverage, uptime,
|
||
profit_price, stop_price, liquidation_price)
|
||
VALUES
|
||
(:st_id, :k_id, :asset, :symbol, :side, :price, :sum,
|
||
:asset_num, :asset_profit, :leverage, :uptime,
|
||
:profit_price, :stop_price, :liquidation_price)
|
||
ON DUPLICATE KEY UPDATE
|
||
price = VALUES(price),
|
||
`sum` = VALUES(`sum`),
|
||
asset_num = VALUES(asset_num),
|
||
asset_profit = VALUES(asset_profit),
|
||
leverage = VALUES(leverage),
|
||
uptime = VALUES(uptime),
|
||
profit_price = VALUES(profit_price),
|
||
stop_price = VALUES(stop_price),
|
||
liquidation_price = VALUES(liquidation_price)
|
||
""")
|
||
|
||
result = session.execute(upsert_sql, processed_positions)
|
||
|
||
# 正确计算插入和更新的数量
|
||
total_affected = result.rowcount # 受影响的总行数
|
||
batch_size = len(processed_positions) # 本次尝试插入的数量
|
||
|
||
# 累加到总结果
|
||
results['total'] += batch_size
|
||
results['affected'] += total_affected
|
||
|
||
logger.debug(f"第 {group_idx//10 + 1} 组: "
|
||
f"处理 {batch_size} 条, "
|
||
f"受影响 {total_affected} 条")
|
||
|
||
except Exception as e:
|
||
logger.error(f"批量插入/更新失败: {e}", exc_info=True)
|
||
session.rollback()
|
||
results['errors'].append(f"批量插入/更新失败: {str(e)}")
|
||
# 继续处理下一组
|
||
continue
|
||
|
||
# 删除本组每个账号中已不存在的持仓
|
||
for k_id in group_account_ids:
|
||
try:
|
||
if k_id not in account_position_keys or not account_position_keys[k_id]:
|
||
# 如果该账号没有任何持仓,删除所有
|
||
delete_sql = text("""
|
||
DELETE FROM deh_strategy_position_new
|
||
WHERE k_id = :k_id
|
||
""")
|
||
|
||
result = session.execute(delete_sql, {'k_id': k_id})
|
||
deleted_count = result.rowcount
|
||
results['deleted'] += deleted_count
|
||
|
||
if deleted_count > 0:
|
||
logger.debug(f"账号 {k_id}: 删除所有旧持仓,共 {deleted_count} 条")
|
||
|
||
else:
|
||
# 构建当前持仓的条件
|
||
current_keys = account_position_keys[k_id]
|
||
|
||
# 使用多个OR条件来处理IN子句的限制
|
||
conditions = []
|
||
params = {'k_id': k_id}
|
||
|
||
for idx, key in enumerate(current_keys):
|
||
parts = key.split('&')
|
||
if len(parts) >= 3: # 确保有st_id, symbol, side三部分
|
||
st_id_val = parts[0]
|
||
symbol_val = parts[1]
|
||
side_val = parts[2]
|
||
|
||
conditions.append(f"(st_id = :st_id_{idx} AND symbol = :symbol_{idx} AND side = :side_{idx})")
|
||
params[f'st_id_{idx}'] = int(st_id_val) if st_id_val.isdigit() else st_id_val
|
||
params[f'symbol_{idx}'] = symbol_val
|
||
params[f'side_{idx}'] = side_val
|
||
|
||
if conditions:
|
||
conditions_str = " OR ".join(conditions)
|
||
|
||
# 删除不在当前持仓列表中的记录
|
||
delete_sql = text(f"""
|
||
DELETE FROM deh_strategy_position_new
|
||
WHERE k_id = :k_id
|
||
AND NOT ({conditions_str})
|
||
""")
|
||
|
||
result = session.execute(delete_sql, params)
|
||
deleted_count = result.rowcount
|
||
results['deleted'] += deleted_count
|
||
|
||
if deleted_count > 0:
|
||
logger.debug(f"账号 {k_id}: 删除 {deleted_count} 条过期持仓")
|
||
|
||
except Exception as e:
|
||
logger.error(f"删除账号 {k_id} 旧持仓失败: {e}")
|
||
# 记录错误但继续处理其他账号
|
||
results['errors'].append(f"删除账号 {k_id} 旧持仓失败: {str(e)}")
|
||
|
||
# 每组结束后提交
|
||
try:
|
||
session.commit()
|
||
logger.debug(f"第 {group_idx//10 + 1} 组处理完成并提交")
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"第 {group_idx//10 + 1} 组提交失败: {e}")
|
||
results['errors'].append(f"第 {group_idx//10 + 1} 组提交失败: {str(e)}")
|
||
|
||
logger.info(f"批量同步完成: "
|
||
f"总数={results['total']}, "
|
||
f"受影响={results['affected']}, "
|
||
f"删除={results['deleted']}, "
|
||
f"错误数={len(results['errors'])}")
|
||
|
||
|
||
success = len(results['errors']) == 0
|
||
return success, results
|
||
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"批量同步过程中发生错误: {e}", exc_info=True)
|
||
results['errors'].append(f"同步过程错误: {str(e)}")
|
||
return False, results
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
async def _sync_positions_batch_to_db_optimized_v3(self, all_positions: List[Dict]) -> Tuple[bool, Dict]:
|
||
"""
|
||
最优化的批量同步(兼容所有MySQL版本)
|
||
|
||
使用策略:
|
||
1. 一次性UPSERT所有持仓数据
|
||
2. 使用UNION ALL构造虚拟表进行JOIN删除
|
||
|
||
Args:
|
||
all_positions: 所有持仓数据列表
|
||
|
||
Returns:
|
||
Tuple[bool, Dict]: (是否成功, 结果统计)
|
||
"""
|
||
if not all_positions:
|
||
return True, {'total': 0, 'affected': 0, 'deleted': 0, 'errors': []}
|
||
|
||
session = self.db_manager.get_session()
|
||
|
||
results = {
|
||
'total': 0,
|
||
'affected': 0,
|
||
'deleted': 0,
|
||
'errors': []
|
||
}
|
||
|
||
try:
|
||
session.begin()
|
||
|
||
# 准备数据
|
||
processed_positions = []
|
||
current_position_records = set() # 使用set去重,避免重复
|
||
|
||
for raw_position in all_positions:
|
||
try:
|
||
processed = self._convert_position_data(raw_position)
|
||
|
||
if not all([processed.get('symbol'), processed.get('side')]):
|
||
continue
|
||
|
||
if 'qty' in processed:
|
||
processed['sum'] = processed.pop('qty')
|
||
|
||
k_id = processed.get('k_id', raw_position['k_id'])
|
||
st_id = processed.get('st_id', raw_position.get('st_id', 0))
|
||
symbol = processed.get('symbol')
|
||
side = processed.get('side')
|
||
|
||
|
||
processed_positions.append(processed)
|
||
|
||
# 去重记录当前持仓
|
||
record_key = (k_id, st_id, symbol, side)
|
||
current_position_records.add(record_key)
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理持仓数据失败: {raw_position}, error={e}")
|
||
continue
|
||
|
||
if not processed_positions:
|
||
session.commit()
|
||
return True, results
|
||
|
||
results['total'] = len(processed_positions)
|
||
logger.info(f"准备同步 {results['total']} 条持仓数据,去重后 {len(current_position_records)} 条唯一记录")
|
||
|
||
# 批量UPSERT
|
||
upsert_sql = text("""
|
||
INSERT INTO deh_strategy_position_new
|
||
(st_id, k_id, asset, symbol, side, price, `sum`,
|
||
asset_num, asset_profit, leverage, uptime,
|
||
profit_price, stop_price, liquidation_price)
|
||
VALUES
|
||
(:st_id, :k_id, :asset, :symbol, :side, :price, :sum,
|
||
:asset_num, :asset_profit, :leverage, :uptime,
|
||
:profit_price, :stop_price, :liquidation_price)
|
||
ON DUPLICATE KEY UPDATE
|
||
price = VALUES(price),
|
||
`sum` = VALUES(`sum`),
|
||
asset_num = VALUES(asset_num),
|
||
asset_profit = VALUES(asset_profit),
|
||
leverage = VALUES(leverage),
|
||
uptime = VALUES(uptime),
|
||
profit_price = VALUES(profit_price),
|
||
stop_price = VALUES(stop_price),
|
||
liquidation_price = VALUES(liquidation_price)
|
||
""")
|
||
|
||
result = session.execute(upsert_sql, processed_positions)
|
||
|
||
total_affected = result.rowcount
|
||
|
||
results['affected'] =total_affected
|
||
|
||
logger.info(f"UPSERT完成: 总数 {results['total']} 条, 受影响 {results['affected']} 条")
|
||
|
||
# 批量删除(使用UNION ALL构造虚拟表)
|
||
if current_position_records:
|
||
# 构建UNION ALL查询
|
||
union_parts = []
|
||
for record in current_position_records:
|
||
k_id, st_id, symbol, side = record
|
||
# 转义单引号
|
||
symbol_escaped = symbol.replace("'", "''")
|
||
side_escaped = side.replace("'", "''")
|
||
union_parts.append(f"SELECT {k_id} as k_id, {st_id} as st_id, '{symbol_escaped}' as symbol, '{side_escaped}' as side")
|
||
|
||
if union_parts:
|
||
union_sql = " UNION ALL ".join(union_parts)
|
||
|
||
# 或者使用LEFT JOIN方式
|
||
delete_sql_join = text(f"""
|
||
DELETE p FROM deh_strategy_position_new p
|
||
LEFT JOIN (
|
||
{union_sql}
|
||
) AS current_pos ON
|
||
p.k_id = current_pos.k_id
|
||
AND p.st_id = current_pos.st_id
|
||
AND p.symbol = current_pos.symbol
|
||
AND p.side = current_pos.side
|
||
WHERE current_pos.k_id IS NULL
|
||
""")
|
||
|
||
result = session.execute(delete_sql_join)
|
||
deleted_count = result.rowcount
|
||
results['deleted'] = deleted_count
|
||
|
||
logger.info(f"删除 {deleted_count} 条过期持仓")
|
||
|
||
session.commit()
|
||
|
||
logger.info(f"批量同步V3完成: 总数={results['total']}, "
|
||
f"受影响={results['affected']}, "
|
||
f"删除={results['deleted']}")
|
||
|
||
return True, results
|
||
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"批量同步V3失败: {e}", exc_info=True)
|
||
results['errors'].append(f"同步失败: {str(e)}")
|
||
return False, results
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
async def _collect_all_positions(self, accounts: Dict[str, Dict]) -> List[Dict]:
|
||
"""收集所有账号的持仓数据"""
|
||
all_positions = []
|
||
|
||
try:
|
||
# 按交易所分组账号
|
||
account_groups = self._group_accounts_by_exchange(accounts)
|
||
|
||
# 并发收集每个交易所的数据
|
||
tasks = []
|
||
for exchange_id, account_list in account_groups.items():
|
||
task = self._collect_exchange_positions(exchange_id, account_list)
|
||
tasks.append(task)
|
||
|
||
# 等待所有任务完成并合并结果
|
||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
||
for result in results:
|
||
if isinstance(result, list):
|
||
all_positions.extend(result)
|
||
|
||
except Exception as e:
|
||
logger.error(f"收集持仓数据失败: {e}")
|
||
|
||
return all_positions
|
||
|
||
def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]:
|
||
"""按交易所分组账号"""
|
||
groups = {}
|
||
for account_id, account_info in accounts.items():
|
||
exchange_id = account_info.get('exchange_id')
|
||
if exchange_id:
|
||
if exchange_id not in groups:
|
||
groups[exchange_id] = []
|
||
groups[exchange_id].append(account_info)
|
||
return groups
|
||
|
||
async def _collect_exchange_positions(self, exchange_id: str, account_list: List[Dict]) -> List[Dict]:
|
||
"""收集某个交易所的持仓数据"""
|
||
positions_list = []
|
||
|
||
try:
|
||
tasks = []
|
||
for account_info in account_list:
|
||
k_id = int(account_info['k_id'])
|
||
st_id = account_info.get('st_id', 0)
|
||
task = self._get_positions_from_redis(k_id, st_id, exchange_id)
|
||
tasks.append(task)
|
||
|
||
# 并发获取
|
||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
||
for result in results:
|
||
if isinstance(result, list):
|
||
positions_list.extend(result)
|
||
|
||
except Exception as e:
|
||
logger.error(f"收集交易所 {exchange_id} 持仓数据失败: {e}")
|
||
|
||
return positions_list
|
||
|
||
async def _get_positions_from_redis(self, k_id: int, st_id: int, exchange_id: str) -> List[Dict]:
|
||
"""从Redis获取持仓数据"""
|
||
try:
|
||
redis_key = f"{exchange_id}:positions:{k_id}"
|
||
redis_data = self.redis_client.client.hget(redis_key, 'positions')
|
||
|
||
if not redis_data:
|
||
return []
|
||
|
||
positions = json.loads(redis_data)
|
||
|
||
# 添加账号信息
|
||
for position in positions:
|
||
# print(position['symbol'])
|
||
position['k_id'] = k_id
|
||
position['st_id'] = st_id
|
||
position['exchange_id'] = exchange_id
|
||
|
||
return positions
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取Redis持仓数据失败: k_id={k_id}, error={e}")
|
||
return []
|
||
|
||
def _convert_position_data(self, data: Dict) -> Dict:
|
||
"""转换持仓数据格式"""
|
||
try:
|
||
|
||
return {
|
||
'st_id': helpers.safe_int(data.get('st_id'), 0),
|
||
'k_id': helpers.safe_int(data.get('k_id'), 0),
|
||
'asset': data.get('asset', 'USDT'),
|
||
'symbol': data.get('symbol', ''),
|
||
'side': data.get('side', ''),
|
||
'price': helpers.safe_float(data.get('price')),
|
||
'qty': helpers.safe_float(data.get('qty')), # 后面会重命名为sum
|
||
'asset_num': helpers.safe_float(data.get('asset_num')),
|
||
'asset_profit': helpers.safe_float(data.get('asset_profit')),
|
||
'leverage': helpers.safe_int(data.get('leverage')),
|
||
'uptime': helpers.safe_int(data.get('uptime')),
|
||
'profit_price': helpers.safe_float(data.get('profit_price')),
|
||
'stop_price': helpers.safe_float(data.get('stop_price')),
|
||
'liquidation_price': helpers.safe_float(data.get('liquidation_price'))
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"转换持仓数据异常: {data}, error={e}")
|
||
return {}
|
||
|
||
async def sync(self):
|
||
"""兼容旧接口"""
|
||
accounts = self.get_accounts_from_redis()
|
||
await self.sync_batch(accounts) |