190 lines
6.2 KiB
Python
190 lines
6.2 KiB
Python
# sync/base_sync.py
|
||
from abc import ABC, abstractmethod
|
||
from loguru import logger
|
||
from typing import List, Dict, Any, Set, Optional
|
||
import json
|
||
import re
|
||
import time
|
||
|
||
from utils.redis_client import RedisClient
|
||
from utils.database_manager import DatabaseManager
|
||
from config.settings import COMPUTER_NAMES, COMPUTER_NAME_PATTERN
|
||
|
||
class BaseSync(ABC):
|
||
"""同步基类"""
|
||
|
||
def __init__(self):
|
||
self.redis_client = RedisClient()
|
||
self.db_manager = DatabaseManager()
|
||
self.computer_names = self._get_computer_names()
|
||
self.computer_name_pattern = re.compile(COMPUTER_NAME_PATTERN)
|
||
self.sync_stats = {
|
||
'total_accounts': 0,
|
||
'success_count': 0,
|
||
'error_count': 0,
|
||
'last_sync_time': 0,
|
||
'avg_sync_time': 0
|
||
}
|
||
|
||
def _get_computer_names(self) -> List[str]:
|
||
"""获取计算机名列表"""
|
||
if ',' in COMPUTER_NAMES:
|
||
names = [name.strip() for name in COMPUTER_NAMES.split(',')]
|
||
logger.info(f"使用配置的计算机名列表: {names}")
|
||
return names
|
||
return [COMPUTER_NAMES.strip()]
|
||
|
||
@abstractmethod
|
||
async def sync(self):
|
||
"""执行同步(兼容旧接口)"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
async def sync_batch(self, accounts: Dict[str, Dict]):
|
||
"""批量同步数据"""
|
||
pass
|
||
|
||
def _safe_float(self, value: Any, default: float = 0.0) -> float:
|
||
"""安全转换为float"""
|
||
if value is None:
|
||
return default
|
||
try:
|
||
if isinstance(value, str):
|
||
value = value.strip()
|
||
if value == '':
|
||
return default
|
||
return float(value)
|
||
except (ValueError, TypeError):
|
||
return default
|
||
|
||
def _safe_int(self, value: Any, default: int = 0) -> int:
|
||
"""安全转换为int"""
|
||
if value is None:
|
||
return default
|
||
try:
|
||
if isinstance(value, str):
|
||
value = value.strip()
|
||
if value == '':
|
||
return default
|
||
return int(float(value))
|
||
except (ValueError, TypeError):
|
||
return default
|
||
|
||
def _safe_str(self, value: Any, default: str = '') -> str:
|
||
"""安全转换为str"""
|
||
if value is None:
|
||
return default
|
||
try:
|
||
result = str(value).strip()
|
||
return result if result else default
|
||
except:
|
||
return default
|
||
|
||
def _escape_sql_value(self, value: Any) -> str:
|
||
"""转义SQL值"""
|
||
if value is None:
|
||
return 'NULL'
|
||
if isinstance(value, bool):
|
||
return '1' if value else '0'
|
||
if isinstance(value, (int, float)):
|
||
return str(value)
|
||
if isinstance(value, str):
|
||
# 转义单引号
|
||
escaped = value.replace("'", "''")
|
||
return f"'{escaped}'"
|
||
# 其他类型转换为字符串
|
||
escaped = str(value).replace("'", "''")
|
||
return f"'{escaped}'"
|
||
|
||
def _build_sql_values_list(self, data_list: List[Dict], fields_mapping: Dict[str, str] = None) -> List[str]:
|
||
"""构建SQL VALUES列表"""
|
||
values_list = []
|
||
|
||
for data in data_list:
|
||
try:
|
||
value_parts = []
|
||
for field, value in data.items():
|
||
# 应用字段映射
|
||
if fields_mapping and field in fields_mapping:
|
||
db_field = fields_mapping[field]
|
||
else:
|
||
db_field = field
|
||
|
||
escaped_value = self._escape_sql_value(value)
|
||
value_parts.append(escaped_value)
|
||
|
||
values_str = ", ".join(value_parts)
|
||
values_list.append(f"({values_str})")
|
||
|
||
except Exception as e:
|
||
logger.error(f"构建SQL值失败: {data}, error={e}")
|
||
continue
|
||
|
||
return values_list
|
||
|
||
def _get_recent_dates(self, days: int) -> List[str]:
|
||
"""获取最近N天的日期列表"""
|
||
from datetime import datetime, timedelta
|
||
|
||
dates = []
|
||
today = datetime.now()
|
||
|
||
for i in range(days):
|
||
date = today - timedelta(days=i)
|
||
dates.append(date.strftime('%Y-%m-%d'))
|
||
|
||
return dates
|
||
|
||
def _date_to_timestamp(self, date_str: str) -> int:
|
||
"""将日期字符串转换为时间戳(当天0点)"""
|
||
from datetime import datetime
|
||
|
||
try:
|
||
dt = datetime.strptime(date_str, '%Y-%m-%d')
|
||
return int(dt.timestamp())
|
||
except ValueError:
|
||
return 0
|
||
|
||
def update_stats(self, success: bool = True, sync_time: float = 0):
|
||
"""更新统计信息"""
|
||
if success:
|
||
self.sync_stats['success_count'] += 1
|
||
else:
|
||
self.sync_stats['error_count'] += 1
|
||
|
||
if sync_time > 0:
|
||
self.sync_stats['last_sync_time'] = sync_time
|
||
# 计算平均时间(滑动平均)
|
||
if self.sync_stats['avg_sync_time'] == 0:
|
||
self.sync_stats['avg_sync_time'] = sync_time
|
||
else:
|
||
self.sync_stats['avg_sync_time'] = (
|
||
self.sync_stats['avg_sync_time'] * 0.9 + sync_time * 0.1
|
||
)
|
||
|
||
def print_stats(self, sync_type: str = ""):
|
||
"""打印统计信息"""
|
||
stats = self.sync_stats
|
||
prefix = f"[{sync_type}] " if sync_type else ""
|
||
|
||
stats_str = (
|
||
f"{prefix}统计: 账号数={stats['total_accounts']}, "
|
||
f"成功={stats['success_count']}, 失败={stats['error_count']}, "
|
||
f"本次耗时={stats['last_sync_time']:.2f}s, "
|
||
f"平均耗时={stats['avg_sync_time']:.2f}s"
|
||
)
|
||
|
||
if stats['error_count'] > 0:
|
||
logger.warning(stats_str)
|
||
else:
|
||
logger.info(stats_str)
|
||
|
||
def reset_stats(self):
|
||
"""重置统计信息"""
|
||
self.sync_stats = {
|
||
'total_accounts': 0,
|
||
'success_count': 0,
|
||
'error_count': 0,
|
||
'last_sync_time': 0,
|
||
'avg_sync_time': 0
|
||
} |