This commit is contained in:
lz_db
2025-12-03 23:53:07 +08:00
parent 6958445ecf
commit f93f334256
12 changed files with 1036 additions and 2542 deletions

View File

@@ -1,48 +1,99 @@
from .base_sync import BaseSync from .base_sync import BaseSync
from loguru import logger from loguru import logger
from typing import List, Dict from typing import List, Dict, Any, Set
import json import json
import time import time
from datetime import datetime, timedelta from datetime import datetime, timedelta
from sqlalchemy import text, and_
from models.orm_models import StrategyKX
class AccountSync(BaseSync): class AccountSyncBatch(BaseSync):
"""账户信息同步器""" """账户信息批量同步器"""
async def sync(self): async def sync_batch(self, accounts: Dict[str, Dict]):
"""同步账户信息数据""" """批量同步所有账号的账户信息"""
try: try:
# 获取所有账号 logger.info(f"开始批量同步账户信息,共 {len(accounts)} 个账号")
accounts = self.get_accounts_from_redis()
for k_id_str, account_info in accounts.items(): # 收集所有账号的数据
try: all_account_data = await self._collect_all_account_data(accounts)
k_id = int(k_id_str)
st_id = account_info.get('st_id', 0)
exchange_id = account_info['exchange_id']
if k_id <= 0 or st_id <= 0: if not all_account_data:
continue logger.info("无账户信息数据需要同步")
return
# 从Redis获取账户信息数据 # 批量同步到数据
account_data = await self._get_account_info_from_redis(k_id, st_id, exchange_id) success = await self._sync_account_info_batch_to_db(all_account_data)
# 同步到数据库 if success:
if account_data: logger.info(f"账户信息批量同步完成: 处理 {len(all_account_data)} 条记录")
success = self._sync_account_info_to_db(account_data) else:
if success: logger.error("账户信息批量同步失败")
logger.debug(f"账户信息同步成功: k_id={k_id}")
except Exception as e:
logger.error(f"同步账号 {k_id_str} 账户信息失败: {e}")
continue
logger.info("账户信息同步完成")
except Exception as e: except Exception as e:
logger.error(f"账户信息同步失败: {e}") logger.error(f"账户信息批量同步失败: {e}")
async def _collect_all_account_data(self, accounts: Dict[str, Dict]) -> List[Dict]:
"""收集所有账号的账户信息数据"""
all_account_data = []
try:
# 按交易所分组账号
account_groups = self._group_accounts_by_exchange(accounts)
# 并发收集每个交易所的数据
tasks = []
for exchange_id, account_list in account_groups.items():
task = self._collect_exchange_account_data(exchange_id, account_list)
tasks.append(task)
# 等待所有任务完成并合并结果
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, list):
all_account_data.extend(result)
logger.info(f"收集到 {len(all_account_data)} 条账户信息记录")
except Exception as e:
logger.error(f"收集账户信息数据失败: {e}")
return all_account_data
def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]:
"""按交易所分组账号"""
groups = {}
for account_id, account_info in accounts.items():
exchange_id = account_info.get('exchange_id')
if exchange_id:
if exchange_id not in groups:
groups[exchange_id] = []
groups[exchange_id].append(account_info)
return groups
async def _collect_exchange_account_data(self, exchange_id: str, account_list: List[Dict]) -> List[Dict]:
"""收集某个交易所的账户信息数据"""
account_data_list = []
try:
for account_info in account_list:
k_id = int(account_info['k_id'])
st_id = account_info.get('st_id', 0)
# 从Redis获取账户信息数据
account_data = await self._get_account_info_from_redis(k_id, st_id, exchange_id)
account_data_list.extend(account_data)
logger.debug(f"交易所 {exchange_id}: 收集到 {len(account_data_list)} 条账户信息")
except Exception as e:
logger.error(f"收集交易所 {exchange_id} 账户信息失败: {e}")
return account_data_list
async def _get_account_info_from_redis(self, k_id: int, st_id: int, exchange_id: str) -> List[Dict]: async def _get_account_info_from_redis(self, k_id: int, st_id: int, exchange_id: str) -> List[Dict]:
"""从Redis获取账户信息数据""" """从Redis获取账户信息数据(批量优化版本)"""
try: try:
redis_key = f"{exchange_id}:balance:{k_id}" redis_key = f"{exchange_id}:balance:{k_id}"
redis_funds = self.redis_client.client.hgetall(redis_key) redis_funds = self.redis_client.client.hgetall(redis_key)
@@ -97,7 +148,9 @@ class AccountSync(BaseSync):
# 转换为账户信息数据 # 转换为账户信息数据
account_data_list = [] account_data_list = []
sorted_dates = sorted(date_stats.keys()) sorted_dates = sorted(date_stats.keys())
prev_balance = 0.0
# 获取前一天余额用于计算利润
prev_balance_map = self._get_previous_balances(redis_funds, sorted_dates)
for date_str in sorted_dates: for date_str in sorted_dates:
stats = date_stats[date_str] stats = date_stats[date_str]
@@ -111,6 +164,7 @@ class AccountSync(BaseSync):
withdrawal = stats['withdrawal'] withdrawal = stats['withdrawal']
# 计算利润 # 计算利润
prev_balance = prev_balance_map.get(date_str, 0.0)
profit = balance - deposit - withdrawal - prev_balance profit = balance - deposit - withdrawal - prev_balance
# 转换时间戳 # 转换时间戳
@@ -130,54 +184,183 @@ class AccountSync(BaseSync):
account_data_list.append(account_data) account_data_list.append(account_data)
# 更新前一天的余额
if stats['has_balance']:
prev_balance = balance
return account_data_list return account_data_list
except Exception as e: except Exception as e:
logger.error(f"获取Redis账户信息失败: k_id={k_id}, error={e}") logger.error(f"获取Redis账户信息失败: k_id={k_id}, error={e}")
return [] return []
def _sync_account_info_to_db(self, account_data_list: List[Dict]) -> bool: def _get_previous_balances(self, redis_funds: Dict, sorted_dates: List[str]) -> Dict[str, float]:
"""同步账户信息到数据库""" """获取前一天的余额"""
prev_balance_map = {}
prev_date = None
for date_str in sorted_dates:
# 查找前一天的余额
if prev_date:
for fund_key, fund_json in redis_funds.items():
try:
fund_data = json.loads(fund_json)
if (fund_data.get('lz_time') == prev_date and
fund_data.get('lz_type') == 'lz_balance'):
prev_balance_map[date_str] = float(fund_data.get('lz_amount', 0))
break
except:
continue
else:
prev_balance_map[date_str] = 0.0
prev_date = date_str
return prev_balance_map
async def _sync_account_info_batch_to_db(self, account_data_list: List[Dict]) -> bool:
"""批量同步账户信息到数据库(最高效版本)"""
session = self.db_manager.get_session() session = self.db_manager.get_session()
try: try:
if not account_data_list:
return True
with session.begin(): with session.begin():
for account_data in account_data_list: # 方法1使用原生SQL批量插入/更新(性能最好)
try: success = self._batch_upsert_account_info(session, account_data_list)
# 查询是否已存在
existing = session.execute(
select(StrategyKX).where(
and_(
StrategyKX.k_id == account_data['k_id'],
StrategyKX.st_id == account_data['st_id'],
StrategyKX.time == account_data['time']
)
)
).scalar_one_or_none()
if existing: if not success:
# 更新 # 方法2回退到ORM批量操作
existing.balance = account_data['balance'] success = self._batch_orm_upsert_account_info(session, account_data_list)
existing.withdrawal = account_data['withdrawal']
existing.deposit = account_data['deposit']
existing.other = account_data['other']
existing.profit = account_data['profit']
else:
# 插入
new_account = StrategyKX(**account_data)
session.add(new_account)
except Exception as e: return success
logger.error(f"处理账户数据失败: {account_data}, error={e}")
continue
return True
except Exception as e: except Exception as e:
logger.error(f"同步账户信息到数据库失败: error={e}") logger.error(f"批量同步账户信息到数据库失败: {e}")
return False return False
finally: finally:
session.close() session.close()
def _batch_upsert_account_info(self, session, account_data_list: List[Dict]) -> bool:
"""使用原生SQL批量插入/更新账户信息"""
try:
# 准备批量数据
values_list = []
for data in account_data_list:
values = (
f"({data['st_id']}, {data['k_id']}, 'USDT', "
f"{data['balance']}, {data['withdrawal']}, {data['deposit']}, "
f"{data['other']}, {data['profit']}, {data['time']})"
)
values_list.append(values)
if not values_list:
return True
values_str = ", ".join(values_list)
# 使用INSERT ... ON DUPLICATE KEY UPDATE
sql = f"""
INSERT INTO deh_strategy_kx_new
(st_id, k_id, asset, balance, withdrawal, deposit, other, profit, time)
VALUES {values_str}
ON DUPLICATE KEY UPDATE
balance = VALUES(balance),
withdrawal = VALUES(withdrawal),
deposit = VALUES(deposit),
other = VALUES(other),
profit = VALUES(profit),
up_time = NOW()
"""
session.execute(text(sql))
logger.info(f"原生SQL批量更新账户信息: {len(account_data_list)} 条记录")
return True
except Exception as e:
logger.error(f"原生SQL批量更新账户信息失败: {e}")
return False
def _batch_orm_upsert_account_info(self, session, account_data_list: List[Dict]) -> bool:
"""使用ORM批量插入/更新账户信息"""
try:
# 分组数据以提高效率
account_data_by_key = {}
for data in account_data_list:
key = (data['k_id'], data['st_id'], data['time'])
account_data_by_key[key] = data
# 批量查询现有记录
existing_records = self._batch_query_existing_records(session, list(account_data_by_key.keys()))
# 批量更新或插入
to_update = []
to_insert = []
for key, data in account_data_by_key.items():
if key in existing_records:
# 更新
record = existing_records[key]
record.balance = data['balance']
record.withdrawal = data['withdrawal']
record.deposit = data['deposit']
record.other = data['other']
record.profit = data['profit']
else:
# 插入
to_insert.append(StrategyKX(**data))
# 批量插入新记录
if to_insert:
session.add_all(to_insert)
logger.info(f"ORM批量更新账户信息: 更新 {len(existing_records)} 条,插入 {len(to_insert)}")
return True
except Exception as e:
logger.error(f"ORM批量更新账户信息失败: {e}")
return False
def _batch_query_existing_records(self, session, keys: List[tuple]) -> Dict[tuple, StrategyKX]:
"""批量查询现有记录"""
existing_records = {}
try:
if not keys:
return existing_records
# 构建查询条件
conditions = []
for k_id, st_id, time_val in keys:
conditions.append(f"(k_id = {k_id} AND st_id = {st_id} AND time = {time_val})")
if conditions:
conditions_str = " OR ".join(conditions)
sql = f"""
SELECT * FROM deh_strategy_kx_new
WHERE {conditions_str}
"""
results = session.execute(text(sql)).fetchall()
for row in results:
key = (row.k_id, row.st_id, row.time)
existing_records[key] = StrategyKX(
id=row.id,
st_id=row.st_id,
k_id=row.k_id,
asset=row.asset,
balance=row.balance,
withdrawal=row.withdrawal,
deposit=row.deposit,
other=row.other,
profit=row.profit,
time=row.time
)
except Exception as e:
logger.error(f"批量查询现有记录失败: {e}")
return existing_records
async def sync(self):
"""兼容旧接口"""
accounts = self.get_accounts_from_redis()
await self.sync_batch(accounts)

View File

@@ -1,366 +0,0 @@
from .base_sync import BaseSync
from loguru import logger
from typing import List, Dict, Any, Set
import json
import time
from datetime import datetime, timedelta
from sqlalchemy import text, and_
from models.orm_models import StrategyKX
class AccountSyncBatch(BaseSync):
"""账户信息批量同步器"""
async def sync_batch(self, accounts: Dict[str, Dict]):
"""批量同步所有账号的账户信息"""
try:
logger.info(f"开始批量同步账户信息,共 {len(accounts)} 个账号")
# 收集所有账号的数据
all_account_data = await self._collect_all_account_data(accounts)
if not all_account_data:
logger.info("无账户信息数据需要同步")
return
# 批量同步到数据库
success = await self._sync_account_info_batch_to_db(all_account_data)
if success:
logger.info(f"账户信息批量同步完成: 处理 {len(all_account_data)} 条记录")
else:
logger.error("账户信息批量同步失败")
except Exception as e:
logger.error(f"账户信息批量同步失败: {e}")
async def _collect_all_account_data(self, accounts: Dict[str, Dict]) -> List[Dict]:
"""收集所有账号的账户信息数据"""
all_account_data = []
try:
# 按交易所分组账号
account_groups = self._group_accounts_by_exchange(accounts)
# 并发收集每个交易所的数据
tasks = []
for exchange_id, account_list in account_groups.items():
task = self._collect_exchange_account_data(exchange_id, account_list)
tasks.append(task)
# 等待所有任务完成并合并结果
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, list):
all_account_data.extend(result)
logger.info(f"收集到 {len(all_account_data)} 条账户信息记录")
except Exception as e:
logger.error(f"收集账户信息数据失败: {e}")
return all_account_data
def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]:
"""按交易所分组账号"""
groups = {}
for account_id, account_info in accounts.items():
exchange_id = account_info.get('exchange_id')
if exchange_id:
if exchange_id not in groups:
groups[exchange_id] = []
groups[exchange_id].append(account_info)
return groups
async def _collect_exchange_account_data(self, exchange_id: str, account_list: List[Dict]) -> List[Dict]:
"""收集某个交易所的账户信息数据"""
account_data_list = []
try:
for account_info in account_list:
k_id = int(account_info['k_id'])
st_id = account_info.get('st_id', 0)
# 从Redis获取账户信息数据
account_data = await self._get_account_info_from_redis(k_id, st_id, exchange_id)
account_data_list.extend(account_data)
logger.debug(f"交易所 {exchange_id}: 收集到 {len(account_data_list)} 条账户信息")
except Exception as e:
logger.error(f"收集交易所 {exchange_id} 账户信息失败: {e}")
return account_data_list
async def _get_account_info_from_redis(self, k_id: int, st_id: int, exchange_id: str) -> List[Dict]:
"""从Redis获取账户信息数据批量优化版本"""
try:
redis_key = f"{exchange_id}:balance:{k_id}"
redis_funds = self.redis_client.client.hgetall(redis_key)
if not redis_funds:
return []
# 按天统计数据
from config.settings import SYNC_CONFIG
recent_days = SYNC_CONFIG['recent_days']
today = datetime.now()
date_stats = {}
# 收集所有日期的数据
for fund_key, fund_json in redis_funds.items():
try:
fund_data = json.loads(fund_json)
date_str = fund_data.get('lz_time', '')
lz_type = fund_data.get('lz_type', '')
if not date_str or lz_type not in ['lz_balance', 'deposit', 'withdrawal']:
continue
# 只处理最近N天的数据
date_obj = datetime.strptime(date_str, '%Y-%m-%d')
if (today - date_obj).days > recent_days:
continue
if date_str not in date_stats:
date_stats[date_str] = {
'balance': 0.0,
'deposit': 0.0,
'withdrawal': 0.0,
'has_balance': False
}
lz_amount = float(fund_data.get('lz_amount', 0))
if lz_type == 'lz_balance':
date_stats[date_str]['balance'] = lz_amount
date_stats[date_str]['has_balance'] = True
elif lz_type == 'deposit':
date_stats[date_str]['deposit'] += lz_amount
elif lz_type == 'withdrawal':
date_stats[date_str]['withdrawal'] += lz_amount
except (json.JSONDecodeError, ValueError) as e:
logger.debug(f"解析Redis数据失败: {fund_key}, error={e}")
continue
# 转换为账户信息数据
account_data_list = []
sorted_dates = sorted(date_stats.keys())
# 获取前一天余额用于计算利润
prev_balance_map = self._get_previous_balances(redis_funds, sorted_dates)
for date_str in sorted_dates:
stats = date_stats[date_str]
# 如果没有余额数据但有充提数据,仍然处理
if not stats['has_balance'] and stats['deposit'] == 0 and stats['withdrawal'] == 0:
continue
balance = stats['balance']
deposit = stats['deposit']
withdrawal = stats['withdrawal']
# 计算利润
prev_balance = prev_balance_map.get(date_str, 0.0)
profit = balance - deposit - withdrawal - prev_balance
# 转换时间戳
date_obj = datetime.strptime(date_str, '%Y-%m-%d')
time_timestamp = int(date_obj.timestamp())
account_data = {
'st_id': st_id,
'k_id': k_id,
'balance': balance,
'withdrawal': withdrawal,
'deposit': deposit,
'other': 0.0, # 暂时为0
'profit': profit,
'time': time_timestamp
}
account_data_list.append(account_data)
return account_data_list
except Exception as e:
logger.error(f"获取Redis账户信息失败: k_id={k_id}, error={e}")
return []
def _get_previous_balances(self, redis_funds: Dict, sorted_dates: List[str]) -> Dict[str, float]:
"""获取前一天的余额"""
prev_balance_map = {}
prev_date = None
for date_str in sorted_dates:
# 查找前一天的余额
if prev_date:
for fund_key, fund_json in redis_funds.items():
try:
fund_data = json.loads(fund_json)
if (fund_data.get('lz_time') == prev_date and
fund_data.get('lz_type') == 'lz_balance'):
prev_balance_map[date_str] = float(fund_data.get('lz_amount', 0))
break
except:
continue
else:
prev_balance_map[date_str] = 0.0
prev_date = date_str
return prev_balance_map
async def _sync_account_info_batch_to_db(self, account_data_list: List[Dict]) -> bool:
"""批量同步账户信息到数据库(最高效版本)"""
session = self.db_manager.get_session()
try:
if not account_data_list:
return True
with session.begin():
# 方法1使用原生SQL批量插入/更新(性能最好)
success = self._batch_upsert_account_info(session, account_data_list)
if not success:
# 方法2回退到ORM批量操作
success = self._batch_orm_upsert_account_info(session, account_data_list)
return success
except Exception as e:
logger.error(f"批量同步账户信息到数据库失败: {e}")
return False
finally:
session.close()
def _batch_upsert_account_info(self, session, account_data_list: List[Dict]) -> bool:
"""使用原生SQL批量插入/更新账户信息"""
try:
# 准备批量数据
values_list = []
for data in account_data_list:
values = (
f"({data['st_id']}, {data['k_id']}, 'USDT', "
f"{data['balance']}, {data['withdrawal']}, {data['deposit']}, "
f"{data['other']}, {data['profit']}, {data['time']})"
)
values_list.append(values)
if not values_list:
return True
values_str = ", ".join(values_list)
# 使用INSERT ... ON DUPLICATE KEY UPDATE
sql = f"""
INSERT INTO deh_strategy_kx_new
(st_id, k_id, asset, balance, withdrawal, deposit, other, profit, time)
VALUES {values_str}
ON DUPLICATE KEY UPDATE
balance = VALUES(balance),
withdrawal = VALUES(withdrawal),
deposit = VALUES(deposit),
other = VALUES(other),
profit = VALUES(profit),
up_time = NOW()
"""
session.execute(text(sql))
logger.info(f"原生SQL批量更新账户信息: {len(account_data_list)} 条记录")
return True
except Exception as e:
logger.error(f"原生SQL批量更新账户信息失败: {e}")
return False
def _batch_orm_upsert_account_info(self, session, account_data_list: List[Dict]) -> bool:
"""使用ORM批量插入/更新账户信息"""
try:
# 分组数据以提高效率
account_data_by_key = {}
for data in account_data_list:
key = (data['k_id'], data['st_id'], data['time'])
account_data_by_key[key] = data
# 批量查询现有记录
existing_records = self._batch_query_existing_records(session, list(account_data_by_key.keys()))
# 批量更新或插入
to_update = []
to_insert = []
for key, data in account_data_by_key.items():
if key in existing_records:
# 更新
record = existing_records[key]
record.balance = data['balance']
record.withdrawal = data['withdrawal']
record.deposit = data['deposit']
record.other = data['other']
record.profit = data['profit']
else:
# 插入
to_insert.append(StrategyKX(**data))
# 批量插入新记录
if to_insert:
session.add_all(to_insert)
logger.info(f"ORM批量更新账户信息: 更新 {len(existing_records)} 条,插入 {len(to_insert)}")
return True
except Exception as e:
logger.error(f"ORM批量更新账户信息失败: {e}")
return False
def _batch_query_existing_records(self, session, keys: List[tuple]) -> Dict[tuple, StrategyKX]:
"""批量查询现有记录"""
existing_records = {}
try:
if not keys:
return existing_records
# 构建查询条件
conditions = []
for k_id, st_id, time_val in keys:
conditions.append(f"(k_id = {k_id} AND st_id = {st_id} AND time = {time_val})")
if conditions:
conditions_str = " OR ".join(conditions)
sql = f"""
SELECT * FROM deh_strategy_kx_new
WHERE {conditions_str}
"""
results = session.execute(text(sql)).fetchall()
for row in results:
key = (row.k_id, row.st_id, row.time)
existing_records[key] = StrategyKX(
id=row.id,
st_id=row.st_id,
k_id=row.k_id,
asset=row.asset,
balance=row.balance,
withdrawal=row.withdrawal,
deposit=row.deposit,
other=row.other,
profit=row.profit,
time=row.time
)
except Exception as e:
logger.error(f"批量查询现有记录失败: {e}")
return existing_records
async def sync(self):
"""兼容旧接口"""
accounts = self.get_accounts_from_redis()
await self.sync_batch(accounts)

View File

@@ -44,211 +44,6 @@ class BaseSync(ABC):
"""批量同步数据""" """批量同步数据"""
pass pass
def get_accounts_from_redis(self) -> Dict[str, Dict]:
"""从Redis获取所有计算机名的账号配置"""
try:
accounts_dict = {}
total_keys_processed = 0
# 方法1使用配置的计算机名列表
for computer_name in self.computer_names:
accounts = self._get_accounts_by_computer_name(computer_name)
total_keys_processed += 1
accounts_dict.update(accounts)
# 方法2如果配置的计算机名没有数据尝试自动发现备用方案
if not accounts_dict:
logger.warning("配置的计算机名未找到数据,尝试自动发现...")
accounts_dict = self._discover_all_accounts()
self.sync_stats['total_accounts'] = len(accounts_dict)
logger.info(f"{len(self.computer_names)} 个计算机名获取到 {len(accounts_dict)} 个账号")
return accounts_dict
except Exception as e:
logger.error(f"获取账户信息失败: {e}")
return {}
def _get_accounts_by_computer_name(self, computer_name: str) -> Dict[str, Dict]:
"""获取指定计算机名的账号"""
accounts_dict = {}
try:
# 构建key
redis_key = f"{computer_name}_strategy_api"
# 从Redis获取数据
result = self.redis_client.client.hgetall(redis_key)
if not result:
logger.debug(f"未找到 {redis_key} 的策略API配置")
return {}
logger.info(f"{redis_key} 获取到 {len(result)} 个交易所配置")
for exchange_name, accounts_json in result.items():
try:
accounts = json.loads(accounts_json)
if not accounts:
continue
# 格式化交易所ID
exchange_id = self.format_exchange_id(exchange_name)
for account_id, account_info in accounts.items():
parsed_account = self.parse_account(exchange_id, account_id, account_info)
if parsed_account:
# 添加计算机名标记
parsed_account['computer_name'] = computer_name
accounts_dict[account_id] = parsed_account
except json.JSONDecodeError as e:
logger.error(f"解析交易所 {exchange_name} 的JSON数据失败: {e}")
continue
except Exception as e:
logger.error(f"处理交易所 {exchange_name} 数据异常: {e}")
continue
logger.info(f"{redis_key} 解析到 {len(accounts_dict)} 个账号")
except Exception as e:
logger.error(f"获取计算机名 {computer_name} 的账号失败: {e}")
return accounts_dict
def _discover_all_accounts(self) -> Dict[str, Dict]:
"""自动发现所有匹配的账号key"""
accounts_dict = {}
discovered_keys = []
try:
# 获取所有匹配模式的key
pattern = "*_strategy_api"
cursor = 0
while True:
cursor, keys = self.redis_client.client.scan(cursor, match=pattern, count=100)
for key in keys:
key_str = key.decode('utf-8') if isinstance(key, bytes) else key
discovered_keys.append(key_str)
if cursor == 0:
break
logger.info(f"自动发现 {len(discovered_keys)} 个策略API key")
# 处理每个发现的key
for key_str in discovered_keys:
# 提取计算机名
computer_name = key_str.replace('_strategy_api', '')
# 验证计算机名格式
if self.computer_name_pattern.match(computer_name):
accounts = self._get_accounts_by_computer_name(computer_name)
accounts_dict.update(accounts)
else:
logger.warning(f"跳过不符合格式的计算机名: {computer_name}")
logger.info(f"自动发现共获取到 {len(accounts_dict)} 个账号")
except Exception as e:
logger.error(f"自动发现账号失败: {e}")
return accounts_dict
def format_exchange_id(self, key: str) -> str:
"""格式化交易所ID"""
key = key.lower().strip()
# 交易所名称映射
exchange_mapping = {
'metatrader': 'mt5',
'binance_spot_test': 'binance',
'binance_spot': 'binance',
'binance': 'binance',
'gate_spot': 'gate',
'okex': 'okx',
'okx': 'okx',
'bybit': 'bybit',
'bybit_spot': 'bybit',
'bybit_test': 'bybit',
'huobi': 'huobi',
'huobi_spot': 'huobi',
'gate': 'gate',
'gateio': 'gate',
'kucoin': 'kucoin',
'kucoin_spot': 'kucoin',
'mexc': 'mexc',
'mexc_spot': 'mexc',
'bitget': 'bitget',
'bitget_spot': 'bitget'
}
normalized_key = exchange_mapping.get(key, key)
# 记录未映射的交易所
if normalized_key == key and key not in exchange_mapping.values():
logger.debug(f"未映射的交易所名称: {key}")
return normalized_key
def parse_account(self, exchange_id: str, account_id: str, account_info: str) -> Optional[Dict]:
"""解析账号信息"""
try:
source_account_info = json.loads(account_info)
# 基础信息
account_data = {
'exchange_id': exchange_id,
'k_id': account_id,
'st_id': self._safe_int(source_account_info.get('st_id'), 0),
'add_time': self._safe_int(source_account_info.get('add_time'), 0),
'account_type': source_account_info.get('account_type', 'real'),
'api_key': source_account_info.get('api_key', ''),
'secret_key': source_account_info.get('secret_key', ''),
'password': source_account_info.get('password', ''),
'access_token': source_account_info.get('access_token', ''),
'remark': source_account_info.get('remark', '')
}
# MT5特殊处理
if exchange_id == 'mt5':
# 解析服务器地址和端口
server_info = source_account_info.get('secret_key', '')
if ':' in server_info:
host, port = server_info.split(':', 1)
account_data['mt5_host'] = host
account_data['mt5_port'] = self._safe_int(port, 0)
# 合并原始信息
result = {**source_account_info, **account_data}
# 验证必要字段
if not result.get('st_id') or not result.get('exchange_id'):
logger.warning(f"账号 {account_id} 缺少必要字段: st_id={result.get('st_id')}, exchange_id={result.get('exchange_id')}")
return None
return result
except json.JSONDecodeError as e:
logger.error(f"解析账号 {account_id} JSON数据失败: {e}, 原始数据: {account_info[:100]}...")
return None
except Exception as e:
logger.error(f"处理账号 {account_id} 数据异常: {e}")
return None
def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]:
"""按交易所分组账号"""
groups = {}
for account_id, account_info in accounts.items():
exchange_id = account_info.get('exchange_id')
if exchange_id:
if exchange_id not in groups:
groups[exchange_id] = []
groups[exchange_id].append(account_info)
return groups
def _safe_float(self, value: Any, default: float = 0.0) -> float: def _safe_float(self, value: Any, default: float = 0.0) -> float:
"""安全转换为float""" """安全转换为float"""
if value is None: if value is None:

View File

@@ -3,26 +3,30 @@ from loguru import logger
import signal import signal
import sys import sys
import time import time
import json
from typing import Dict from typing import Dict
import re
from utils.redis_client import RedisClient
from config.settings import SYNC_CONFIG from config.settings import SYNC_CONFIG
from .position_sync_batch import PositionSyncBatch from .position_sync import PositionSyncBatch
from .order_sync_batch import OrderSyncBatch # 使用批量版本 from .order_sync import OrderSyncBatch # 使用批量版本
from .account_sync_batch import AccountSyncBatch from .account_sync import AccountSyncBatch
from utils.batch_position_sync import BatchPositionSync
from utils.batch_order_sync import BatchOrderSync
from utils.batch_account_sync import BatchAccountSync
from utils.redis_batch_helper import RedisBatchHelper from utils.redis_batch_helper import RedisBatchHelper
from config.settings import COMPUTER_NAMES, COMPUTER_NAME_PATTERN
from typing import List, Dict, Any, Set, Optional
class SyncManager: class SyncManager:
"""同步管理器(完整批量版本)""" """同步管理器(完整批量版本)"""
def __init__(self): def __init__(self):
self.is_running = True self.is_running = True
self.redis_client = RedisClient()
self.sync_interval = SYNC_CONFIG['interval'] self.sync_interval = SYNC_CONFIG['interval']
self.computer_names = self._get_computer_names()
self.computer_name_pattern = re.compile(COMPUTER_NAME_PATTERN)
# 初始化批量同步工具 # 初始化批量同步工具
self.batch_tools = {}
self.redis_helper = None self.redis_helper = None
# 初始化同步器 # 初始化同步器
@@ -31,24 +35,16 @@ class SyncManager:
if SYNC_CONFIG['enable_position_sync']: if SYNC_CONFIG['enable_position_sync']:
position_sync = PositionSyncBatch() position_sync = PositionSyncBatch()
self.syncers.append(position_sync) self.syncers.append(position_sync)
self.batch_tools['position'] = BatchPositionSync(position_sync.db_manager)
logger.info("启用持仓批量同步") logger.info("启用持仓批量同步")
if SYNC_CONFIG['enable_order_sync']: if SYNC_CONFIG['enable_order_sync']:
order_sync = OrderSyncBatch() order_sync = OrderSyncBatch()
self.syncers.append(order_sync) self.syncers.append(order_sync)
self.batch_tools['order'] = BatchOrderSync(order_sync.db_manager)
# 初始化Redis批量助手
if order_sync.redis_client:
self.redis_helper = RedisBatchHelper(order_sync.redis_client.client)
logger.info("启用订单批量同步") logger.info("启用订单批量同步")
if SYNC_CONFIG['enable_account_sync']: if SYNC_CONFIG['enable_account_sync']:
account_sync = AccountSyncBatch() account_sync = AccountSyncBatch()
self.syncers.append(account_sync) self.syncers.append(account_sync)
self.batch_tools['account'] = BatchAccountSync(account_sync.db_manager)
logger.info("启用账户信息批量同步") logger.info("启用账户信息批量同步")
# 性能统计 # 性能统计
@@ -71,21 +67,25 @@ class SyncManager:
while self.is_running: while self.is_running:
try: try:
self.stats['total_syncs'] += 1
sync_start = time.time()
# 获取所有账号(只获取一次) # 获取所有账号(只获取一次)
accounts = await self._get_all_accounts() accounts = await self.get_accounts_from_redis()
if not accounts: if not accounts:
logger.warning("未获取到任何账号,等待下次同步") logger.warning("未获取到任何账号,等待下次同步")
await asyncio.sleep(self.sync_interval) await asyncio.sleep(self.sync_interval)
continue continue
self.stats['total_syncs'] += 1
sync_start = time.time()
logger.info(f"{self.stats['total_syncs']}次同步开始,共 {len(accounts)} 个账号") logger.info(f"{self.stats['total_syncs']}次同步开始,共 {len(accounts)} 个账号")
# 并发执行所有同步 # 执行所有同步
await self._execute_all_syncers_concurrent(accounts) tasks = [syncer.sync(accounts) for syncer in self.syncers]
await asyncio.gather(*tasks, return_exceptions=True)
# 更新统计 # 更新统计
sync_time = time.time() - sync_start sync_time = time.time() - sync_start
@@ -101,136 +101,254 @@ class SyncManager:
logger.error(f"同步任务异常: {e}") logger.error(f"同步任务异常: {e}")
await asyncio.sleep(30) await asyncio.sleep(30)
async def _get_all_accounts(self) -> Dict[str, Dict]: def get_accounts_from_redis(self) -> Dict[str, Dict]:
"""获取所有账号""" """从Redis获取所有计算机名的账号配置"""
if not self.syncers: try:
accounts_dict = {}
total_keys_processed = 0
# 方法1使用配置的计算机名列表
for computer_name in self.computer_names:
accounts = self._get_accounts_by_computer_name(computer_name)
total_keys_processed += 1
accounts_dict.update(accounts)
# 方法2如果配置的计算机名没有数据尝试自动发现备用方案
if not accounts_dict:
logger.warning("配置的计算机名未找到数据,尝试自动发现...")
accounts_dict = self._discover_all_accounts()
self.sync_stats['total_accounts'] = len(accounts_dict)
logger.info(f"{len(self.computer_names)} 个计算机名获取到 {len(accounts_dict)} 个账号")
return accounts_dict
except Exception as e:
logger.error(f"获取账户信息失败: {e}")
return {} return {}
# 使用第一个同步器获取账号 def _get_computer_names(self) -> List[str]:
return self.syncers[0].get_accounts_from_redis() """获取计算机名列表"""
if ',' in COMPUTER_NAMES:
names = [name.strip() for name in COMPUTER_NAMES.split(',')]
logger.info(f"使用配置的计算机名列表: {names}")
return names
return [COMPUTER_NAMES.strip()]
async def _execute_all_syncers_concurrent(self, accounts: Dict[str, Dict]): def _get_accounts_by_computer_name(self, computer_name: str) -> Dict[str, Dict]:
"""并发执行所有同步器""" """获取指定计算机名的账号"""
tasks = [] accounts_dict = {}
# 持仓批量同步
if 'position' in self.batch_tools:
task = self._sync_positions_batch(accounts)
tasks.append(task)
# 订单批量同步
if 'order' in self.batch_tools:
task = self._sync_orders_batch(accounts)
tasks.append(task)
# 账户信息批量同步
if 'account' in self.batch_tools:
task = self._sync_accounts_batch(accounts)
tasks.append(task)
# 并发执行所有任务
if tasks:
results = await asyncio.gather(*tasks, return_exceptions=True)
# 检查结果
for i, result in enumerate(results):
if isinstance(result, Exception):
logger.error(f"同步任务 {i} 失败: {result}")
async def _sync_positions_batch(self, accounts: Dict[str, Dict]):
"""批量同步持仓数据"""
try: try:
start_time = time.time() # 构建key
redis_key = f"{computer_name}_strategy_api"
# 收集所有持仓数据 # 从Redis获取数据
position_sync = next((s for s in self.syncers if isinstance(s, PositionSyncBatch)), None) result = self.redis_client.client.hgetall(redis_key)
if not position_sync: if not result:
return logger.debug(f"未找到 {redis_key} 的策略API配置")
return {}
all_positions = await position_sync._collect_all_positions(accounts) logger.info(f"{redis_key} 获取到 {len(result)} 个交易所配置")
if not all_positions: for exchange_name, accounts_json in result.items():
self.stats['position'] = {'accounts': 0, 'positions': 0, 'time': 0} try:
return accounts = json.loads(accounts_json)
if not accounts:
continue
# 使用批量工具同步 # 格式化交易所ID
batch_tool = self.batch_tools['position'] exchange_id = self.format_exchange_id(exchange_name)
success, stats = batch_tool.sync_positions_batch(all_positions)
if success: for account_id, account_info in accounts.items():
elapsed = time.time() - start_time parsed_account = self.parse_account(exchange_id, account_id, account_info)
self.stats['position'] = { if parsed_account:
'accounts': len(accounts), # 添加计算机名标记
'positions': stats['total'], parsed_account['computer_name'] = computer_name
'time': elapsed accounts_dict[account_id] = parsed_account
}
except json.JSONDecodeError as e:
logger.error(f"解析交易所 {exchange_name} 的JSON数据失败: {e}")
continue
except Exception as e:
logger.error(f"处理交易所 {exchange_name} 数据异常: {e}")
continue
logger.info(f"{redis_key} 解析到 {len(accounts_dict)} 个账号")
except Exception as e: except Exception as e:
logger.error(f"批量同步持仓失败: {e}") logger.error(f"获取计算机名 {computer_name} 的账号失败: {e}")
self.stats['position'] = {'accounts': 0, 'positions': 0, 'time': 0}
return accounts_dict
def _discover_all_accounts(self) -> Dict[str, Dict]:
"""自动发现所有匹配的账号key"""
accounts_dict = {}
discovered_keys = []
async def _sync_orders_batch(self, accounts: Dict[str, Dict]):
"""批量同步订单数据"""
try: try:
start_time = time.time() # 获取所有匹配模式的key
pattern = "*_strategy_api"
cursor = 0
# 收集所有订单数据 while True:
order_sync = next((s for s in self.syncers if isinstance(s, OrderSyncBatch)), None) cursor, keys = self.redis_client.client.scan(cursor, match=pattern, count=100)
if not order_sync:
return
all_orders = await order_sync._collect_all_orders(accounts) for key in keys:
key_str = key.decode('utf-8') if isinstance(key, bytes) else key
discovered_keys.append(key_str)
if not all_orders: if cursor == 0:
self.stats['order'] = {'accounts': 0, 'orders': 0, 'time': 0} break
return
# 使用批量工具同步 logger.info(f"自动发现 {len(discovered_keys)} 个策略API key")
batch_tool = self.batch_tools['order']
success, processed_count = batch_tool.sync_orders_batch(all_orders)
if success: # 处理每个发现的key
elapsed = time.time() - start_time for key_str in discovered_keys:
self.stats['order'] = { # 提取计算机名
'accounts': len(accounts), computer_name = key_str.replace('_strategy_api', '')
'orders': processed_count,
'time': elapsed # 验证计算机名格式
} if self.computer_name_pattern.match(computer_name):
accounts = self._get_accounts_by_computer_name(computer_name)
accounts_dict.update(accounts)
else:
logger.warning(f"跳过不符合格式的计算机名: {computer_name}")
logger.info(f"自动发现共获取到 {len(accounts_dict)} 个账号")
except Exception as e: except Exception as e:
logger.error(f"批量同步订单失败: {e}") logger.error(f"自动发现账号失败: {e}")
self.stats['order'] = {'accounts': 0, 'orders': 0, 'time': 0}
return accounts_dict
def _discover_all_accounts(self) -> Dict[str, Dict]:
"""自动发现所有匹配的账号key"""
accounts_dict = {}
discovered_keys = []
async def _sync_accounts_batch(self, accounts: Dict[str, Dict]):
"""批量同步账户信息数据"""
try: try:
start_time = time.time() # 获取所有匹配模式的key
pattern = "*_strategy_api"
cursor = 0
# 收集所有账户数据 while True:
account_sync = next((s for s in self.syncers if isinstance(s, AccountSyncBatch)), None) cursor, keys = self.redis_client.client.scan(cursor, match=pattern, count=100)
if not account_sync:
return
all_account_data = await account_sync._collect_all_account_data(accounts) for key in keys:
key_str = key.decode('utf-8') if isinstance(key, bytes) else key
discovered_keys.append(key_str)
if not all_account_data: if cursor == 0:
self.stats['account'] = {'accounts': 0, 'records': 0, 'time': 0} break
return
# 使用批量工具同步 logger.info(f"自动发现 {len(discovered_keys)} 个策略API key")
batch_tool = self.batch_tools['account']
updated, inserted = batch_tool.sync_accounts_batch(all_account_data)
elapsed = time.time() - start_time # 处理每个发现的key
self.stats['account'] = { for key_str in discovered_keys:
'accounts': len(accounts), # 提取计算机名
'records': len(all_account_data), computer_name = key_str.replace('_strategy_api', '')
'time': elapsed
# 验证计算机名格式
if self.computer_name_pattern.match(computer_name):
accounts = self._get_accounts_by_computer_name(computer_name)
accounts_dict.update(accounts)
else:
logger.warning(f"跳过不符合格式的计算机名: {computer_name}")
logger.info(f"自动发现共获取到 {len(accounts_dict)} 个账号")
except Exception as e:
logger.error(f"自动发现账号失败: {e}")
return accounts_dict
def format_exchange_id(self, key: str) -> str:
"""格式化交易所ID"""
key = key.lower().strip()
# 交易所名称映射
exchange_mapping = {
'metatrader': 'mt5',
'binance_spot_test': 'binance',
'binance_spot': 'binance',
'binance': 'binance',
'gate_spot': 'gate',
'okex': 'okx',
'okx': 'okx',
'bybit': 'bybit',
'bybit_spot': 'bybit',
'bybit_test': 'bybit',
'huobi': 'huobi',
'huobi_spot': 'huobi',
'gate': 'gate',
'gateio': 'gate',
'kucoin': 'kucoin',
'kucoin_spot': 'kucoin',
'mexc': 'mexc',
'mexc_spot': 'mexc',
'bitget': 'bitget',
'bitget_spot': 'bitget'
}
normalized_key = exchange_mapping.get(key, key)
# 记录未映射的交易所
if normalized_key == key and key not in exchange_mapping.values():
logger.debug(f"未映射的交易所名称: {key}")
return normalized_key
def parse_account(self, exchange_id: str, account_id: str, account_info: str) -> Optional[Dict]:
"""解析账号信息"""
try:
source_account_info = json.loads(account_info)
# 基础信息
account_data = {
'exchange_id': exchange_id,
'k_id': account_id,
'st_id': self._safe_int(source_account_info.get('st_id'), 0),
'add_time': self._safe_int(source_account_info.get('add_time'), 0),
'account_type': source_account_info.get('account_type', 'real'),
'api_key': source_account_info.get('api_key', ''),
'secret_key': source_account_info.get('secret_key', ''),
'password': source_account_info.get('password', ''),
'access_token': source_account_info.get('access_token', ''),
'remark': source_account_info.get('remark', '')
} }
# 合并原始信息
result = {**source_account_info, **account_data}
# 验证必要字段
if not result.get('st_id') or not result.get('exchange_id'):
logger.warning(f"账号 {account_id} 缺少必要字段: st_id={result.get('st_id')}, exchange_id={result.get('exchange_id')}")
return None
return result
except json.JSONDecodeError as e:
logger.error(f"解析账号 {account_id} JSON数据失败: {e}, 原始数据: {account_info[:100]}...")
return None
except Exception as e: except Exception as e:
logger.error(f"批量同步账户信息失败: {e}") logger.error(f"处理账号 {account_id} 数据异常: {e}")
self.stats['account'] = {'accounts': 0, 'records': 0, 'time': 0} return None
def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]:
"""按交易所分组账号"""
groups = {}
for account_id, account_info in accounts.items():
exchange_id = account_info.get('exchange_id')
if exchange_id:
if exchange_id not in groups:
groups[exchange_id] = []
groups[exchange_id].append(account_info)
return groups
def _update_stats(self, sync_time: float): def _update_stats(self, sync_time: float):
"""更新统计信息""" """更新统计信息"""

View File

@@ -1,166 +1,269 @@
from .base_sync import BaseSync from .base_sync import BaseSync
from loguru import logger from loguru import logger
from typing import List, Dict from typing import List, Dict, Any, Tuple
import json import json
import asyncio
import time import time
from datetime import datetime, timedelta from datetime import datetime, timedelta
from sqlalchemy import text
import redis
class OrderSync(BaseSync): class OrderSyncBatch(BaseSync):
"""订单数据同步器""" """订单数据批量同步器"""
async def sync(self): def __init__(self):
"""同步订单数据""" super().__init__()
self.batch_size = 1000 # 每批处理数量
self.recent_days = 3 # 同步最近几天的数据
async def sync_batch(self, accounts: Dict[str, Dict]):
"""批量同步所有账号的订单数据"""
try: try:
# 获取所有账号 logger.info(f"开始批量同步订单数据,共 {len(accounts)} 个账号")
accounts = self.get_accounts_from_redis() start_time = time.time()
for k_id_str, account_info in accounts.items(): # 1. 收集所有账号的订单数据
try: all_orders = await self._collect_all_orders(accounts)
k_id = int(k_id_str)
st_id = account_info.get('st_id', 0)
exchange_id = account_info['exchange_id']
if k_id <= 0 or st_id <= 0: if not all_orders:
continue logger.info("无订单数据需要同步")
return
# 从Redis获取最近N天的订单数据 logger.info(f"收集到 {len(all_orders)} 条订单数据")
orders = await self._get_recent_orders_from_redis(k_id, exchange_id)
# 同步到数据库 # 2. 批量同步到数据库
if orders: success, processed_count = await self._sync_orders_batch_to_db(all_orders)
success = self._sync_orders_to_db(k_id, st_id, orders)
if success:
logger.debug(f"订单同步成功: k_id={k_id}, 订单数={len(orders)}")
except Exception as e: elapsed = time.time() - start_time
logger.error(f"同步账号 {k_id_str} 订单失败: {e}") if success:
continue logger.info(f"订单批量同步完成: 处理 {processed_count} 条订单,耗时 {elapsed:.2f}")
else:
logger.info("订单数据同步完成") logger.error("订单批量同步失败")
except Exception as e: except Exception as e:
logger.error(f"订单同步失败: {e}") logger.error(f"订单批量同步失败: {e}")
async def _get_recent_orders_from_redis(self, k_id: int, exchange_id: str) -> List[Dict]: async def _collect_all_orders(self, accounts: Dict[str, Dict]) -> List[Dict]:
"""收集所有账号的订单数据"""
all_orders = []
try:
# 按交易所分组账号
account_groups = self._group_accounts_by_exchange(accounts)
# 并发收集每个交易所的数据
tasks = []
for exchange_id, account_list in account_groups.items():
task = self._collect_exchange_orders(exchange_id, account_list)
tasks.append(task)
# 等待所有任务完成并合并结果
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, list):
all_orders.extend(result)
except Exception as e:
logger.error(f"收集订单数据失败: {e}")
return all_orders
def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]:
"""按交易所分组账号"""
groups = {}
for account_id, account_info in accounts.items():
exchange_id = account_info.get('exchange_id')
if exchange_id:
if exchange_id not in groups:
groups[exchange_id] = []
groups[exchange_id].append(account_info)
return groups
async def _collect_exchange_orders(self, exchange_id: str, account_list: List[Dict]) -> List[Dict]:
"""收集某个交易所的订单数据"""
orders_list = []
try:
# 并发获取每个账号的数据
tasks = []
for account_info in account_list:
k_id = int(account_info['k_id'])
st_id = account_info.get('st_id', 0)
task = self._get_recent_orders_from_redis(k_id, st_id, exchange_id)
tasks.append(task)
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, list):
orders_list.extend(result)
logger.debug(f"交易所 {exchange_id}: 收集到 {len(orders_list)} 条订单")
except Exception as e:
logger.error(f"收集交易所 {exchange_id} 订单数据失败: {e}")
return orders_list
async def _get_recent_orders_from_redis(self, k_id: int, st_id: int, exchange_id: str) -> List[Dict]:
"""从Redis获取最近N天的订单数据""" """从Redis获取最近N天的订单数据"""
try: try:
redis_key = f"{exchange_id}:orders:{k_id}" redis_key = f"{exchange_id}:orders:{k_id}"
# 计算最近N天的日期 # 计算最近N天的日期
from config.settings import SYNC_CONFIG
recent_days = SYNC_CONFIG['recent_days']
today = datetime.now() today = datetime.now()
recent_dates = [] recent_dates = []
for i in range(recent_days): for i in range(self.recent_days):
date = today - timedelta(days=i) date = today - timedelta(days=i)
date_format = date.strftime('%Y-%m-%d') date_format = date.strftime('%Y-%m-%d')
recent_dates.append(date_format) recent_dates.append(date_format)
# 获取所有key # 使用scan获取所有符合条件的key
all_keys = self.redis_client.client.hkeys(redis_key) cursor = 0
recent_keys = []
orders_list = [] while True:
for key in all_keys: cursor, keys = self.redis_client.client.hscan(redis_key, cursor, count=1000)
key_str = key.decode('utf-8') if isinstance(key, bytes) else key
if key_str == 'positions': for key, _ in keys.items():
continue key_str = key.decode('utf-8') if isinstance(key, bytes) else key
# 检查是否以最近N天的日期开头 if key_str == 'positions':
for date_format in recent_dates: continue
if key_str.startswith(date_format + '_'):
try:
order_json = self.redis_client.client.hget(redis_key, key_str)
if order_json:
order = json.loads(order_json)
# 验证时间 # 检查是否以最近N天的日期开头
order_time = order.get('time', 0) for date_format in recent_dates:
if order_time >= int(time.time()) - recent_days * 24 * 3600: if key_str.startswith(date_format + '_'):
orders_list.append(order) recent_keys.append(key_str)
break
except:
break break
if cursor == 0:
break
if not recent_keys:
return []
# 批量获取订单数据
orders_list = []
# 分批获取避免单次hgetall数据量太大
chunk_size = 500
for i in range(0, len(recent_keys), chunk_size):
chunk_keys = recent_keys[i:i + chunk_size]
# 使用hmget批量获取
chunk_values = self.redis_client.client.hmget(redis_key, chunk_keys)
for key, order_json in zip(chunk_keys, chunk_values):
if not order_json:
continue
try:
order = json.loads(order_json)
# 验证时间
order_time = order.get('time', 0)
if order_time >= int(time.time()) - self.recent_days * 24 * 3600:
# 添加账号信息
order['k_id'] = k_id
order['st_id'] = st_id
order['exchange_id'] = exchange_id
orders_list.append(order)
except json.JSONDecodeError as e:
logger.debug(f"解析订单JSON失败: key={key}, error={e}")
continue
return orders_list return orders_list
except Exception as e: except Exception as e:
logger.error(f"获取Redis订单数据失败: k_id={k_id}, error={e}") logger.error(f"获取Redis订单数据失败: k_id={k_id}, error={e}")
return [] return []
def _sync_orders_to_db(self, k_id: int, st_id: int, orders_data: List[Dict]) -> bool: async def _sync_orders_batch_to_db(self, all_orders: List[Dict]) -> Tuple[bool, int]:
"""同步订单数据到数据库""" """批量同步订单数据到数据库"""
session = self.db_manager.get_session()
try: try:
# 准备批量数据 if not all_orders:
insert_data = [] return True, 0
for order_data in orders_data:
# 转换数据
converted_orders = []
for order in all_orders:
try: try:
order_dict = self._convert_order_data(order_data) order_dict = self._convert_order_data(order)
# 检查完整性 # 检查完整性
required_fields = ['order_id', 'symbol', 'side', 'time'] required_fields = ['order_id', 'symbol', 'side', 'time']
if not all(order_dict.get(field) for field in required_fields): if not all(order_dict.get(field) for field in required_fields):
continue continue
insert_data.append(order_dict) converted_orders.append(order_dict)
except Exception as e: except Exception as e:
logger.error(f"转换订单数据失败: {order_data}, error={e}") logger.error(f"转换订单数据失败: {order}, error={e}")
continue continue
if not insert_data: if not converted_orders:
return True return True, 0
with session.begin(): # 使用批量工具同步
# 使用参数化批量插入 from utils.batch_order_sync import BatchOrderSync
sql = """ batch_tool = BatchOrderSync(self.db_manager, self.batch_size)
INSERT INTO deh_strategy_order_new
(st_id, k_id, asset, order_id, symbol, side, price, time,
order_qty, last_qty, avg_price, exchange_id)
VALUES
(:st_id, :k_id, :asset, :order_id, :symbol, :side, :price, :time,
:order_qty, :last_qty, :avg_price, :exchange_id)
ON DUPLICATE KEY UPDATE
side = VALUES(side),
price = VALUES(price),
time = VALUES(time),
order_qty = VALUES(order_qty),
last_qty = VALUES(last_qty),
avg_price = VALUES(avg_price)
"""
# 分块执行 success, processed_count = batch_tool.sync_orders_batch(converted_orders)
from config.settings import SYNC_CONFIG
chunk_size = SYNC_CONFIG['chunk_size']
for i in range(0, len(insert_data), chunk_size): return success, processed_count
chunk = insert_data[i:i + chunk_size]
session.execute(text(sql), chunk)
return True
except Exception as e: except Exception as e:
logger.error(f"同步订单到数据库失败: k_id={k_id}, error={e}") logger.error(f"批量同步订单到数据库失败: {e}")
return False return False, 0
finally:
session.close()
def _convert_order_data(self, data: Dict) -> Dict: def _convert_order_data(self, data: Dict) -> Dict:
"""转换订单数据格式""" """转换订单数据格式"""
return { try:
'st_id': int(data.get('st_id', 0)), # 安全转换函数
'k_id': int(data.get('k_id', 0)), def safe_float(value):
'asset': 'USDT', if value is None:
'order_id': str(data.get('order_id', '')), return None
'symbol': data.get('symbol', ''), try:
'side': data.get('side', ''), return float(value)
'price': float(data.get('price', 0)) if data.get('price') is not None else None, except (ValueError, TypeError):
'time': int(data.get('time', 0)) if data.get('time') is not None else None, return None
'order_qty': float(data.get('order_qty', 0)) if data.get('order_qty') is not None else None,
'last_qty': float(data.get('last_qty', 0)) if data.get('last_qty') is not None else None, def safe_int(value):
'avg_price': float(data.get('avg_price', 0)) if data.get('avg_price') is not None else None, if value is None:
'exchange_id': None # 忽略该字段 return None
} try:
return int(float(value))
except (ValueError, TypeError):
return None
def safe_str(value):
if value is None:
return ''
return str(value)
return {
'st_id': safe_int(data.get('st_id'), 0),
'k_id': safe_int(data.get('k_id'), 0),
'asset': 'USDT',
'order_id': safe_str(data.get('order_id')),
'symbol': safe_str(data.get('symbol')),
'side': safe_str(data.get('side')),
'price': safe_float(data.get('price')),
'time': safe_int(data.get('time')),
'order_qty': safe_float(data.get('order_qty')),
'last_qty': safe_float(data.get('last_qty')),
'avg_price': safe_float(data.get('avg_price')),
'exchange_id': None # 忽略该字段
}
except Exception as e:
logger.error(f"转换订单数据异常: {data}, error={e}")
return {}
async def sync(self):
"""兼容旧接口"""
accounts = self.get_accounts_from_redis()
await self.sync_batch(accounts)

View File

@@ -1,269 +0,0 @@
from .base_sync import BaseSync
from loguru import logger
from typing import List, Dict, Any, Tuple
import json
import asyncio
import time
from datetime import datetime, timedelta
from sqlalchemy import text
import redis
class OrderSyncBatch(BaseSync):
"""订单数据批量同步器"""
def __init__(self):
super().__init__()
self.batch_size = 1000 # 每批处理数量
self.recent_days = 3 # 同步最近几天的数据
async def sync_batch(self, accounts: Dict[str, Dict]):
"""批量同步所有账号的订单数据"""
try:
logger.info(f"开始批量同步订单数据,共 {len(accounts)} 个账号")
start_time = time.time()
# 1. 收集所有账号的订单数据
all_orders = await self._collect_all_orders(accounts)
if not all_orders:
logger.info("无订单数据需要同步")
return
logger.info(f"收集到 {len(all_orders)} 条订单数据")
# 2. 批量同步到数据库
success, processed_count = await self._sync_orders_batch_to_db(all_orders)
elapsed = time.time() - start_time
if success:
logger.info(f"订单批量同步完成: 处理 {processed_count} 条订单,耗时 {elapsed:.2f}")
else:
logger.error("订单批量同步失败")
except Exception as e:
logger.error(f"订单批量同步失败: {e}")
async def _collect_all_orders(self, accounts: Dict[str, Dict]) -> List[Dict]:
"""收集所有账号的订单数据"""
all_orders = []
try:
# 按交易所分组账号
account_groups = self._group_accounts_by_exchange(accounts)
# 并发收集每个交易所的数据
tasks = []
for exchange_id, account_list in account_groups.items():
task = self._collect_exchange_orders(exchange_id, account_list)
tasks.append(task)
# 等待所有任务完成并合并结果
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, list):
all_orders.extend(result)
except Exception as e:
logger.error(f"收集订单数据失败: {e}")
return all_orders
def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]:
"""按交易所分组账号"""
groups = {}
for account_id, account_info in accounts.items():
exchange_id = account_info.get('exchange_id')
if exchange_id:
if exchange_id not in groups:
groups[exchange_id] = []
groups[exchange_id].append(account_info)
return groups
async def _collect_exchange_orders(self, exchange_id: str, account_list: List[Dict]) -> List[Dict]:
"""收集某个交易所的订单数据"""
orders_list = []
try:
# 并发获取每个账号的数据
tasks = []
for account_info in account_list:
k_id = int(account_info['k_id'])
st_id = account_info.get('st_id', 0)
task = self._get_recent_orders_from_redis(k_id, st_id, exchange_id)
tasks.append(task)
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, list):
orders_list.extend(result)
logger.debug(f"交易所 {exchange_id}: 收集到 {len(orders_list)} 条订单")
except Exception as e:
logger.error(f"收集交易所 {exchange_id} 订单数据失败: {e}")
return orders_list
async def _get_recent_orders_from_redis(self, k_id: int, st_id: int, exchange_id: str) -> List[Dict]:
"""从Redis获取最近N天的订单数据"""
try:
redis_key = f"{exchange_id}:orders:{k_id}"
# 计算最近N天的日期
today = datetime.now()
recent_dates = []
for i in range(self.recent_days):
date = today - timedelta(days=i)
date_format = date.strftime('%Y-%m-%d')
recent_dates.append(date_format)
# 使用scan获取所有符合条件的key
cursor = 0
recent_keys = []
while True:
cursor, keys = self.redis_client.client.hscan(redis_key, cursor, count=1000)
for key, _ in keys.items():
key_str = key.decode('utf-8') if isinstance(key, bytes) else key
if key_str == 'positions':
continue
# 检查是否以最近N天的日期开头
for date_format in recent_dates:
if key_str.startswith(date_format + '_'):
recent_keys.append(key_str)
break
if cursor == 0:
break
if not recent_keys:
return []
# 批量获取订单数据
orders_list = []
# 分批获取避免单次hgetall数据量太大
chunk_size = 500
for i in range(0, len(recent_keys), chunk_size):
chunk_keys = recent_keys[i:i + chunk_size]
# 使用hmget批量获取
chunk_values = self.redis_client.client.hmget(redis_key, chunk_keys)
for key, order_json in zip(chunk_keys, chunk_values):
if not order_json:
continue
try:
order = json.loads(order_json)
# 验证时间
order_time = order.get('time', 0)
if order_time >= int(time.time()) - self.recent_days * 24 * 3600:
# 添加账号信息
order['k_id'] = k_id
order['st_id'] = st_id
order['exchange_id'] = exchange_id
orders_list.append(order)
except json.JSONDecodeError as e:
logger.debug(f"解析订单JSON失败: key={key}, error={e}")
continue
return orders_list
except Exception as e:
logger.error(f"获取Redis订单数据失败: k_id={k_id}, error={e}")
return []
async def _sync_orders_batch_to_db(self, all_orders: List[Dict]) -> Tuple[bool, int]:
"""批量同步订单数据到数据库"""
try:
if not all_orders:
return True, 0
# 转换数据
converted_orders = []
for order in all_orders:
try:
order_dict = self._convert_order_data(order)
# 检查完整性
required_fields = ['order_id', 'symbol', 'side', 'time']
if not all(order_dict.get(field) for field in required_fields):
continue
converted_orders.append(order_dict)
except Exception as e:
logger.error(f"转换订单数据失败: {order}, error={e}")
continue
if not converted_orders:
return True, 0
# 使用批量工具同步
from utils.batch_order_sync import BatchOrderSync
batch_tool = BatchOrderSync(self.db_manager, self.batch_size)
success, processed_count = batch_tool.sync_orders_batch(converted_orders)
return success, processed_count
except Exception as e:
logger.error(f"批量同步订单到数据库失败: {e}")
return False, 0
def _convert_order_data(self, data: Dict) -> Dict:
"""转换订单数据格式"""
try:
# 安全转换函数
def safe_float(value):
if value is None:
return None
try:
return float(value)
except (ValueError, TypeError):
return None
def safe_int(value):
if value is None:
return None
try:
return int(float(value))
except (ValueError, TypeError):
return None
def safe_str(value):
if value is None:
return ''
return str(value)
return {
'st_id': safe_int(data.get('st_id'), 0),
'k_id': safe_int(data.get('k_id'), 0),
'asset': 'USDT',
'order_id': safe_str(data.get('order_id')),
'symbol': safe_str(data.get('symbol')),
'side': safe_str(data.get('side')),
'price': safe_float(data.get('price')),
'time': safe_int(data.get('time')),
'order_qty': safe_float(data.get('order_qty')),
'last_qty': safe_float(data.get('last_qty')),
'avg_price': safe_float(data.get('avg_price')),
'exchange_id': None # 忽略该字段
}
except Exception as e:
logger.error(f"转换订单数据异常: {data}, error={e}")
return {}
async def sync(self):
"""兼容旧接口"""
accounts = self.get_accounts_from_redis()
await self.sync_batch(accounts)

View File

@@ -1,41 +1,74 @@
from .base_sync import BaseSync from .base_sync import BaseSync
from loguru import logger from loguru import logger
from typing import List, Dict from typing import List, Dict, Any, Set, Tuple
import json import json
import asyncio import asyncio
from concurrent.futures import ThreadPoolExecutor from datetime import datetime
from sqlalchemy import text, and_, select, delete
from models.orm_models import StrategyPosition
import time
class PositionSync(BaseSync): class PositionSyncBatch(BaseSync):
"""持仓数据同步器(批量版本)""" """持仓数据批量同步器"""
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.max_concurrent = 10 # 每个同步器的最大并发数 self.batch_size = 500 # 每批处理数量
async def sync_batch(self, accounts: Dict[str, Dict]): async def sync_batch(self, accounts: Dict[str, Dict]):
"""批量同步所有账号的持仓数据""" """批量同步所有账号的持仓数据"""
try: try:
logger.info(f"开始批量同步持仓数据,共 {len(accounts)} 个账号") logger.info(f"开始批量同步持仓数据,共 {len(accounts)} 个账号")
start_time = time.time()
# 按账号分组 # 1. 收集所有账号的持仓数据
account_groups = self._group_accounts_by_exchange(accounts) all_positions = await self._collect_all_positions(accounts)
# 并发处理每个交易所的账号 if not all_positions:
tasks = [] logger.info("无持仓数据需要同步")
for exchange_id, account_list in account_groups.items(): return
task = self._sync_exchange_accounts(exchange_id, account_list)
tasks.append(task)
# 等待所有任务完成 logger.info(f"收集到 {len(all_positions)} 条持仓数据")
results = await asyncio.gather(*tasks, return_exceptions=True)
# 统计结果 # 2. 批量同步到数据库
success_count = sum(1 for r in results if isinstance(r, bool) and r) success, stats = await self._sync_positions_batch_to_db(all_positions)
logger.info(f"持仓批量同步完成: 成功 {success_count}/{len(results)} 个交易所组")
elapsed = time.time() - start_time
if success:
logger.info(f"持仓批量同步完成: 处理 {stats['total']} 条,更新 {stats['updated']} 条,"
f"插入 {stats['inserted']} 条,删除 {stats['deleted']} 条,耗时 {elapsed:.2f}")
else:
logger.error("持仓批量同步失败")
except Exception as e: except Exception as e:
logger.error(f"持仓批量同步失败: {e}") logger.error(f"持仓批量同步失败: {e}")
async def _collect_all_positions(self, accounts: Dict[str, Dict]) -> List[Dict]:
"""收集所有账号的持仓数据"""
all_positions = []
try:
# 按交易所分组账号
account_groups = self._group_accounts_by_exchange(accounts)
# 并发收集每个交易所的数据
tasks = []
for exchange_id, account_list in account_groups.items():
task = self._collect_exchange_positions(exchange_id, account_list)
tasks.append(task)
# 等待所有任务完成并合并结果
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, list):
all_positions.extend(result)
except Exception as e:
logger.error(f"收集持仓数据失败: {e}")
return all_positions
def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]: def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]:
"""按交易所分组账号""" """按交易所分组账号"""
groups = {} groups = {}
@@ -47,46 +80,60 @@ class PositionSync(BaseSync):
groups[exchange_id].append(account_info) groups[exchange_id].append(account_info)
return groups return groups
async def _sync_exchange_accounts(self, exchange_id: str, account_list: List[Dict]): async def _collect_exchange_positions(self, exchange_id: str, account_list: List[Dict]) -> List[Dict]:
"""同步某个交易所的所有账号""" """收集某个交易所的持仓数据"""
try: positions_list = []
# 收集所有账号的持仓数据
all_positions = []
try:
tasks = []
for account_info in account_list: for account_info in account_list:
k_id = int(account_info['k_id']) k_id = int(account_info['k_id'])
st_id = account_info.get('st_id', 0) st_id = account_info.get('st_id', 0)
task = self._get_positions_from_redis(k_id, st_id, exchange_id)
tasks.append(task)
# 从Redis获取持仓数据 # 并发获取
positions = await self._get_positions_from_redis(k_id, exchange_id) results = await asyncio.gather(*tasks, return_exceptions=True)
if positions: for result in results:
# 添加账号信息 if isinstance(result, list):
for position in positions: positions_list.extend(result)
position['k_id'] = k_id
position['st_id'] = st_id
all_positions.extend(positions)
if not all_positions:
logger.debug(f"交易所 {exchange_id} 无持仓数据")
return True
# 批量同步到数据库
success = self._sync_positions_batch_to_db(all_positions)
if success:
logger.info(f"交易所 {exchange_id} 持仓同步成功: {len(all_positions)} 条持仓")
return success
except Exception as e: except Exception as e:
logger.error(f"同步交易所 {exchange_id} 持仓失败: {e}") logger.error(f"收集交易所 {exchange_id} 持仓数据失败: {e}")
return False
def _sync_positions_batch_to_db(self, all_positions: List[Dict]) -> bool: return positions_list
"""批量同步持仓数据到数据库(优化版)"""
session = self.db_manager.get_session() async def _get_positions_from_redis(self, k_id: int, st_id: int, exchange_id: str) -> List[Dict]:
"""从Redis获取持仓数据"""
try: try:
# 按k_id分组 redis_key = f"{exchange_id}:positions:{k_id}"
redis_data = self.redis_client.client.hget(redis_key, 'positions')
if not redis_data:
return []
positions = json.loads(redis_data)
# 添加账号信息
for position in positions:
position['k_id'] = k_id
position['st_id'] = st_id
position['exchange_id'] = exchange_id
return positions
except Exception as e:
logger.error(f"获取Redis持仓数据失败: k_id={k_id}, error={e}")
return []
async def _sync_positions_batch_to_db(self, all_positions: List[Dict]) -> Tuple[bool, Dict]:
"""批量同步持仓数据到数据库"""
try:
if not all_positions:
return True, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0}
# 按账号分组
positions_by_account = {} positions_by_account = {}
for position in all_positions: for position in all_positions:
k_id = position['k_id'] k_id = position['k_id']
@@ -94,98 +141,239 @@ class PositionSync(BaseSync):
positions_by_account[k_id] = [] positions_by_account[k_id] = []
positions_by_account[k_id].append(position) positions_by_account[k_id].append(position)
success_count = 0 logger.info(f"开始批量处理 {len(positions_by_account)} 个账号的持仓数据")
with session.begin(): # 批量处理每个账号
for k_id, positions in positions_by_account.items(): total_stats = {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0}
try:
st_id = positions[0]['st_id'] if positions else 0
# 准备数据 for k_id, positions in positions_by_account.items():
insert_data = [] st_id = positions[0]['st_id'] if positions else 0
keep_keys = set()
for pos_data in positions: # 处理单个账号的批量同步
try: success, stats = await self._sync_single_account_batch(k_id, st_id, positions)
pos_dict = self._convert_position_data(pos_data)
if not all([pos_dict.get('symbol'), pos_dict.get('side')]):
continue
# 重命名qty为sum if success:
if 'qty' in pos_dict: total_stats['total'] += stats['total']
pos_dict['sum'] = pos_dict.pop('qty') total_stats['updated'] += stats['updated']
total_stats['inserted'] += stats['inserted']
total_stats['deleted'] += stats['deleted']
insert_data.append(pos_dict) return True, total_stats
keep_keys.add((pos_dict['symbol'], pos_dict['side']))
except Exception as e:
logger.error(f"转换持仓数据失败: {pos_data}, error={e}")
continue
if not insert_data:
continue
# 批量插入/更新
from sqlalchemy.dialects.mysql import insert
stmt = insert(StrategyPosition.__table__).values(insert_data)
update_dict = {
'price': stmt.inserted.price,
'sum': stmt.inserted.sum,
'asset_num': stmt.inserted.asset_num,
'asset_profit': stmt.inserted.asset_profit,
'leverage': stmt.inserted.leverage,
'uptime': stmt.inserted.uptime,
'profit_price': stmt.inserted.profit_price,
'stop_price': stmt.inserted.stop_price,
'liquidation_price': stmt.inserted.liquidation_price
}
stmt = stmt.on_duplicate_key_update(**update_dict)
session.execute(stmt)
# 删除多余持仓
if keep_keys:
existing_positions = session.execute(
select(StrategyPosition).where(
and_(
StrategyPosition.k_id == k_id,
StrategyPosition.st_id == st_id
)
)
).scalars().all()
to_delete_ids = []
for existing in existing_positions:
key = (existing.symbol, existing.side)
if key not in keep_keys:
to_delete_ids.append(existing.id)
if to_delete_ids:
# 分块删除
chunk_size = 100
for i in range(0, len(to_delete_ids), chunk_size):
chunk = to_delete_ids[i:i + chunk_size]
session.execute(
delete(StrategyPosition).where(
StrategyPosition.id.in_(chunk)
)
)
success_count += 1
except Exception as e:
logger.error(f"同步账号 {k_id} 持仓失败: {e}")
continue
logger.info(f"批量同步完成: 成功 {success_count}/{len(positions_by_account)} 个账号")
return success_count > 0
except Exception as e: except Exception as e:
logger.error(f"批量同步持仓到数据库失败: {e}") logger.error(f"批量同步持仓到数据库失败: {e}")
return False return False, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0}
async def _sync_single_account_batch(self, k_id: int, st_id: int, positions: List[Dict]) -> Tuple[bool, Dict]:
"""批量同步单个账号的持仓数据"""
session = self.db_manager.get_session()
try:
# 准备数据
insert_data = []
new_positions_map = {} # (symbol, side) -> position_id (用于删除)
for position_data in positions:
try:
position_dict = self._convert_position_data(position_data)
if not all([position_dict.get('symbol'), position_dict.get('side')]):
continue
symbol = position_dict['symbol']
side = position_dict['side']
key = (symbol, side)
# 重命名qty为sum
if 'qty' in position_dict:
position_dict['sum'] = position_dict.pop('qty')
insert_data.append(position_dict)
new_positions_map[key] = position_dict.get('id') # 如果有id的话
except Exception as e:
logger.error(f"转换持仓数据失败: {position_data}, error={e}")
continue
with session.begin():
if not insert_data:
# 清空该账号所有持仓
result = session.execute(
delete(StrategyPosition).where(
and_(
StrategyPosition.k_id == k_id,
StrategyPosition.st_id == st_id
)
)
)
deleted_count = result.rowcount
return True, {
'total': 0,
'updated': 0,
'inserted': 0,
'deleted': deleted_count
}
# 1. 批量插入/更新持仓数据
processed_count = self._batch_upsert_positions(session, insert_data)
# 2. 批量删除多余持仓
deleted_count = self._batch_delete_extra_positions(session, k_id, st_id, new_positions_map)
# 注意这里无法区分插入和更新的数量processed_count是总处理数
inserted_count = processed_count # 简化处理
updated_count = 0 # 需要更复杂的逻辑来区分
stats = {
'total': len(insert_data),
'updated': updated_count,
'inserted': inserted_count,
'deleted': deleted_count
}
return True, stats
except Exception as e:
logger.error(f"批量同步账号 {k_id} 持仓失败: {e}")
return False, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0}
finally: finally:
session.close() session.close()
# 其他方法保持不变... def _batch_upsert_positions(self, session, insert_data: List[Dict]) -> int:
"""批量插入/更新持仓数据"""
try:
# 分块处理
chunk_size = self.batch_size
total_processed = 0
for i in range(0, len(insert_data), chunk_size):
chunk = insert_data[i:i + chunk_size]
values_list = []
for data in chunk:
symbol = data.get('symbol').replace("'", "''") if data.get('symbol') else ''
values = (
f"({data['st_id']}, {data['k_id']}, '{data.get('asset', 'USDT')}', "
f"'{symbol}', "
f"{data.get('price') or 'NULL'}, {data.get('sum') or 'NULL'}, "
f"{data.get('asset_num') or 'NULL'}, {data.get('asset_profit') or 'NULL'}, "
f"{data.get('leverage') or 'NULL'}, {data.get('uptime') or 'NULL'}, "
f"{data.get('profit_price') or 'NULL'}, {data.get('stop_price') or 'NULL'}, "
f"{data.get('liquidation_price') or 'NULL'})"
)
values_list.append(values)
if values_list:
values_str = ", ".join(values_list)
sql = f"""
INSERT INTO deh_strategy_position_new
(st_id, k_id, asset, symbol, side, price, `sum`,
asset_num, asset_profit, leverage, uptime,
profit_price, stop_price, liquidation_price)
VALUES {values_str}
ON DUPLICATE KEY UPDATE
price = VALUES(price),
`sum` = VALUES(`sum`),
asset_num = VALUES(asset_num),
asset_profit = VALUES(asset_profit),
leverage = VALUES(leverage),
uptime = VALUES(uptime),
profit_price = VALUES(profit_price),
stop_price = VALUES(stop_price),
liquidation_price = VALUES(liquidation_price)
"""
session.execute(text(sql))
total_processed += len(chunk)
return total_processed
except Exception as e:
logger.error(f"批量插入/更新持仓失败: {e}")
raise
def _batch_delete_extra_positions(self, session, k_id: int, st_id: int, new_positions_map: Dict) -> int:
"""批量删除多余持仓"""
try:
if not new_positions_map:
# 删除所有持仓
result = session.execute(
delete(StrategyPosition).where(
and_(
StrategyPosition.k_id == k_id,
StrategyPosition.st_id == st_id
)
)
)
return result.rowcount
# 构建保留条件
conditions = []
for (symbol, side) in new_positions_map.keys():
safe_symbol = symbol.replace("'", "''") if symbol else ''
safe_side = side.replace("'", "''") if side else ''
conditions.append(f"(symbol = '{safe_symbol}' AND side = '{safe_side}')")
if conditions:
conditions_str = " OR ".join(conditions)
sql = f"""
DELETE FROM deh_strategy_position_new
WHERE k_id = {k_id} AND st_id = {st_id}
AND NOT ({conditions_str})
"""
result = session.execute(text(sql))
return result.rowcount
return 0
except Exception as e:
logger.error(f"批量删除持仓失败: k_id={k_id}, error={e}")
return 0
def _convert_position_data(self, data: Dict) -> Dict:
"""转换持仓数据格式"""
try:
# 安全转换函数
def safe_float(value, default=None):
if value is None:
return default
try:
return float(value)
except (ValueError, TypeError):
return default
def safe_int(value, default=None):
if value is None:
return default
try:
return int(float(value))
except (ValueError, TypeError):
return default
return {
'st_id': safe_int(data.get('st_id'), 0),
'k_id': safe_int(data.get('k_id'), 0),
'asset': data.get('asset', 'USDT'),
'symbol': data.get('symbol', ''),
'side': data.get('side', ''),
'price': safe_float(data.get('price')),
'qty': safe_float(data.get('qty')), # 后面会重命名为sum
'asset_num': safe_float(data.get('asset_num')),
'asset_profit': safe_float(data.get('asset_profit')),
'leverage': safe_int(data.get('leverage')),
'uptime': safe_int(data.get('uptime')),
'profit_price': safe_float(data.get('profit_price')),
'stop_price': safe_float(data.get('stop_price')),
'liquidation_price': safe_float(data.get('liquidation_price'))
}
except Exception as e:
logger.error(f"转换持仓数据异常: {data}, error={e}")
return {}
async def sync(self):
"""兼容旧接口"""
accounts = self.get_accounts_from_redis()
await self.sync_batch(accounts)

View File

@@ -1,379 +0,0 @@
from .base_sync import BaseSync
from loguru import logger
from typing import List, Dict, Any, Set, Tuple
import json
import asyncio
from datetime import datetime
from sqlalchemy import text, and_, select, delete
from models.orm_models import StrategyPosition
import time
class PositionSyncBatch(BaseSync):
"""持仓数据批量同步器"""
def __init__(self):
super().__init__()
self.batch_size = 500 # 每批处理数量
async def sync_batch(self, accounts: Dict[str, Dict]):
"""批量同步所有账号的持仓数据"""
try:
logger.info(f"开始批量同步持仓数据,共 {len(accounts)} 个账号")
start_time = time.time()
# 1. 收集所有账号的持仓数据
all_positions = await self._collect_all_positions(accounts)
if not all_positions:
logger.info("无持仓数据需要同步")
return
logger.info(f"收集到 {len(all_positions)} 条持仓数据")
# 2. 批量同步到数据库
success, stats = await self._sync_positions_batch_to_db(all_positions)
elapsed = time.time() - start_time
if success:
logger.info(f"持仓批量同步完成: 处理 {stats['total']} 条,更新 {stats['updated']} 条,"
f"插入 {stats['inserted']} 条,删除 {stats['deleted']} 条,耗时 {elapsed:.2f}")
else:
logger.error("持仓批量同步失败")
except Exception as e:
logger.error(f"持仓批量同步失败: {e}")
async def _collect_all_positions(self, accounts: Dict[str, Dict]) -> List[Dict]:
"""收集所有账号的持仓数据"""
all_positions = []
try:
# 按交易所分组账号
account_groups = self._group_accounts_by_exchange(accounts)
# 并发收集每个交易所的数据
tasks = []
for exchange_id, account_list in account_groups.items():
task = self._collect_exchange_positions(exchange_id, account_list)
tasks.append(task)
# 等待所有任务完成并合并结果
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, list):
all_positions.extend(result)
except Exception as e:
logger.error(f"收集持仓数据失败: {e}")
return all_positions
def _group_accounts_by_exchange(self, accounts: Dict[str, Dict]) -> Dict[str, List[Dict]]:
"""按交易所分组账号"""
groups = {}
for account_id, account_info in accounts.items():
exchange_id = account_info.get('exchange_id')
if exchange_id:
if exchange_id not in groups:
groups[exchange_id] = []
groups[exchange_id].append(account_info)
return groups
async def _collect_exchange_positions(self, exchange_id: str, account_list: List[Dict]) -> List[Dict]:
"""收集某个交易所的持仓数据"""
positions_list = []
try:
tasks = []
for account_info in account_list:
k_id = int(account_info['k_id'])
st_id = account_info.get('st_id', 0)
task = self._get_positions_from_redis(k_id, st_id, exchange_id)
tasks.append(task)
# 并发获取
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, list):
positions_list.extend(result)
except Exception as e:
logger.error(f"收集交易所 {exchange_id} 持仓数据失败: {e}")
return positions_list
async def _get_positions_from_redis(self, k_id: int, st_id: int, exchange_id: str) -> List[Dict]:
"""从Redis获取持仓数据"""
try:
redis_key = f"{exchange_id}:positions:{k_id}"
redis_data = self.redis_client.client.hget(redis_key, 'positions')
if not redis_data:
return []
positions = json.loads(redis_data)
# 添加账号信息
for position in positions:
position['k_id'] = k_id
position['st_id'] = st_id
position['exchange_id'] = exchange_id
return positions
except Exception as e:
logger.error(f"获取Redis持仓数据失败: k_id={k_id}, error={e}")
return []
async def _sync_positions_batch_to_db(self, all_positions: List[Dict]) -> Tuple[bool, Dict]:
"""批量同步持仓数据到数据库"""
try:
if not all_positions:
return True, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0}
# 按账号分组
positions_by_account = {}
for position in all_positions:
k_id = position['k_id']
if k_id not in positions_by_account:
positions_by_account[k_id] = []
positions_by_account[k_id].append(position)
logger.info(f"开始批量处理 {len(positions_by_account)} 个账号的持仓数据")
# 批量处理每个账号
total_stats = {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0}
for k_id, positions in positions_by_account.items():
st_id = positions[0]['st_id'] if positions else 0
# 处理单个账号的批量同步
success, stats = await self._sync_single_account_batch(k_id, st_id, positions)
if success:
total_stats['total'] += stats['total']
total_stats['updated'] += stats['updated']
total_stats['inserted'] += stats['inserted']
total_stats['deleted'] += stats['deleted']
return True, total_stats
except Exception as e:
logger.error(f"批量同步持仓到数据库失败: {e}")
return False, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0}
async def _sync_single_account_batch(self, k_id: int, st_id: int, positions: List[Dict]) -> Tuple[bool, Dict]:
"""批量同步单个账号的持仓数据"""
session = self.db_manager.get_session()
try:
# 准备数据
insert_data = []
new_positions_map = {} # (symbol, side) -> position_id (用于删除)
for position_data in positions:
try:
position_dict = self._convert_position_data(position_data)
if not all([position_dict.get('symbol'), position_dict.get('side')]):
continue
symbol = position_dict['symbol']
side = position_dict['side']
key = (symbol, side)
# 重命名qty为sum
if 'qty' in position_dict:
position_dict['sum'] = position_dict.pop('qty')
insert_data.append(position_dict)
new_positions_map[key] = position_dict.get('id') # 如果有id的话
except Exception as e:
logger.error(f"转换持仓数据失败: {position_data}, error={e}")
continue
with session.begin():
if not insert_data:
# 清空该账号所有持仓
result = session.execute(
delete(StrategyPosition).where(
and_(
StrategyPosition.k_id == k_id,
StrategyPosition.st_id == st_id
)
)
)
deleted_count = result.rowcount
return True, {
'total': 0,
'updated': 0,
'inserted': 0,
'deleted': deleted_count
}
# 1. 批量插入/更新持仓数据
processed_count = self._batch_upsert_positions(session, insert_data)
# 2. 批量删除多余持仓
deleted_count = self._batch_delete_extra_positions(session, k_id, st_id, new_positions_map)
# 注意这里无法区分插入和更新的数量processed_count是总处理数
inserted_count = processed_count # 简化处理
updated_count = 0 # 需要更复杂的逻辑来区分
stats = {
'total': len(insert_data),
'updated': updated_count,
'inserted': inserted_count,
'deleted': deleted_count
}
return True, stats
except Exception as e:
logger.error(f"批量同步账号 {k_id} 持仓失败: {e}")
return False, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0}
finally:
session.close()
def _batch_upsert_positions(self, session, insert_data: List[Dict]) -> int:
"""批量插入/更新持仓数据"""
try:
# 分块处理
chunk_size = self.batch_size
total_processed = 0
for i in range(0, len(insert_data), chunk_size):
chunk = insert_data[i:i + chunk_size]
values_list = []
for data in chunk:
symbol = data.get('symbol').replace("'", "''") if data.get('symbol') else ''
values = (
f"({data['st_id']}, {data['k_id']}, '{data.get('asset', 'USDT')}', "
f"'{symbol}', "
f"{data.get('price') or 'NULL'}, {data.get('sum') or 'NULL'}, "
f"{data.get('asset_num') or 'NULL'}, {data.get('asset_profit') or 'NULL'}, "
f"{data.get('leverage') or 'NULL'}, {data.get('uptime') or 'NULL'}, "
f"{data.get('profit_price') or 'NULL'}, {data.get('stop_price') or 'NULL'}, "
f"{data.get('liquidation_price') or 'NULL'})"
)
values_list.append(values)
if values_list:
values_str = ", ".join(values_list)
sql = f"""
INSERT INTO deh_strategy_position_new
(st_id, k_id, asset, symbol, side, price, `sum`,
asset_num, asset_profit, leverage, uptime,
profit_price, stop_price, liquidation_price)
VALUES {values_str}
ON DUPLICATE KEY UPDATE
price = VALUES(price),
`sum` = VALUES(`sum`),
asset_num = VALUES(asset_num),
asset_profit = VALUES(asset_profit),
leverage = VALUES(leverage),
uptime = VALUES(uptime),
profit_price = VALUES(profit_price),
stop_price = VALUES(stop_price),
liquidation_price = VALUES(liquidation_price)
"""
session.execute(text(sql))
total_processed += len(chunk)
return total_processed
except Exception as e:
logger.error(f"批量插入/更新持仓失败: {e}")
raise
def _batch_delete_extra_positions(self, session, k_id: int, st_id: int, new_positions_map: Dict) -> int:
"""批量删除多余持仓"""
try:
if not new_positions_map:
# 删除所有持仓
result = session.execute(
delete(StrategyPosition).where(
and_(
StrategyPosition.k_id == k_id,
StrategyPosition.st_id == st_id
)
)
)
return result.rowcount
# 构建保留条件
conditions = []
for (symbol, side) in new_positions_map.keys():
safe_symbol = symbol.replace("'", "''") if symbol else ''
safe_side = side.replace("'", "''") if side else ''
conditions.append(f"(symbol = '{safe_symbol}' AND side = '{safe_side}')")
if conditions:
conditions_str = " OR ".join(conditions)
sql = f"""
DELETE FROM deh_strategy_position_new
WHERE k_id = {k_id} AND st_id = {st_id}
AND NOT ({conditions_str})
"""
result = session.execute(text(sql))
return result.rowcount
return 0
except Exception as e:
logger.error(f"批量删除持仓失败: k_id={k_id}, error={e}")
return 0
def _convert_position_data(self, data: Dict) -> Dict:
"""转换持仓数据格式"""
try:
# 安全转换函数
def safe_float(value, default=None):
if value is None:
return default
try:
return float(value)
except (ValueError, TypeError):
return default
def safe_int(value, default=None):
if value is None:
return default
try:
return int(float(value))
except (ValueError, TypeError):
return default
return {
'st_id': safe_int(data.get('st_id'), 0),
'k_id': safe_int(data.get('k_id'), 0),
'asset': data.get('asset', 'USDT'),
'symbol': data.get('symbol', ''),
'side': data.get('side', ''),
'price': safe_float(data.get('price')),
'qty': safe_float(data.get('qty')), # 后面会重命名为sum
'asset_num': safe_float(data.get('asset_num')),
'asset_profit': safe_float(data.get('asset_profit')),
'leverage': safe_int(data.get('leverage')),
'uptime': safe_int(data.get('uptime')),
'profit_price': safe_float(data.get('profit_price')),
'stop_price': safe_float(data.get('stop_price')),
'liquidation_price': safe_float(data.get('liquidation_price'))
}
except Exception as e:
logger.error(f"转换持仓数据异常: {data}, error={e}")
return {}
async def sync(self):
"""兼容旧接口"""
accounts = self.get_accounts_from_redis()
await self.sync_batch(accounts)

View File

@@ -1,174 +0,0 @@
from typing import List, Dict, Any, Tuple
from loguru import logger
from sqlalchemy import text
import time
class BatchAccountSync:
"""账户信息批量同步工具"""
def __init__(self, db_manager):
self.db_manager = db_manager
def sync_accounts_batch(self, all_account_data: List[Dict]) -> Tuple[int, int]:
"""批量同步账户信息(最高效版本)"""
if not all_account_data:
return 0, 0
session = self.db_manager.get_session()
try:
start_time = time.time()
# 方法1使用临时表进行批量操作性能最好
updated_count, inserted_count = self._sync_using_temp_table(session, all_account_data)
elapsed = time.time() - start_time
logger.info(f"账户信息批量同步完成: 更新 {updated_count} 条,插入 {inserted_count} 条,耗时 {elapsed:.2f}")
return updated_count, inserted_count
except Exception as e:
logger.error(f"账户信息批量同步失败: {e}")
return 0, 0
finally:
session.close()
def _sync_using_temp_table(self, session, all_account_data: List[Dict]) -> Tuple[int, int]:
"""使用临时表进行批量同步"""
try:
# 1. 创建临时表
session.execute(text("""
CREATE TEMPORARY TABLE IF NOT EXISTS temp_account_info (
st_id INT,
k_id INT,
asset VARCHAR(32),
balance DECIMAL(20, 8),
withdrawal DECIMAL(20, 8),
deposit DECIMAL(20, 8),
other DECIMAL(20, 8),
profit DECIMAL(20, 8),
time INT,
PRIMARY KEY (k_id, st_id, time)
)
"""))
# 2. 清空临时表
session.execute(text("TRUNCATE TABLE temp_account_info"))
# 3. 批量插入数据到临时表
chunk_size = 1000
total_inserted = 0
for i in range(0, len(all_account_data), chunk_size):
chunk = all_account_data[i:i + chunk_size]
values_list = []
for data in chunk:
values = (
f"({data['st_id']}, {data['k_id']}, 'USDT', "
f"{data['balance']}, {data['withdrawal']}, {data['deposit']}, "
f"{data['other']}, {data['profit']}, {data['time']})"
)
values_list.append(values)
if values_list:
values_str = ", ".join(values_list)
sql = f"""
INSERT INTO temp_account_info
(st_id, k_id, asset, balance, withdrawal, deposit, other, profit, time)
VALUES {values_str}
"""
session.execute(text(sql))
total_inserted += len(chunk)
# 4. 使用临时表更新主表
# 更新已存在的记录
update_result = session.execute(text("""
UPDATE deh_strategy_kx_new main
INNER JOIN temp_account_info temp
ON main.k_id = temp.k_id
AND main.st_id = temp.st_id
AND main.time = temp.time
SET main.balance = temp.balance,
main.withdrawal = temp.withdrawal,
main.deposit = temp.deposit,
main.other = temp.other,
main.profit = temp.profit,
main.up_time = NOW()
"""))
updated_count = update_result.rowcount
# 插入新记录
insert_result = session.execute(text("""
INSERT INTO deh_strategy_kx_new
(st_id, k_id, asset, balance, withdrawal, deposit, other, profit, time, up_time)
SELECT
st_id, k_id, asset, balance, withdrawal, deposit, other, profit, time, NOW()
FROM temp_account_info temp
WHERE NOT EXISTS (
SELECT 1 FROM deh_strategy_kx_new main
WHERE main.k_id = temp.k_id
AND main.st_id = temp.st_id
AND main.time = temp.time
)
"""))
inserted_count = insert_result.rowcount
# 5. 删除临时表
session.execute(text("DROP TEMPORARY TABLE IF EXISTS temp_account_info"))
session.commit()
return updated_count, inserted_count
except Exception as e:
session.rollback()
logger.error(f"临时表同步失败: {e}")
raise
def _sync_using_on_duplicate(self, session, all_account_data: List[Dict]) -> Tuple[int, int]:
"""使用ON DUPLICATE KEY UPDATE批量同步简化版"""
try:
# 分块执行避免SQL过长
chunk_size = 1000
total_processed = 0
for i in range(0, len(all_account_data), chunk_size):
chunk = all_account_data[i:i + chunk_size]
values_list = []
for data in chunk:
values = (
f"({data['st_id']}, {data['k_id']}, 'USDT', "
f"{data['balance']}, {data['withdrawal']}, {data['deposit']}, "
f"{data['other']}, {data['profit']}, {data['time']})"
)
values_list.append(values)
if values_list:
values_str = ", ".join(values_list)
sql = f"""
INSERT INTO deh_strategy_kx_new
(st_id, k_id, asset, balance, withdrawal, deposit, other, profit, time)
VALUES {values_str}
ON DUPLICATE KEY UPDATE
balance = VALUES(balance),
withdrawal = VALUES(withdrawal),
deposit = VALUES(deposit),
other = VALUES(other),
profit = VALUES(profit),
up_time = NOW()
"""
result = session.execute(text(sql))
total_processed += len(chunk)
session.commit()
# 注意:这里无法区分更新和插入的数量
return total_processed, 0
except Exception as e:
session.rollback()
logger.error(f"ON DUPLICATE同步失败: {e}")
raise

View File

@@ -1,138 +0,0 @@
from typing import List, Dict, Any, Tuple
from loguru import logger
from sqlalchemy import text
from .database_manager import DatabaseManager
class BatchOperations:
"""批量数据库操作工具"""
def __init__(self):
self.db_manager = DatabaseManager()
def batch_insert_update_positions(self, positions_data: List[Dict]) -> Tuple[int, int]:
"""批量插入/更新持仓数据"""
session = self.db_manager.get_session()
try:
if not positions_data:
return 0, 0
# 按账号分组
positions_by_account = {}
for position in positions_data:
k_id = position.get('k_id')
if k_id not in positions_by_account:
positions_by_account[k_id] = []
positions_by_account[k_id].append(position)
total_processed = 0
total_deleted = 0
with session.begin():
for k_id, positions in positions_by_account.items():
processed, deleted = self._process_account_positions(session, k_id, positions)
total_processed += processed
total_deleted += deleted
logger.info(f"批量处理持仓完成: 处理 {total_processed} 条,删除 {total_deleted}")
return total_processed, total_deleted
except Exception as e:
logger.error(f"批量处理持仓失败: {e}")
return 0, 0
finally:
session.close()
def _process_account_positions(self, session, k_id: int, positions: List[Dict]) -> Tuple[int, int]:
"""处理单个账号的持仓数据"""
try:
st_id = positions[0].get('st_id', 0) if positions else 0
# 准备数据
insert_data = []
keep_keys = set()
for pos_data in positions:
# 转换数据
pos_dict = self._convert_position_data(pos_data)
if not all([pos_dict.get('symbol'), pos_dict.get('side')]):
continue
# 重命名qty为sum
if 'qty' in pos_dict:
pos_dict['sum'] = pos_dict.pop('qty')
insert_data.append(pos_dict)
keep_keys.add((pos_dict['symbol'], pos_dict['side']))
if not insert_data:
# 清空该账号持仓
result = session.execute(
text("DELETE FROM deh_strategy_position_new WHERE k_id = :k_id AND st_id = :st_id"),
{'k_id': k_id, 'st_id': st_id}
)
return 0, result.rowcount
# 批量插入/更新
sql = """
INSERT INTO deh_strategy_position_new
(st_id, k_id, asset, symbol, side, price, `sum`,
asset_num, asset_profit, leverage, uptime,
profit_price, stop_price, liquidation_price)
VALUES
(:st_id, :k_id, :asset, :symbol, :side, :price, :sum,
:asset_num, :asset_profit, :leverage, :uptime,
:profit_price, :stop_price, :liquidation_price)
ON DUPLICATE KEY UPDATE
price = VALUES(price),
`sum` = VALUES(`sum`),
asset_num = VALUES(asset_num),
asset_profit = VALUES(asset_profit),
leverage = VALUES(leverage),
uptime = VALUES(uptime),
profit_price = VALUES(profit_price),
stop_price = VALUES(stop_price),
liquidation_price = VALUES(liquidation_price)
"""
# 分块执行
chunk_size = 500
processed_count = 0
for i in range(0, len(insert_data), chunk_size):
chunk = insert_data[i:i + chunk_size]
session.execute(text(sql), chunk)
processed_count += len(chunk)
# 删除多余持仓
deleted_count = 0
if keep_keys:
# 构建删除条件
conditions = []
for symbol, side in keep_keys:
safe_symbol = symbol.replace("'", "''") if symbol else ''
safe_side = side.replace("'", "''") if side else ''
conditions.append(f"(symbol = '{safe_symbol}' AND side = '{safe_side}')")
if conditions:
conditions_str = " OR ".join(conditions)
delete_sql = f"""
DELETE FROM deh_strategy_position_new
WHERE k_id = {k_id} AND st_id = {st_id}
AND NOT ({conditions_str})
"""
result = session.execute(text(delete_sql))
deleted_count = result.rowcount
return processed_count, deleted_count
except Exception as e:
logger.error(f"处理账号 {k_id} 持仓失败: {e}")
return 0, 0
def _convert_position_data(self, data: Dict) -> Dict:
"""转换持仓数据格式"""
# 转换逻辑...
pass
# 类似的批量方法 for orders and account info...

View File

@@ -1,313 +0,0 @@
from typing import List, Dict, Any, Tuple
from loguru import logger
from sqlalchemy import text
import time
class BatchOrderSync:
"""订单数据批量同步工具(最高性能)"""
def __init__(self, db_manager, batch_size: int = 1000):
self.db_manager = db_manager
self.batch_size = batch_size
def sync_orders_batch(self, all_orders: List[Dict]) -> Tuple[bool, int]:
"""批量同步订单数据"""
if not all_orders:
return True, 0
session = self.db_manager.get_session()
try:
start_time = time.time()
# 方法1使用临时表性能最好
processed_count = self._sync_using_temp_table(session, all_orders)
elapsed = time.time() - start_time
logger.info(f"订单批量同步完成: 处理 {processed_count} 条订单,耗时 {elapsed:.2f}")
return True, processed_count
except Exception as e:
logger.error(f"订单批量同步失败: {e}")
return False, 0
finally:
session.close()
def _sync_using_temp_table(self, session, all_orders: List[Dict]) -> int:
"""使用临时表批量同步订单"""
try:
# 1. 创建临时表
session.execute(text("""
CREATE TEMPORARY TABLE IF NOT EXISTS temp_orders (
st_id INT,
k_id INT,
asset VARCHAR(32),
order_id VARCHAR(765),
symbol VARCHAR(120),
side VARCHAR(120),
price FLOAT,
time INT,
order_qty FLOAT,
last_qty FLOAT,
avg_price FLOAT,
exchange_id INT,
UNIQUE KEY idx_unique_order (order_id, symbol, k_id, side)
)
"""))
# 2. 清空临时表
session.execute(text("TRUNCATE TABLE temp_orders"))
# 3. 批量插入数据到临时表(分块)
inserted_count = self._batch_insert_to_temp_table(session, all_orders)
if inserted_count == 0:
session.execute(text("DROP TEMPORARY TABLE IF EXISTS temp_orders"))
return 0
# 4. 使用临时表更新主表
# 更新已存在的记录(只更新需要比较的字段)
update_result = session.execute(text("""
UPDATE deh_strategy_order_new main
INNER JOIN temp_orders temp
ON main.order_id = temp.order_id
AND main.symbol = temp.symbol
AND main.k_id = temp.k_id
AND main.side = temp.side
SET main.side = temp.side,
main.price = temp.price,
main.time = temp.time,
main.order_qty = temp.order_qty,
main.last_qty = temp.last_qty,
main.avg_price = temp.avg_price
WHERE main.side != temp.side
OR main.price != temp.price
OR main.time != temp.time
OR main.order_qty != temp.order_qty
OR main.last_qty != temp.last_qty
OR main.avg_price != temp.avg_price
"""))
updated_count = update_result.rowcount
# 插入新记录
insert_result = session.execute(text("""
INSERT INTO deh_strategy_order_new
(st_id, k_id, asset, order_id, symbol, side, price, time,
order_qty, last_qty, avg_price, exchange_id)
SELECT
st_id, k_id, asset, order_id, symbol, side, price, time,
order_qty, last_qty, avg_price, exchange_id
FROM temp_orders temp
WHERE NOT EXISTS (
SELECT 1 FROM deh_strategy_order_new main
WHERE main.order_id = temp.order_id
AND main.symbol = temp.symbol
AND main.k_id = temp.k_id
AND main.side = temp.side
)
"""))
inserted_count = insert_result.rowcount
# 5. 删除临时表
session.execute(text("DROP TEMPORARY TABLE IF EXISTS temp_orders"))
session.commit()
total_processed = updated_count + inserted_count
logger.info(f"订单批量同步: 更新 {updated_count} 条,插入 {inserted_count}")
return total_processed
except Exception as e:
session.rollback()
logger.error(f"临时表同步订单失败: {e}")
raise
def _batch_insert_to_temp_table(self, session, all_orders: List[Dict]) -> int:
"""批量插入数据到临时表"""
total_inserted = 0
try:
# 分块处理
for i in range(0, len(all_orders), self.batch_size):
chunk = all_orders[i:i + self.batch_size]
values_list = []
for order in chunk:
try:
# 处理NULL值
price = order.get('price')
time_val = order.get('time')
order_qty = order.get('order_qty')
last_qty = order.get('last_qty')
avg_price = order.get('avg_price')
# 转义单引号
symbol = order.get('symbol').replace("'", "''") if order.get('symbol') else ''
order_id = order.get('order_id').replace("'", "''") if order.get('order_id') else ''
values = (
f"({order['st_id']}, {order['k_id']}, '{order.get('asset', 'USDT')}', "
f"'{order_id}', "
f"'{symbol}', "
f"'{order['side']}', "
f"{price if price is not None else 'NULL'}, "
f"{time_val if time_val is not None else 'NULL'}, "
f"{order_qty if order_qty is not None else 'NULL'}, "
f"{last_qty if last_qty is not None else 'NULL'}, "
f"{avg_price if avg_price is not None else 'NULL'}, "
"NULL)"
)
values_list.append(values)
except Exception as e:
logger.error(f"构建订单值失败: {order}, error={e}")
continue
if values_list:
values_str = ", ".join(values_list)
sql = f"""
INSERT INTO temp_orders
(st_id, k_id, asset, order_id, symbol, side, price, time,
order_qty, last_qty, avg_price, exchange_id)
VALUES {values_str}
"""
result = session.execute(text(sql))
total_inserted += len(chunk)
return total_inserted
except Exception as e:
logger.error(f"批量插入临时表失败: {e}")
raise
def _batch_insert_to_temp_table1(self, session, all_orders: List[Dict]) -> int:
"""批量插入数据到临时表使用参数化查询temp_orders"""
total_inserted = 0
try:
# 分块处理
for i in range(0, len(all_orders), self.batch_size):
chunk = all_orders[i:i + self.batch_size]
# 准备参数化数据
insert_data = []
for order in chunk:
try:
insert_data.append({
'st_id': order['st_id'],
'k_id': order['k_id'],
'asset': order.get('asset', 'USDT'),
'order_id': order['order_id'],
'symbol': order['symbol'],
'side': order['side'],
'price': order.get('price'),
'time': order.get('time'),
'order_qty': order.get('order_qty'),
'last_qty': order.get('last_qty'),
'avg_price': order.get('avg_price')
# exchange_id 留空使用默认值NULL
})
except KeyError as e:
logger.error(f"订单数据缺少必要字段: {order}, missing={e}")
continue
except Exception as e:
logger.error(f"处理订单数据失败: {order}, error={e}")
continue
if insert_data:
sql = text(f"""
INSERT INTO {self.temp_table_name}
(st_id, k_id, asset, order_id, symbol, side, price, time,
order_qty, last_qty, avg_price)
VALUES
(:st_id, :k_id, :asset, :order_id, :symbol, :side, :price, :time,
:order_qty, :last_qty, :avg_price)
""")
try:
session.execute(sql, insert_data)
session.commit()
total_inserted += len(insert_data)
logger.debug(f"插入 {len(insert_data)} 条数据到临时表")
except Exception as e:
session.rollback()
logger.error(f"执行批量插入失败: {e}")
raise
logger.info(f"总共插入 {total_inserted} 条数据到临时表")
return total_inserted
except Exception as e:
logger.error(f"批量插入临时表失败: {e}")
session.rollback()
raise
def _sync_using_on_duplicate(self, session, all_orders: List[Dict]) -> int:
"""使用ON DUPLICATE KEY UPDATE批量同步简化版"""
try:
total_processed = 0
# 分块执行
for i in range(0, len(all_orders), self.batch_size):
chunk = all_orders[i:i + self.batch_size]
values_list = []
for order in chunk:
try:
# 处理NULL值
price = order.get('price')
time_val = order.get('time')
order_qty = order.get('order_qty')
last_qty = order.get('last_qty')
avg_price = order.get('avg_price')
symbol = order.get('symbol').replace("'", "''") if order.get('symbol') else ''
order_id = order.get('order_id').replace("'", "''") if order.get('order_id') else ''
values = (
f"({order['st_id']}, {order['k_id']}, '{order.get('asset', 'USDT')}', "
f"'{order_id}', "
f"'{symbol}', "
f"'{order['side']}', "
f"{price if price is not None else 'NULL'}, "
f"{time_val if time_val is not None else 'NULL'}, "
f"{order_qty if order_qty is not None else 'NULL'}, "
f"{last_qty if last_qty is not None else 'NULL'}, "
f"{avg_price if avg_price is not None else 'NULL'}, "
"NULL)"
)
values_list.append(values)
except Exception as e:
logger.error(f"构建订单值失败: {order}, error={e}")
continue
if values_list:
values_str = ", ".join(values_list)
sql = f"""
INSERT INTO deh_strategy_order_new
(st_id, k_id, asset, order_id, symbol, side, price, time,
order_qty, last_qty, avg_price, exchange_id)
VALUES {values_str}
ON DUPLICATE KEY UPDATE
side = VALUES(side),
price = VALUES(price),
time = VALUES(time),
order_qty = VALUES(order_qty),
last_qty = VALUES(last_qty),
avg_price = VALUES(avg_price)
"""
session.execute(text(sql))
total_processed += len(chunk)
session.commit()
return total_processed
except Exception as e:
session.rollback()
logger.error(f"ON DUPLICATE同步订单失败: {e}")
raise

View File

@@ -1,254 +0,0 @@
from typing import List, Dict, Any, Tuple
from loguru import logger
from sqlalchemy import text
import time
class BatchPositionSync:
"""持仓数据批量同步工具(使用临时表,最高性能)"""
def __init__(self, db_manager, batch_size: int = 500):
self.db_manager = db_manager
self.batch_size = batch_size
def sync_positions_batch(self, all_positions: List[Dict]) -> Tuple[bool, Dict]:
"""批量同步持仓数据(最高效版本)"""
if not all_positions:
return True, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0}
session = self.db_manager.get_session()
try:
start_time = time.time()
# 按账号分组
positions_by_account = self._group_positions_by_account(all_positions)
total_stats = {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0}
with session.begin():
# 处理每个账号
for (k_id, st_id), positions in positions_by_account.items():
success, stats = self._sync_account_using_temp_table(
session, k_id, st_id, positions
)
if success:
total_stats['total'] += stats['total']
total_stats['updated'] += stats['updated']
total_stats['inserted'] += stats['inserted']
total_stats['deleted'] += stats['deleted']
elapsed = time.time() - start_time
logger.info(f"持仓批量同步完成: 处理 {len(positions_by_account)} 个账号,"
f"总持仓 {total_stats['total']} 条,耗时 {elapsed:.2f}")
return True, total_stats
except Exception as e:
logger.error(f"持仓批量同步失败: {e}")
return False, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0}
finally:
session.close()
def _group_positions_by_account(self, all_positions: List[Dict]) -> Dict[Tuple[int, int], List[Dict]]:
"""按账号分组持仓数据"""
groups = {}
for position in all_positions:
k_id = position.get('k_id')
st_id = position.get('st_id', 0)
key = (k_id, st_id)
if key not in groups:
groups[key] = []
groups[key].append(position)
return groups
def _sync_account_using_temp_table(self, session, k_id: int, st_id: int, positions: List[Dict]) -> Tuple[bool, Dict]:
"""使用临时表同步单个账号的持仓数据"""
try:
# 1. 创建临时表
session.execute(text("""
CREATE TEMPORARY TABLE IF NOT EXISTS temp_positions (
st_id INT,
k_id INT,
asset VARCHAR(32),
symbol VARCHAR(50),
side VARCHAR(10),
price FLOAT,
`sum` FLOAT,
asset_num DECIMAL(20, 8),
asset_profit DECIMAL(20, 8),
leverage INT,
uptime INT,
profit_price DECIMAL(20, 8),
stop_price DECIMAL(20, 8),
liquidation_price DECIMAL(20, 8),
PRIMARY KEY (k_id, st_id, symbol, side)
)
"""))
# 2. 清空临时表
session.execute(text("TRUNCATE TABLE temp_positions"))
# 3. 批量插入数据到临时表
self._batch_insert_to_temp_table(session, positions)
# 4. 使用临时表更新主表
# 更新已存在的记录
update_result = session.execute(text(f"""
UPDATE deh_strategy_position_new main
INNER JOIN temp_positions temp
ON main.k_id = temp.k_id
AND main.st_id = temp.st_id
AND main.symbol = temp.symbol
AND main.side = temp.side
SET main.price = temp.price,
main.`sum` = temp.`sum`,
main.asset_num = temp.asset_num,
main.asset_profit = temp.asset_profit,
main.leverage = temp.leverage,
main.uptime = temp.uptime,
main.profit_price = temp.profit_price,
main.stop_price = temp.stop_price,
main.liquidation_price = temp.liquidation_price
WHERE main.k_id = {k_id} AND main.st_id = {st_id}
"""))
updated_count = update_result.rowcount
# 插入新记录
insert_result = session.execute(text(f"""
INSERT INTO deh_strategy_position_new
(st_id, k_id, asset, symbol, side, price, `sum`,
asset_num, asset_profit, leverage, uptime,
profit_price, stop_price, liquidation_price)
SELECT
st_id, k_id, asset, symbol, side, price, `sum`,
asset_num, asset_profit, leverage, uptime,
profit_price, stop_price, liquidation_price
FROM temp_positions temp
WHERE NOT EXISTS (
SELECT 1 FROM deh_strategy_position_new main
WHERE main.k_id = temp.k_id
AND main.st_id = temp.st_id
AND main.symbol = temp.symbol
AND main.side = temp.side
)
AND temp.k_id = {k_id} AND temp.st_id = {st_id}
"""))
inserted_count = insert_result.rowcount
# 5. 删除多余持仓(在临时表中不存在但在主表中存在的)
delete_result = session.execute(text(f"""
DELETE main
FROM deh_strategy_position_new main
LEFT JOIN temp_positions temp
ON main.k_id = temp.k_id
AND main.st_id = temp.st_id
AND main.symbol = temp.symbol
AND main.side = temp.side
WHERE main.k_id = {k_id} AND main.st_id = {st_id}
AND temp.symbol IS NULL
"""))
deleted_count = delete_result.rowcount
# 6. 删除临时表
session.execute(text("DROP TEMPORARY TABLE IF EXISTS temp_positions"))
stats = {
'total': len(positions),
'updated': updated_count,
'inserted': inserted_count,
'deleted': deleted_count
}
logger.debug(f"账号({k_id},{st_id})持仓同步: 更新{updated_count} 插入{inserted_count} 删除{deleted_count}")
return True, stats
except Exception as e:
logger.error(f"临时表同步账号({k_id},{st_id})持仓失败: {e}")
return False, {'total': 0, 'updated': 0, 'inserted': 0, 'deleted': 0}
def _batch_insert_to_temp_table(self, session, positions: List[Dict]):
"""批量插入数据到临时表(使用参数化查询)"""
if not positions:
return
# 分块处理
for i in range(0, len(positions), self.batch_size):
chunk = positions[i:i + self.batch_size]
# 准备参数化数据
insert_data = []
for position in chunk:
try:
data = self._convert_position_for_temp(position)
if not all([data.get('symbol'), data.get('side')]):
continue
insert_data.append({
'st_id': data['st_id'],
'k_id': data['k_id'],
'asset': data.get('asset', 'USDT'),
'symbol': data['symbol'],
'side': data['side'],
'price': data.get('price'),
'sum_val': data.get('sum'), # 注意字段名
'asset_num': data.get('asset_num'),
'asset_profit': data.get('asset_profit'),
'leverage': data.get('leverage'),
'uptime': data.get('uptime'),
'profit_price': data.get('profit_price'),
'stop_price': data.get('stop_price'),
'liquidation_price': data.get('liquidation_price')
})
except Exception as e:
logger.error(f"转换持仓数据失败: {position}, error={e}")
continue
if insert_data:
sql = """
INSERT INTO temp_positions
(st_id, k_id, asset, symbol, side, price, `sum`,
asset_num, asset_profit, leverage, uptime,
profit_price, stop_price, liquidation_price)
VALUES
(:st_id, :k_id, :asset, :symbol, :side, :price, :sum_val,
:asset_num, :asset_profit, :leverage, :uptime,
:profit_price, :stop_price, :liquidation_price)
"""
session.execute(text(sql), insert_data)
def _convert_position_for_temp(self, data: Dict) -> Dict:
"""转换持仓数据格式用于临时表"""
# 使用安全转换
def safe_float(value):
try:
return float(value) if value is not None else None
except:
return None
def safe_int(value):
try:
return int(value) if value is not None else None
except:
return None
return {
'st_id': safe_int(data.get('st_id')) or 0,
'k_id': safe_int(data.get('k_id')) or 0,
'asset': data.get('asset', 'USDT'),
'symbol': str(data.get('symbol', '')),
'side': str(data.get('side', '')),
'price': safe_float(data.get('price')),
'sum': safe_float(data.get('qty')), # 注意这里直接使用sum
'asset_num': safe_float(data.get('asset_num')),
'asset_profit': safe_float(data.get('asset_profit')),
'leverage': safe_int(data.get('leverage')),
'uptime': safe_int(data.get('uptime')),
'profit_price': safe_float(data.get('profit_price')),
'stop_price': safe_float(data.get('stop_price')),
'liquidation_price': safe_float(data.get('liquidation_price'))
}