""" 数据库模型定义 """ from database.connection import db from datetime import datetime, timezone, timedelta import json import logging # 北京时间时区(UTC+8) BEIJING_TZ = timezone(timedelta(hours=8)) def get_beijing_time(): """获取当前北京时间(UTC+8)""" return datetime.now(BEIJING_TZ).replace(tzinfo=None) logger = logging.getLogger(__name__) class TradingConfig: """交易配置模型""" @staticmethod def get_all(): """获取所有配置""" return db.execute_query( "SELECT * FROM trading_config ORDER BY category, config_key" ) @staticmethod def get(key): """获取单个配置""" return db.execute_one( "SELECT * FROM trading_config WHERE config_key = %s", (key,) ) @staticmethod def set(key, value, config_type, category, description=None): """设置配置""" value_str = TradingConfig._convert_to_string(value, config_type) db.execute_update( """INSERT INTO trading_config (config_key, config_value, config_type, category, description) VALUES (%s, %s, %s, %s, %s) ON DUPLICATE KEY UPDATE config_value = VALUES(config_value), config_type = VALUES(config_type), category = VALUES(category), description = VALUES(description), updated_at = CURRENT_TIMESTAMP""", (key, value_str, config_type, category, description) ) @staticmethod def get_value(key, default=None): """获取配置值(自动转换类型)""" result = TradingConfig.get(key) if result: return TradingConfig._convert_value(result['config_value'], result['config_type']) return default @staticmethod def _convert_value(value, config_type): """转换配置值类型""" if config_type == 'number': try: return float(value) if '.' in str(value) else int(value) except: return 0 elif config_type == 'boolean': return str(value).lower() in ('true', '1', 'yes', 'on') elif config_type == 'json': try: return json.loads(value) except: return {} return value @staticmethod def _convert_to_string(value, config_type): """转换值为字符串""" if config_type == 'json': return json.dumps(value, ensure_ascii=False) return str(value) class Trade: """交易记录模型""" @staticmethod def create(symbol, side, quantity, entry_price, leverage=10, entry_reason=None, entry_order_id=None): """创建交易记录(使用北京时间) Args: symbol: 交易对 side: 方向 quantity: 数量 entry_price: 入场价 leverage: 杠杆 entry_reason: 入场原因 entry_order_id: 币安开仓订单号(可选,用于对账) """ entry_time = get_beijing_time() db.execute_update( """INSERT INTO trades (symbol, side, quantity, entry_price, leverage, entry_reason, status, entry_time, entry_order_id) VALUES (%s, %s, %s, %s, %s, %s, 'open', %s, %s)""", (symbol, side, quantity, entry_price, leverage, entry_reason, entry_time, entry_order_id) ) return db.execute_one("SELECT LAST_INSERT_ID() as id")['id'] @staticmethod def update_exit(trade_id, exit_price, exit_reason, pnl, pnl_percent, exit_order_id=None): """更新平仓信息(使用北京时间) Args: trade_id: 交易记录ID exit_price: 出场价 exit_reason: 平仓原因 pnl: 盈亏 pnl_percent: 盈亏百分比 exit_order_id: 币安平仓订单号(可选,用于对账) 注意:如果 exit_order_id 已存在且属于其他交易记录,会跳过更新 exit_order_id 以避免唯一约束冲突 """ exit_time = get_beijing_time() # 如果提供了 exit_order_id,先检查是否已被其他交易记录使用 if exit_order_id is not None: try: existing_trade = Trade.get_by_exit_order_id(exit_order_id) if existing_trade and existing_trade['id'] != trade_id: # 如果 exit_order_id 已被其他交易记录使用,记录警告但不更新 exit_order_id logger.warning( f"交易记录 {trade_id} 的 exit_order_id {exit_order_id} 已被交易记录 {existing_trade['id']} 使用," f"跳过更新 exit_order_id,只更新其他字段" ) # 只更新其他字段,不更新 exit_order_id db.execute_update( """UPDATE trades SET exit_price = %s, exit_time = %s, exit_reason = %s, pnl = %s, pnl_percent = %s, status = 'closed' WHERE id = %s""", (exit_price, exit_time, exit_reason, pnl, pnl_percent, trade_id) ) return except Exception as e: # 如果查询失败,记录警告但继续正常更新 logger.warning(f"检查 exit_order_id {exit_order_id} 时出错: {e},继续正常更新") # 正常更新(包括 exit_order_id) try: db.execute_update( """UPDATE trades SET exit_price = %s, exit_time = %s, exit_reason = %s, pnl = %s, pnl_percent = %s, status = 'closed', exit_order_id = %s WHERE id = %s""", (exit_price, exit_time, exit_reason, pnl, pnl_percent, exit_order_id, trade_id) ) except Exception as e: # 如果更新失败(可能是唯一约束冲突),尝试不更新 exit_order_id error_str = str(e) if "Duplicate entry" in error_str and "exit_order_id" in error_str: logger.warning( f"更新交易记录 {trade_id} 时 exit_order_id {exit_order_id} 唯一约束冲突," f"尝试不更新 exit_order_id" ) # 只更新其他字段,不更新 exit_order_id db.execute_update( """UPDATE trades SET exit_price = %s, exit_time = %s, exit_reason = %s, pnl = %s, pnl_percent = %s, status = 'closed' WHERE id = %s""", (exit_price, exit_time, exit_reason, pnl, pnl_percent, trade_id) ) else: # 其他错误,重新抛出 raise @staticmethod def get_by_entry_order_id(entry_order_id): """根据开仓订单号获取交易记录""" return db.execute_one( "SELECT * FROM trades WHERE entry_order_id = %s", (entry_order_id,) ) @staticmethod def get_by_exit_order_id(exit_order_id): """根据平仓订单号获取交易记录""" return db.execute_one( "SELECT * FROM trades WHERE exit_order_id = %s", (exit_order_id,) ) @staticmethod def get_all(start_date=None, end_date=None, symbol=None, status=None, trade_type=None, exit_reason=None): """获取交易记录""" query = "SELECT * FROM trades WHERE 1=1" params = [] if start_date: query += " AND entry_time >= %s" params.append(start_date) if end_date: query += " AND entry_time <= %s" params.append(end_date) if symbol: query += " AND symbol = %s" params.append(symbol) if status: query += " AND status = %s" params.append(status) if trade_type: query += " AND side = %s" params.append(trade_type) if exit_reason: query += " AND exit_reason = %s" params.append(exit_reason) # 按平仓时间倒序排序,如果没有平仓时间则按入场时间倒序 # query += " ORDER BY COALESCE(exit_time, entry_time) DESC, entry_time DESC" query += " ORDER BY id DESC" return db.execute_query(query, params) @staticmethod def get_by_symbol(symbol, status='open'): """根据交易对获取持仓""" return db.execute_query( "SELECT * FROM trades WHERE symbol = %s AND status = %s", (symbol, status) ) class AccountSnapshot: """账户快照模型""" @staticmethod def create(total_balance, available_balance, total_position_value, total_pnl, open_positions): """创建账户快照(使用北京时间)""" snapshot_time = get_beijing_time() db.execute_update( """INSERT INTO account_snapshots (total_balance, available_balance, total_position_value, total_pnl, open_positions, snapshot_time) VALUES (%s, %s, %s, %s, %s, %s)""", (total_balance, available_balance, total_position_value, total_pnl, open_positions, snapshot_time) ) @staticmethod def get_recent(days=7): """获取最近的快照""" return db.execute_query( """SELECT * FROM account_snapshots WHERE snapshot_time >= DATE_SUB(NOW(), INTERVAL %s DAY) ORDER BY snapshot_time DESC""", (days,) ) class MarketScan: """市场扫描记录模型""" @staticmethod def create(symbols_scanned, symbols_found, top_symbols, scan_duration): """创建扫描记录(使用北京时间)""" scan_time = get_beijing_time() db.execute_update( """INSERT INTO market_scans (symbols_scanned, symbols_found, top_symbols, scan_duration, scan_time) VALUES (%s, %s, %s, %s, %s)""", (symbols_scanned, symbols_found, json.dumps(top_symbols), scan_duration, scan_time) ) @staticmethod def get_recent(limit=100): """获取最近的扫描记录""" return db.execute_query( "SELECT * FROM market_scans ORDER BY scan_time DESC LIMIT %s", (limit,) ) class TradingSignal: """交易信号模型""" @staticmethod def create(symbol, signal_direction, signal_strength, signal_reason, rsi=None, macd_histogram=None, market_regime=None): """创建交易信号(使用北京时间)""" signal_time = get_beijing_time() db.execute_update( """INSERT INTO trading_signals (symbol, signal_direction, signal_strength, signal_reason, rsi, macd_histogram, market_regime, signal_time) VALUES (%s, %s, %s, %s, %s, %s, %s, %s)""", (symbol, signal_direction, signal_strength, signal_reason, rsi, macd_histogram, market_regime, signal_time) ) @staticmethod def mark_executed(signal_id): """标记信号已执行""" db.execute_update( "UPDATE trading_signals SET executed = TRUE WHERE id = %s", (signal_id,) ) @staticmethod def get_recent(limit=100): """获取最近的信号""" return db.execute_query( "SELECT * FROM trading_signals ORDER BY signal_time DESC LIMIT %s", (limit,) ) class TradeRecommendation: """推荐交易对模型""" @staticmethod def create( symbol, direction, current_price, change_percent, recommendation_reason, signal_strength, market_regime=None, trend_4h=None, rsi=None, macd_histogram=None, bollinger_upper=None, bollinger_middle=None, bollinger_lower=None, ema20=None, ema50=None, ema20_4h=None, atr=None, suggested_stop_loss=None, suggested_take_profit_1=None, suggested_take_profit_2=None, suggested_position_percent=None, suggested_leverage=10, order_type='LIMIT', suggested_limit_price=None, volume_24h=None, volatility=None, notes=None ): """创建推荐记录(使用北京时间)""" recommendation_time = get_beijing_time() # 默认24小时后过期 expires_at = recommendation_time + timedelta(hours=24) # 检查字段是否存在(兼容旧数据库schema) try: db.execute_one("SELECT order_type FROM trade_recommendations LIMIT 1") has_order_fields = True except: has_order_fields = False if has_order_fields: db.execute_update( """INSERT INTO trade_recommendations (symbol, direction, recommendation_time, current_price, change_percent, recommendation_reason, signal_strength, market_regime, trend_4h, rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower, ema20, ema50, ema20_4h, atr, suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2, suggested_position_percent, suggested_leverage, order_type, suggested_limit_price, volume_24h, volatility, expires_at, notes) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)""", (symbol, direction, recommendation_time, current_price, change_percent, recommendation_reason, signal_strength, market_regime, trend_4h, rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower, ema20, ema50, ema20_4h, atr, suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2, suggested_position_percent, suggested_leverage, order_type, suggested_limit_price, volume_24h, volatility, expires_at, notes) ) else: # 兼容旧schema db.execute_update( """INSERT INTO trade_recommendations (symbol, direction, recommendation_time, current_price, change_percent, recommendation_reason, signal_strength, market_regime, trend_4h, rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower, ema20, ema50, ema20_4h, atr, suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2, suggested_position_percent, suggested_leverage, volume_24h, volatility, expires_at, notes) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)""", (symbol, direction, recommendation_time, current_price, change_percent, recommendation_reason, signal_strength, market_regime, trend_4h, rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower, ema20, ema50, ema20_4h, atr, suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2, suggested_position_percent, suggested_leverage, volume_24h, volatility, expires_at, notes) ) return db.execute_one("SELECT LAST_INSERT_ID() as id")['id'] @staticmethod def mark_executed(recommendation_id, trade_id=None, execution_result='success'): """标记推荐已执行""" executed_at = get_beijing_time() db.execute_update( """UPDATE trade_recommendations SET status = 'executed', executed_at = %s, execution_result = %s, execution_trade_id = %s WHERE id = %s""", (executed_at, execution_result, trade_id, recommendation_id) ) @staticmethod def mark_expired(recommendation_id): """标记推荐已过期""" db.execute_update( "UPDATE trade_recommendations SET status = 'expired' WHERE id = %s", (recommendation_id,) ) @staticmethod def mark_cancelled(recommendation_id, notes=None): """标记推荐已取消""" db.execute_update( "UPDATE trade_recommendations SET status = 'cancelled', notes = %s WHERE id = %s", (notes, recommendation_id) ) @staticmethod def get_all(status=None, direction=None, limit=100, start_date=None, end_date=None): """获取推荐记录""" query = "SELECT * FROM trade_recommendations WHERE 1=1" params = [] if status: query += " AND status = %s" params.append(status) if direction: query += " AND direction = %s" params.append(direction) if start_date: query += " AND recommendation_time >= %s" params.append(start_date) if end_date: query += " AND recommendation_time <= %s" params.append(end_date) query += " ORDER BY recommendation_time DESC, signal_strength DESC LIMIT %s" params.append(limit) return db.execute_query(query, params) @staticmethod def get_active(): """获取当前有效的推荐(未过期、未执行、未取消) 同一交易对只返回最新的推荐(去重) """ return db.execute_query( """SELECT t1.* FROM trade_recommendations t1 INNER JOIN ( SELECT symbol, MAX(recommendation_time) as max_time FROM trade_recommendations WHERE status = 'active' AND (expires_at IS NULL OR expires_at > NOW()) GROUP BY symbol ) t2 ON t1.symbol = t2.symbol AND t1.recommendation_time = t2.max_time WHERE t1.status = 'active' AND (t1.expires_at IS NULL OR t1.expires_at > NOW()) ORDER BY t1.signal_strength DESC, t1.recommendation_time DESC""" ) @staticmethod def get_by_id(recommendation_id): """根据ID获取推荐""" return db.execute_one( "SELECT * FROM trade_recommendations WHERE id = %s", (recommendation_id,) ) @staticmethod def get_by_symbol(symbol, limit=10): """根据交易对获取推荐记录""" return db.execute_query( """SELECT * FROM trade_recommendations WHERE symbol = %s ORDER BY recommendation_time DESC LIMIT %s""", (symbol, limit) )