""" 数据库模型定义 """ from database.connection import db from datetime import datetime, timezone, timedelta import json import logging # 北京时间时区(UTC+8) BEIJING_TZ = timezone(timedelta(hours=8)) def get_beijing_time(): """获取当前北京时间(UTC+8)的Unix时间戳(秒)""" return int(datetime.now(BEIJING_TZ).timestamp()) logger = logging.getLogger(__name__) class TradingConfig: """交易配置模型""" @staticmethod def get_all(): """获取所有配置""" return db.execute_query( "SELECT * FROM trading_config ORDER BY category, config_key" ) @staticmethod def get(key): """获取单个配置""" return db.execute_one( "SELECT * FROM trading_config WHERE config_key = %s", (key,) ) @staticmethod def set(key, value, config_type, category, description=None): """设置配置""" value_str = TradingConfig._convert_to_string(value, config_type) db.execute_update( """INSERT INTO trading_config (config_key, config_value, config_type, category, description) VALUES (%s, %s, %s, %s, %s) ON DUPLICATE KEY UPDATE config_value = VALUES(config_value), config_type = VALUES(config_type), category = VALUES(category), description = VALUES(description), updated_at = CURRENT_TIMESTAMP""", (key, value_str, config_type, category, description) ) @staticmethod def get_value(key, default=None): """获取配置值(自动转换类型)""" result = TradingConfig.get(key) if result: return TradingConfig._convert_value(result['config_value'], result['config_type']) return default @staticmethod def _convert_value(value, config_type): """转换配置值类型""" if config_type == 'number': try: return float(value) if '.' in str(value) else int(value) except: return 0 elif config_type == 'boolean': return str(value).lower() in ('true', '1', 'yes', 'on') elif config_type == 'json': try: return json.loads(value) except: return {} return value @staticmethod def _convert_to_string(value, config_type): """转换值为字符串""" if config_type == 'json': return json.dumps(value, ensure_ascii=False) return str(value) class Trade: """交易记录模型""" @staticmethod def create( symbol, side, quantity, entry_price, leverage=10, entry_reason=None, entry_order_id=None, stop_loss_price=None, take_profit_price=None, take_profit_1=None, take_profit_2=None, atr=None, notional_usdt=None, margin_usdt=None, ): """创建交易记录(使用北京时间) Args: symbol: 交易对 side: 方向 quantity: 数量 entry_price: 入场价 leverage: 杠杆 entry_reason: 入场原因 entry_order_id: 币安开仓订单号(可选,用于对账) stop_loss_price: 实际使用的止损价格(考虑了ATR等动态计算) take_profit_price: 实际使用的止盈价格(考虑了ATR等动态计算) take_profit_1: 第一目标止盈价(可选) take_profit_2: 第二目标止盈价(可选) atr: 开仓时使用的ATR值(可选) notional_usdt: 名义下单量(USDT,可选) margin_usdt: 保证金(USDT,可选) """ entry_time = get_beijing_time() # 自动计算 notional/margin(若调用方没传) try: if notional_usdt is None and quantity is not None and entry_price is not None: notional_usdt = float(quantity) * float(entry_price) except Exception: pass try: if margin_usdt is None and notional_usdt is not None: lv = float(leverage) if leverage else 0 margin_usdt = (float(notional_usdt) / lv) if lv and lv > 0 else float(notional_usdt) except Exception: pass def _has_column(col: str) -> bool: try: db.execute_one(f"SELECT {col} FROM trades LIMIT 1") return True except Exception: return False # 动态构建 INSERT(兼容不同schema) columns = ["symbol", "side", "quantity", "entry_price", "leverage", "entry_reason", "status", "entry_time"] values = [symbol, side, quantity, entry_price, leverage, entry_reason, "open", entry_time] if _has_column("entry_order_id"): columns.append("entry_order_id") values.append(entry_order_id) if _has_column("notional_usdt"): columns.append("notional_usdt") values.append(notional_usdt) if _has_column("margin_usdt"): columns.append("margin_usdt") values.append(margin_usdt) if _has_column("atr"): columns.append("atr") values.append(atr) if _has_column("stop_loss_price"): columns.append("stop_loss_price") values.append(stop_loss_price) if _has_column("take_profit_price"): columns.append("take_profit_price") values.append(take_profit_price) if _has_column("take_profit_1"): columns.append("take_profit_1") values.append(take_profit_1) if _has_column("take_profit_2"): columns.append("take_profit_2") values.append(take_profit_2) placeholders = ", ".join(["%s"] * len(columns)) sql = f"INSERT INTO trades ({', '.join(columns)}) VALUES ({placeholders})" db.execute_update(sql, tuple(values)) return db.execute_one("SELECT LAST_INSERT_ID() as id")['id'] @staticmethod def update_exit(trade_id, exit_price, exit_reason, pnl, pnl_percent, exit_order_id=None, strategy_type=None, duration_minutes=None): """更新平仓信息(使用北京时间) Args: trade_id: 交易记录ID exit_price: 出场价 exit_reason: 平仓原因 pnl: 盈亏 pnl_percent: 盈亏百分比 exit_order_id: 币安平仓订单号(可选,用于对账) 注意:如果 exit_order_id 已存在且属于其他交易记录,会跳过更新 exit_order_id 以避免唯一约束冲突 """ exit_time = get_beijing_time() # 如果提供了 exit_order_id,先检查是否已被其他交易记录使用 if exit_order_id is not None: try: existing_trade = Trade.get_by_exit_order_id(exit_order_id) if existing_trade: if existing_trade['id'] == trade_id: # 如果 exit_order_id 属于当前交易记录,说明已经更新过了,直接返回 logger.debug(f"交易记录 {trade_id} 的 exit_order_id {exit_order_id} 已存在,跳过更新") return else: # 如果 exit_order_id 已被其他交易记录使用,记录警告但不更新 exit_order_id logger.warning( f"交易记录 {trade_id} 的 exit_order_id {exit_order_id} 已被交易记录 {existing_trade['id']} 使用," f"跳过更新 exit_order_id,只更新其他字段" ) # 只更新其他字段,不更新 exit_order_id update_fields = [ "exit_price = %s", "exit_time = %s", "exit_reason = %s", "pnl = %s", "pnl_percent = %s", "status = 'closed'" ] update_values = [exit_price, exit_time, exit_reason, pnl, pnl_percent] if strategy_type is not None: update_fields.append("strategy_type = %s") update_values.append(strategy_type) if duration_minutes is not None: update_fields.append("duration_minutes = %s") update_values.append(duration_minutes) update_values.append(trade_id) db.execute_update( f"UPDATE trades SET {', '.join(update_fields)} WHERE id = %s", tuple(update_values) ) return except Exception as e: # 如果查询失败,记录警告但继续正常更新 logger.warning(f"检查 exit_order_id {exit_order_id} 时出错: {e},继续正常更新") # 正常更新(包括 exit_order_id) try: update_fields = [ "exit_price = %s", "exit_time = %s", "exit_reason = %s", "pnl = %s", "pnl_percent = %s", "status = 'closed'", "exit_order_id = %s" ] update_values = [exit_price, exit_time, exit_reason, pnl, pnl_percent, exit_order_id] if strategy_type is not None: update_fields.append("strategy_type = %s") update_values.append(strategy_type) if duration_minutes is not None: update_fields.append("duration_minutes = %s") update_values.append(duration_minutes) update_values.append(trade_id) db.execute_update( f"UPDATE trades SET {', '.join(update_fields)} WHERE id = %s", tuple(update_values) ) except Exception as e: # 如果更新失败(可能是唯一约束冲突),尝试不更新 exit_order_id error_str = str(e) if "Duplicate entry" in error_str and "exit_order_id" in error_str: logger.warning( f"更新交易记录 {trade_id} 时 exit_order_id {exit_order_id} 唯一约束冲突," f"尝试不更新 exit_order_id" ) # 只更新其他字段,不更新 exit_order_id update_fields = [ "exit_price = %s", "exit_time = %s", "exit_reason = %s", "pnl = %s", "pnl_percent = %s", "status = 'closed'" ] update_values = [exit_price, exit_time, exit_reason, pnl, pnl_percent] if strategy_type is not None: update_fields.append("strategy_type = %s") update_values.append(strategy_type) if duration_minutes is not None: update_fields.append("duration_minutes = %s") update_values.append(duration_minutes) update_values.append(trade_id) db.execute_update( f"UPDATE trades SET {', '.join(update_fields)} WHERE id = %s", tuple(update_values) ) else: # 其他错误,重新抛出 raise @staticmethod def get_by_entry_order_id(entry_order_id): """根据开仓订单号获取交易记录""" return db.execute_one( "SELECT * FROM trades WHERE entry_order_id = %s", (entry_order_id,) ) @staticmethod def get_by_exit_order_id(exit_order_id): """根据平仓订单号获取交易记录""" return db.execute_one( "SELECT * FROM trades WHERE exit_order_id = %s", (exit_order_id,) ) @staticmethod def get_all(start_timestamp=None, end_timestamp=None, symbol=None, status=None, trade_type=None, exit_reason=None): """获取交易记录 Args: start_timestamp: 开始时间(Unix时间戳秒数,可选) end_timestamp: 结束时间(Unix时间戳秒数,可选) symbol: 交易对(可选) status: 状态(可选) trade_type: 交易类型(可选) exit_reason: 平仓原因(可选) """ query = "SELECT * FROM trades WHERE 1=1" params = [] if start_timestamp is not None: query += " AND created_at >= %s" params.append(start_timestamp) if end_timestamp is not None: query += " AND created_at <= %s" params.append(end_timestamp) if symbol: query += " AND symbol = %s" params.append(symbol) if status: query += " AND status = %s" params.append(status) if trade_type: query += " AND side = %s" params.append(trade_type) if exit_reason: query += " AND exit_reason = %s" params.append(exit_reason) # 按平仓时间倒序排序,如果没有平仓时间则按入场时间倒序 # query += " ORDER BY COALESCE(exit_time, entry_time) DESC, entry_time DESC" query += " ORDER BY id DESC" logger.info(f"查询交易记录: {query}, {params}") result = db.execute_query(query, params) return result @staticmethod def get_by_symbol(symbol, status='open'): """根据交易对获取持仓""" return db.execute_query( "SELECT * FROM trades WHERE symbol = %s AND status = %s", (symbol, status) ) class AccountSnapshot: """账户快照模型""" @staticmethod def create(total_balance, available_balance, total_position_value, total_pnl, open_positions): """创建账户快照(使用北京时间)""" snapshot_time = get_beijing_time() db.execute_update( """INSERT INTO account_snapshots (total_balance, available_balance, total_position_value, total_pnl, open_positions, snapshot_time) VALUES (%s, %s, %s, %s, %s, %s)""", (total_balance, available_balance, total_position_value, total_pnl, open_positions, snapshot_time) ) @staticmethod def get_recent(days=7): """获取最近的快照""" return db.execute_query( """SELECT * FROM account_snapshots WHERE snapshot_time >= DATE_SUB(NOW(), INTERVAL %s DAY) ORDER BY snapshot_time DESC""", (days,) ) class MarketScan: """市场扫描记录模型""" @staticmethod def create(symbols_scanned, symbols_found, top_symbols, scan_duration): """创建扫描记录(使用北京时间)""" scan_time = get_beijing_time() db.execute_update( """INSERT INTO market_scans (symbols_scanned, symbols_found, top_symbols, scan_duration, scan_time) VALUES (%s, %s, %s, %s, %s)""", (symbols_scanned, symbols_found, json.dumps(top_symbols), scan_duration, scan_time) ) @staticmethod def get_recent(limit=100): """获取最近的扫描记录""" return db.execute_query( "SELECT * FROM market_scans ORDER BY scan_time DESC LIMIT %s", (limit,) ) class TradingSignal: """交易信号模型""" @staticmethod def create(symbol, signal_direction, signal_strength, signal_reason, rsi=None, macd_histogram=None, market_regime=None): """创建交易信号(使用北京时间)""" signal_time = get_beijing_time() db.execute_update( """INSERT INTO trading_signals (symbol, signal_direction, signal_strength, signal_reason, rsi, macd_histogram, market_regime, signal_time) VALUES (%s, %s, %s, %s, %s, %s, %s, %s)""", (symbol, signal_direction, signal_strength, signal_reason, rsi, macd_histogram, market_regime, signal_time) ) @staticmethod def mark_executed(signal_id): """标记信号已执行""" db.execute_update( "UPDATE trading_signals SET executed = TRUE WHERE id = %s", (signal_id,) ) @staticmethod def get_recent(limit=100): """获取最近的信号""" return db.execute_query( "SELECT * FROM trading_signals ORDER BY signal_time DESC LIMIT %s", (limit,) ) class TradeRecommendation: """推荐交易对模型""" @staticmethod def create( symbol, direction, current_price, change_percent, recommendation_reason, signal_strength, market_regime=None, trend_4h=None, rsi=None, macd_histogram=None, bollinger_upper=None, bollinger_middle=None, bollinger_lower=None, ema20=None, ema50=None, ema20_4h=None, atr=None, suggested_stop_loss=None, suggested_take_profit_1=None, suggested_take_profit_2=None, suggested_position_percent=None, suggested_leverage=10, order_type='LIMIT', suggested_limit_price=None, volume_24h=None, volatility=None, notes=None, user_guide=None, recommendation_category=None, risk_level=None, expected_hold_time=None, trading_tutorial=None, max_hold_days=3 ): """创建推荐记录(使用北京时间)""" recommendation_time = get_beijing_time() # 默认24小时后过期 expires_at = recommendation_time + timedelta(hours=24) # 检查字段是否存在(兼容旧数据库schema) try: db.execute_one("SELECT order_type FROM trade_recommendations LIMIT 1") has_order_fields = True except: has_order_fields = False # 检查是否有 user_guide 字段 try: db.execute_one("SELECT user_guide FROM trade_recommendations LIMIT 1") has_user_guide_fields = True except: has_user_guide_fields = False if has_user_guide_fields: # 包含所有新字段(user_guide等) db.execute_update( """INSERT INTO trade_recommendations (symbol, direction, recommendation_time, current_price, change_percent, recommendation_reason, signal_strength, market_regime, trend_4h, rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower, ema20, ema50, ema20_4h, atr, suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2, suggested_position_percent, suggested_leverage, order_type, suggested_limit_price, volume_24h, volatility, expires_at, notes, user_guide, recommendation_category, risk_level, expected_hold_time, trading_tutorial, max_hold_days) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)""", (symbol, direction, recommendation_time, current_price, change_percent, recommendation_reason, signal_strength, market_regime, trend_4h, rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower, ema20, ema50, ema20_4h, atr, suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2, suggested_position_percent, suggested_leverage, order_type, suggested_limit_price, volume_24h, volatility, expires_at, notes, user_guide, recommendation_category, risk_level, expected_hold_time, trading_tutorial, max_hold_days) ) elif has_order_fields: # 只有 order_type 字段,没有 user_guide 字段 db.execute_update( """INSERT INTO trade_recommendations (symbol, direction, recommendation_time, current_price, change_percent, recommendation_reason, signal_strength, market_regime, trend_4h, rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower, ema20, ema50, ema20_4h, atr, suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2, suggested_position_percent, suggested_leverage, order_type, suggested_limit_price, volume_24h, volatility, expires_at, notes) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)""", (symbol, direction, recommendation_time, current_price, change_percent, recommendation_reason, signal_strength, market_regime, trend_4h, rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower, ema20, ema50, ema20_4h, atr, suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2, suggested_position_percent, suggested_leverage, order_type, suggested_limit_price, volume_24h, volatility, expires_at, notes) ) else: # 兼容旧schema db.execute_update( """INSERT INTO trade_recommendations (symbol, direction, recommendation_time, current_price, change_percent, recommendation_reason, signal_strength, market_regime, trend_4h, rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower, ema20, ema50, ema20_4h, atr, suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2, suggested_position_percent, suggested_leverage, volume_24h, volatility, expires_at, notes) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)""", (symbol, direction, recommendation_time, current_price, change_percent, recommendation_reason, signal_strength, market_regime, trend_4h, rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower, ema20, ema50, ema20_4h, atr, suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2, suggested_position_percent, suggested_leverage, volume_24h, volatility, expires_at, notes) ) return db.execute_one("SELECT LAST_INSERT_ID() as id")['id'] @staticmethod def mark_executed(recommendation_id, trade_id=None, execution_result='success'): """标记推荐已执行""" executed_at = get_beijing_time() db.execute_update( """UPDATE trade_recommendations SET status = 'executed', executed_at = %s, execution_result = %s, execution_trade_id = %s WHERE id = %s""", (executed_at, execution_result, trade_id, recommendation_id) ) @staticmethod def mark_expired(recommendation_id): """标记推荐已过期""" db.execute_update( "UPDATE trade_recommendations SET status = 'expired' WHERE id = %s", (recommendation_id,) ) @staticmethod def mark_cancelled(recommendation_id, notes=None): """标记推荐已取消""" db.execute_update( "UPDATE trade_recommendations SET status = 'cancelled', notes = %s WHERE id = %s", (notes, recommendation_id) ) @staticmethod def get_all(status=None, direction=None, limit=100, start_date=None, end_date=None): """获取推荐记录""" query = "SELECT * FROM trade_recommendations WHERE 1=1" params = [] if status: query += " AND status = %s" params.append(status) if direction: query += " AND direction = %s" params.append(direction) if start_date: query += " AND recommendation_time >= %s" params.append(start_date) if end_date: query += " AND recommendation_time <= %s" params.append(end_date) query += " ORDER BY recommendation_time DESC, signal_strength DESC LIMIT %s" params.append(limit) return db.execute_query(query, params) @staticmethod def get_active(): """获取当前有效的推荐(未过期、未执行、未取消) 同一交易对只返回最新的推荐(去重) """ return db.execute_query( """SELECT t1.* FROM trade_recommendations t1 INNER JOIN ( SELECT symbol, MAX(recommendation_time) as max_time FROM trade_recommendations WHERE status = 'active' AND (expires_at IS NULL OR expires_at > NOW()) GROUP BY symbol ) t2 ON t1.symbol = t2.symbol AND t1.recommendation_time = t2.max_time WHERE t1.status = 'active' AND (t1.expires_at IS NULL OR t1.expires_at > NOW()) ORDER BY t1.signal_strength DESC, t1.recommendation_time DESC""" ) @staticmethod def get_by_id(recommendation_id): """根据ID获取推荐""" return db.execute_one( "SELECT * FROM trade_recommendations WHERE id = %s", (recommendation_id,) ) @staticmethod def get_by_symbol(symbol, limit=10): """根据交易对获取推荐记录""" return db.execute_query( """SELECT * FROM trade_recommendations WHERE symbol = %s ORDER BY recommendation_time DESC LIMIT %s""", (symbol, limit) )