""" 数据库模型定义 """ from database.connection import db from datetime import datetime, timezone, timedelta import json import logging import os # 北京时间时区(UTC+8) BEIJING_TZ = timezone(timedelta(hours=8)) def get_beijing_time(): """获取当前北京时间(UTC+8)的Unix时间戳(秒)""" return int(datetime.now(BEIJING_TZ).timestamp()) logger = logging.getLogger(__name__) def _resolve_default_account_id() -> int: """ 默认账号ID: - trading_system 多进程:每个进程可通过 ATS_ACCOUNT_ID 指定自己的 account_id - backend:未传 account_id 时默认走 1(兼容单账号) """ for k in ("ATS_ACCOUNT_ID", "ACCOUNT_ID", "ATS_DEFAULT_ACCOUNT_ID"): v = (os.getenv(k, "") or "").strip() if v: try: return int(v) except Exception: continue return 1 DEFAULT_ACCOUNT_ID = _resolve_default_account_id() def _table_has_column(table: str, col: str) -> bool: try: db.execute_one(f"SELECT {col} FROM {table} LIMIT 1") return True except Exception: return False class Account: """ 账号模型(多账号) - API Key/Secret 建议加密存储在 accounts 表中,而不是 trading_config """ @staticmethod def get(account_id: int): return db.execute_one("SELECT * FROM accounts WHERE id = %s", (int(account_id),)) @staticmethod def list_all(): return db.execute_query("SELECT id, name, status, created_at, updated_at FROM accounts ORDER BY id ASC") @staticmethod def create(name: str, api_key: str = "", api_secret: str = "", use_testnet: bool = False, status: str = "active"): from security.crypto import encrypt_str # 延迟导入,避免无依赖时直接崩 api_key_enc = encrypt_str(api_key or "") api_secret_enc = encrypt_str(api_secret or "") db.execute_update( """INSERT INTO accounts (name, status, api_key_enc, api_secret_enc, use_testnet) VALUES (%s, %s, %s, %s, %s)""", (name, status, api_key_enc, api_secret_enc, bool(use_testnet)), ) return db.execute_one("SELECT LAST_INSERT_ID() as id")["id"] @staticmethod def update_credentials(account_id: int, api_key: str = None, api_secret: str = None, use_testnet: bool = None): from security.crypto import encrypt_str # 延迟导入 fields = [] params = [] if api_key is not None: fields.append("api_key_enc = %s") params.append(encrypt_str(api_key)) if api_secret is not None: fields.append("api_secret_enc = %s") params.append(encrypt_str(api_secret)) if use_testnet is not None: fields.append("use_testnet = %s") params.append(bool(use_testnet)) if not fields: return params.append(int(account_id)) db.execute_update(f"UPDATE accounts SET {', '.join(fields)} WHERE id = %s", tuple(params)) @staticmethod def get_credentials(account_id: int): """ 返回 (api_key, api_secret, use_testnet);密文字段会自动解密。 若未配置 master key 且库里是明文,仍可工作(但不安全)。 """ row = Account.get(account_id) if not row: return "", "", False try: from security.crypto import decrypt_str api_key = decrypt_str(row.get("api_key_enc") or "") api_secret = decrypt_str(row.get("api_secret_enc") or "") except Exception: # 兼容:无 cryptography 或未配 master key 时,先按明文兜底 api_key = row.get("api_key_enc") or "" api_secret = row.get("api_secret_enc") or "" use_testnet = bool(row.get("use_testnet") or False) return api_key, api_secret, use_testnet class User: """登录用户(管理员/普通用户)""" @staticmethod def get_by_username(username: str): return db.execute_one("SELECT * FROM users WHERE username = %s", (str(username),)) @staticmethod def get_by_id(user_id: int): return db.execute_one("SELECT * FROM users WHERE id = %s", (int(user_id),)) @staticmethod def list_all(): return db.execute_query("SELECT id, username, role, status, created_at, updated_at FROM users ORDER BY id ASC") @staticmethod def create(username: str, password_hash: str, role: str = "user", status: str = "active"): db.execute_update( "INSERT INTO users (username, password_hash, role, status) VALUES (%s, %s, %s, %s)", (username, password_hash, role, status), ) return db.execute_one("SELECT LAST_INSERT_ID() as id")["id"] @staticmethod def set_password(user_id: int, password_hash: str): db.execute_update("UPDATE users SET password_hash = %s WHERE id = %s", (password_hash, int(user_id))) @staticmethod def set_status(user_id: int, status: str): db.execute_update("UPDATE users SET status = %s WHERE id = %s", (status, int(user_id))) @staticmethod def set_role(user_id: int, role: str): db.execute_update("UPDATE users SET role = %s WHERE id = %s", (role, int(user_id))) class UserAccountMembership: """用户-交易账号授权关系""" @staticmethod def add(user_id: int, account_id: int, role: str = "viewer"): db.execute_update( """INSERT INTO user_account_memberships (user_id, account_id, role) VALUES (%s, %s, %s) ON DUPLICATE KEY UPDATE role = VALUES(role)""", (int(user_id), int(account_id), role), ) @staticmethod def remove(user_id: int, account_id: int): db.execute_update( "DELETE FROM user_account_memberships WHERE user_id = %s AND account_id = %s", (int(user_id), int(account_id)), ) @staticmethod def list_for_user(user_id: int): return db.execute_query( "SELECT * FROM user_account_memberships WHERE user_id = %s ORDER BY account_id ASC", (int(user_id),), ) @staticmethod def list_for_account(account_id: int): return db.execute_query( "SELECT * FROM user_account_memberships WHERE account_id = %s ORDER BY user_id ASC", (int(account_id),), ) @staticmethod def has_access(user_id: int, account_id: int) -> bool: row = db.execute_one( "SELECT 1 as ok FROM user_account_memberships WHERE user_id = %s AND account_id = %s", (int(user_id), int(account_id)), ) return bool(row) @staticmethod def get_role(user_id: int, account_id: int) -> str: row = db.execute_one( "SELECT role FROM user_account_memberships WHERE user_id = %s AND account_id = %s", (int(user_id), int(account_id)), ) return (row.get("role") if isinstance(row, dict) else None) or "" class TradingConfig: """交易配置模型""" @staticmethod def get_all(account_id: int = None): """获取所有配置""" aid = int(account_id or DEFAULT_ACCOUNT_ID) if _table_has_column("trading_config", "account_id"): return db.execute_query( "SELECT * FROM trading_config WHERE account_id = %s ORDER BY category, config_key", (aid,), ) return db.execute_query("SELECT * FROM trading_config ORDER BY category, config_key") @staticmethod def get(key, account_id: int = None): """获取单个配置""" aid = int(account_id or DEFAULT_ACCOUNT_ID) if _table_has_column("trading_config", "account_id"): return db.execute_one( "SELECT * FROM trading_config WHERE account_id = %s AND config_key = %s", (aid, key), ) return db.execute_one("SELECT * FROM trading_config WHERE config_key = %s", (key,)) @staticmethod def set(key, value, config_type, category, description=None, account_id: int = None): """设置配置""" value_str = TradingConfig._convert_to_string(value, config_type) aid = int(account_id or DEFAULT_ACCOUNT_ID) if _table_has_column("trading_config", "account_id"): db.execute_update( """INSERT INTO trading_config (account_id, config_key, config_value, config_type, category, description) VALUES (%s, %s, %s, %s, %s, %s) ON DUPLICATE KEY UPDATE config_value = VALUES(config_value), config_type = VALUES(config_type), category = VALUES(category), description = VALUES(description), updated_at = CURRENT_TIMESTAMP""", (aid, key, value_str, config_type, category, description), ) else: db.execute_update( """INSERT INTO trading_config (config_key, config_value, config_type, category, description) VALUES (%s, %s, %s, %s, %s) ON DUPLICATE KEY UPDATE config_value = VALUES(config_value), config_type = VALUES(config_type), category = VALUES(category), description = VALUES(description), updated_at = CURRENT_TIMESTAMP""", (key, value_str, config_type, category, description), ) @staticmethod def get_value(key, default=None, account_id: int = None): """获取配置值(自动转换类型)""" result = TradingConfig.get(key, account_id=account_id) if result: return TradingConfig._convert_value(result['config_value'], result['config_type']) return default @staticmethod def _convert_value(value, config_type): """转换配置值类型""" if config_type == 'number': try: return float(value) if '.' in str(value) else int(value) except: return 0 elif config_type == 'boolean': return str(value).lower() in ('true', '1', 'yes', 'on') elif config_type == 'json': try: return json.loads(value) except: return {} return value @staticmethod def _convert_to_string(value, config_type): """转换值为字符串""" if config_type == 'json': return json.dumps(value, ensure_ascii=False) return str(value) class Trade: """交易记录模型""" @staticmethod def create( symbol, side, quantity, entry_price, leverage=10, entry_reason=None, entry_order_id=None, stop_loss_price=None, take_profit_price=None, take_profit_1=None, take_profit_2=None, atr=None, notional_usdt=None, margin_usdt=None, account_id: int = None, ): """创建交易记录(使用北京时间) Args: symbol: 交易对 side: 方向 quantity: 数量 entry_price: 入场价 leverage: 杠杆 entry_reason: 入场原因 entry_order_id: 币安开仓订单号(可选,用于对账) stop_loss_price: 实际使用的止损价格(考虑了ATR等动态计算) take_profit_price: 实际使用的止盈价格(考虑了ATR等动态计算) take_profit_1: 第一目标止盈价(可选) take_profit_2: 第二目标止盈价(可选) atr: 开仓时使用的ATR值(可选) notional_usdt: 名义下单量(USDT,可选) margin_usdt: 保证金(USDT,可选) """ entry_time = get_beijing_time() # 自动计算 notional/margin(若调用方没传) try: if notional_usdt is None and quantity is not None and entry_price is not None: notional_usdt = float(quantity) * float(entry_price) except Exception: pass try: if margin_usdt is None and notional_usdt is not None: lv = float(leverage) if leverage else 0 margin_usdt = (float(notional_usdt) / lv) if lv and lv > 0 else float(notional_usdt) except Exception: pass def _has_column(col: str) -> bool: try: db.execute_one(f"SELECT {col} FROM trades LIMIT 1") return True except Exception: return False # 动态构建 INSERT(兼容不同schema) columns = ["symbol", "side", "quantity", "entry_price", "leverage", "entry_reason", "status", "entry_time"] values = [symbol, side, quantity, entry_price, leverage, entry_reason, "open", entry_time] if _has_column("account_id"): columns.insert(0, "account_id") values.insert(0, int(account_id or DEFAULT_ACCOUNT_ID)) if _has_column("entry_order_id"): columns.append("entry_order_id") values.append(entry_order_id) if _has_column("notional_usdt"): columns.append("notional_usdt") values.append(notional_usdt) if _has_column("margin_usdt"): columns.append("margin_usdt") values.append(margin_usdt) if _has_column("atr"): columns.append("atr") values.append(atr) if _has_column("stop_loss_price"): columns.append("stop_loss_price") values.append(stop_loss_price) if _has_column("take_profit_price"): columns.append("take_profit_price") values.append(take_profit_price) if _has_column("take_profit_1"): columns.append("take_profit_1") values.append(take_profit_1) if _has_column("take_profit_2"): columns.append("take_profit_2") values.append(take_profit_2) placeholders = ", ".join(["%s"] * len(columns)) sql = f"INSERT INTO trades ({', '.join(columns)}) VALUES ({placeholders})" db.execute_update(sql, tuple(values)) return db.execute_one("SELECT LAST_INSERT_ID() as id")['id'] @staticmethod def update_exit( trade_id, exit_price, exit_reason, pnl, pnl_percent, exit_order_id=None, strategy_type=None, duration_minutes=None, exit_time_ts=None, ): """更新平仓信息(使用北京时间) Args: trade_id: 交易记录ID exit_price: 出场价 exit_reason: 平仓原因 pnl: 盈亏 pnl_percent: 盈亏百分比 exit_order_id: 币安平仓订单号(可选,用于对账) 注意:如果 exit_order_id 已存在且属于其他交易记录,会跳过更新 exit_order_id 以避免唯一约束冲突 """ # exit_time_ts: 允许外部传入“真实成交时间”(Unix秒)以便统计持仓时长更准确 try: exit_time = int(exit_time_ts) if exit_time_ts is not None else get_beijing_time() except Exception: exit_time = get_beijing_time() # 如果提供了 exit_order_id,先检查是否已被其他交易记录使用 if exit_order_id is not None: try: existing_trade = Trade.get_by_exit_order_id(exit_order_id) if existing_trade: if existing_trade['id'] == trade_id: # exit_order_id 属于当前交易记录:允许继续更新(比如补写 exit_reason / exit_time / duration) # 不需要提前 return logger.debug( f"交易记录 {trade_id} 的 exit_order_id {exit_order_id} 已存在,将继续更新其他字段" ) else: # 如果 exit_order_id 已被其他交易记录使用,记录警告但不更新 exit_order_id logger.warning( f"交易记录 {trade_id} 的 exit_order_id {exit_order_id} 已被交易记录 {existing_trade['id']} 使用," f"跳过更新 exit_order_id,只更新其他字段" ) # 只更新其他字段,不更新 exit_order_id update_fields = [ "exit_price = %s", "exit_time = %s", "exit_reason = %s", "pnl = %s", "pnl_percent = %s", "status = 'closed'" ] update_values = [exit_price, exit_time, exit_reason, pnl, pnl_percent] if strategy_type is not None: update_fields.append("strategy_type = %s") update_values.append(strategy_type) if duration_minutes is not None: update_fields.append("duration_minutes = %s") update_values.append(duration_minutes) update_values.append(trade_id) db.execute_update( f"UPDATE trades SET {', '.join(update_fields)} WHERE id = %s", tuple(update_values) ) return except Exception as e: # 如果查询失败,记录警告但继续正常更新 logger.warning(f"检查 exit_order_id {exit_order_id} 时出错: {e},继续正常更新") # 正常更新(包括 exit_order_id) try: update_fields = [ "exit_price = %s", "exit_time = %s", "exit_reason = %s", "pnl = %s", "pnl_percent = %s", "status = 'closed'", "exit_order_id = %s" ] update_values = [exit_price, exit_time, exit_reason, pnl, pnl_percent, exit_order_id] if strategy_type is not None: update_fields.append("strategy_type = %s") update_values.append(strategy_type) if duration_minutes is not None: update_fields.append("duration_minutes = %s") update_values.append(duration_minutes) update_values.append(trade_id) db.execute_update( f"UPDATE trades SET {', '.join(update_fields)} WHERE id = %s", tuple(update_values) ) except Exception as e: # 如果更新失败(可能是唯一约束冲突),尝试不更新 exit_order_id error_str = str(e) if "Duplicate entry" in error_str and "exit_order_id" in error_str: logger.warning( f"更新交易记录 {trade_id} 时 exit_order_id {exit_order_id} 唯一约束冲突," f"尝试不更新 exit_order_id" ) # 只更新其他字段,不更新 exit_order_id update_fields = [ "exit_price = %s", "exit_time = %s", "exit_reason = %s", "pnl = %s", "pnl_percent = %s", "status = 'closed'" ] update_values = [exit_price, exit_time, exit_reason, pnl, pnl_percent] if strategy_type is not None: update_fields.append("strategy_type = %s") update_values.append(strategy_type) if duration_minutes is not None: update_fields.append("duration_minutes = %s") update_values.append(duration_minutes) update_values.append(trade_id) db.execute_update( f"UPDATE trades SET {', '.join(update_fields)} WHERE id = %s", tuple(update_values) ) else: # 其他错误,重新抛出 raise @staticmethod def get_by_entry_order_id(entry_order_id): """根据开仓订单号获取交易记录""" return db.execute_one( "SELECT * FROM trades WHERE entry_order_id = %s", (entry_order_id,) ) @staticmethod def get_by_exit_order_id(exit_order_id): """根据平仓订单号获取交易记录""" return db.execute_one( "SELECT * FROM trades WHERE exit_order_id = %s", (exit_order_id,) ) @staticmethod def get_all(start_timestamp=None, end_timestamp=None, symbol=None, status=None, trade_type=None, exit_reason=None, account_id: int = None): """获取交易记录 Args: start_timestamp: 开始时间(Unix时间戳秒数,可选) end_timestamp: 结束时间(Unix时间戳秒数,可选) symbol: 交易对(可选) status: 状态(可选) trade_type: 交易类型(可选) exit_reason: 平仓原因(可选) """ query = "SELECT * FROM trades WHERE 1=1" params = [] # 多账号隔离(兼容旧schema) try: if _table_has_column("trades", "account_id"): query += " AND account_id = %s" params.append(int(account_id or DEFAULT_ACCOUNT_ID)) except Exception: pass if start_timestamp is not None: query += " AND created_at >= %s" params.append(start_timestamp) if end_timestamp is not None: query += " AND created_at <= %s" params.append(end_timestamp) if symbol: query += " AND symbol = %s" params.append(symbol) if status: query += " AND status = %s" params.append(status) if trade_type: query += " AND side = %s" params.append(trade_type) if exit_reason: query += " AND exit_reason = %s" params.append(exit_reason) # 按平仓时间倒序排序,如果没有平仓时间则按入场时间倒序 # query += " ORDER BY COALESCE(exit_time, entry_time) DESC, entry_time DESC" query += " ORDER BY id DESC" logger.info(f"查询交易记录: {query}, {params}") result = db.execute_query(query, params) return result @staticmethod def get_by_symbol(symbol, status='open', account_id: int = None): """根据交易对获取持仓""" aid = int(account_id or DEFAULT_ACCOUNT_ID) if _table_has_column("trades", "account_id"): return db.execute_query( "SELECT * FROM trades WHERE account_id = %s AND symbol = %s AND status = %s", (aid, symbol, status), ) return db.execute_query( "SELECT * FROM trades WHERE symbol = %s AND status = %s", (symbol, status), ) class AccountSnapshot: """账户快照模型""" @staticmethod def create(total_balance, available_balance, total_position_value, total_pnl, open_positions, account_id: int = None): """创建账户快照(使用北京时间)""" snapshot_time = get_beijing_time() if _table_has_column("account_snapshots", "account_id"): db.execute_update( """INSERT INTO account_snapshots (account_id, total_balance, available_balance, total_position_value, total_pnl, open_positions, snapshot_time) VALUES (%s, %s, %s, %s, %s, %s, %s)""", (int(account_id or DEFAULT_ACCOUNT_ID), total_balance, available_balance, total_position_value, total_pnl, open_positions, snapshot_time), ) else: db.execute_update( """INSERT INTO account_snapshots (total_balance, available_balance, total_position_value, total_pnl, open_positions, snapshot_time) VALUES (%s, %s, %s, %s, %s, %s)""", (total_balance, available_balance, total_position_value, total_pnl, open_positions, snapshot_time), ) @staticmethod def get_recent(days=7, account_id: int = None): """获取最近的快照""" aid = int(account_id or DEFAULT_ACCOUNT_ID) if _table_has_column("account_snapshots", "account_id"): return db.execute_query( """SELECT * FROM account_snapshots WHERE account_id = %s AND snapshot_time >= DATE_SUB(NOW(), INTERVAL %s DAY) ORDER BY snapshot_time DESC""", (aid, days), ) return db.execute_query( """SELECT * FROM account_snapshots WHERE snapshot_time >= DATE_SUB(NOW(), INTERVAL %s DAY) ORDER BY snapshot_time DESC""", (days,), ) class MarketScan: """市场扫描记录模型""" @staticmethod def create(symbols_scanned, symbols_found, top_symbols, scan_duration): """创建扫描记录(使用北京时间)""" scan_time = get_beijing_time() db.execute_update( """INSERT INTO market_scans (symbols_scanned, symbols_found, top_symbols, scan_duration, scan_time) VALUES (%s, %s, %s, %s, %s)""", (symbols_scanned, symbols_found, json.dumps(top_symbols), scan_duration, scan_time) ) @staticmethod def get_recent(limit=100): """获取最近的扫描记录""" return db.execute_query( "SELECT * FROM market_scans ORDER BY scan_time DESC LIMIT %s", (limit,) ) class TradingSignal: """交易信号模型""" @staticmethod def create(symbol, signal_direction, signal_strength, signal_reason, rsi=None, macd_histogram=None, market_regime=None): """创建交易信号(使用北京时间)""" signal_time = get_beijing_time() db.execute_update( """INSERT INTO trading_signals (symbol, signal_direction, signal_strength, signal_reason, rsi, macd_histogram, market_regime, signal_time) VALUES (%s, %s, %s, %s, %s, %s, %s, %s)""", (symbol, signal_direction, signal_strength, signal_reason, rsi, macd_histogram, market_regime, signal_time) ) @staticmethod def mark_executed(signal_id): """标记信号已执行""" db.execute_update( "UPDATE trading_signals SET executed = TRUE WHERE id = %s", (signal_id,) ) @staticmethod def get_recent(limit=100): """获取最近的信号""" return db.execute_query( "SELECT * FROM trading_signals ORDER BY signal_time DESC LIMIT %s", (limit,) ) class TradeRecommendation: """推荐交易对模型""" @staticmethod def create( symbol, direction, current_price, change_percent, recommendation_reason, signal_strength, market_regime=None, trend_4h=None, rsi=None, macd_histogram=None, bollinger_upper=None, bollinger_middle=None, bollinger_lower=None, ema20=None, ema50=None, ema20_4h=None, atr=None, suggested_stop_loss=None, suggested_take_profit_1=None, suggested_take_profit_2=None, suggested_position_percent=None, suggested_leverage=10, order_type='LIMIT', suggested_limit_price=None, volume_24h=None, volatility=None, notes=None, user_guide=None, recommendation_category=None, risk_level=None, expected_hold_time=None, trading_tutorial=None, max_hold_days=3 ): """创建推荐记录(使用北京时间)""" recommendation_time = get_beijing_time() # 默认24小时后过期 expires_at = recommendation_time + timedelta(hours=24) # 检查字段是否存在(兼容旧数据库schema) try: db.execute_one("SELECT order_type FROM trade_recommendations LIMIT 1") has_order_fields = True except: has_order_fields = False # 检查是否有 user_guide 字段 try: db.execute_one("SELECT user_guide FROM trade_recommendations LIMIT 1") has_user_guide_fields = True except: has_user_guide_fields = False if has_user_guide_fields: # 包含所有新字段(user_guide等) db.execute_update( """INSERT INTO trade_recommendations (symbol, direction, recommendation_time, current_price, change_percent, recommendation_reason, signal_strength, market_regime, trend_4h, rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower, ema20, ema50, ema20_4h, atr, suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2, suggested_position_percent, suggested_leverage, order_type, suggested_limit_price, volume_24h, volatility, expires_at, notes, user_guide, recommendation_category, risk_level, expected_hold_time, trading_tutorial, max_hold_days) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)""", (symbol, direction, recommendation_time, current_price, change_percent, recommendation_reason, signal_strength, market_regime, trend_4h, rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower, ema20, ema50, ema20_4h, atr, suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2, suggested_position_percent, suggested_leverage, order_type, suggested_limit_price, volume_24h, volatility, expires_at, notes, user_guide, recommendation_category, risk_level, expected_hold_time, trading_tutorial, max_hold_days) ) elif has_order_fields: # 只有 order_type 字段,没有 user_guide 字段 db.execute_update( """INSERT INTO trade_recommendations (symbol, direction, recommendation_time, current_price, change_percent, recommendation_reason, signal_strength, market_regime, trend_4h, rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower, ema20, ema50, ema20_4h, atr, suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2, suggested_position_percent, suggested_leverage, order_type, suggested_limit_price, volume_24h, volatility, expires_at, notes) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)""", (symbol, direction, recommendation_time, current_price, change_percent, recommendation_reason, signal_strength, market_regime, trend_4h, rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower, ema20, ema50, ema20_4h, atr, suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2, suggested_position_percent, suggested_leverage, order_type, suggested_limit_price, volume_24h, volatility, expires_at, notes) ) else: # 兼容旧schema db.execute_update( """INSERT INTO trade_recommendations (symbol, direction, recommendation_time, current_price, change_percent, recommendation_reason, signal_strength, market_regime, trend_4h, rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower, ema20, ema50, ema20_4h, atr, suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2, suggested_position_percent, suggested_leverage, volume_24h, volatility, expires_at, notes) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)""", (symbol, direction, recommendation_time, current_price, change_percent, recommendation_reason, signal_strength, market_regime, trend_4h, rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower, ema20, ema50, ema20_4h, atr, suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2, suggested_position_percent, suggested_leverage, volume_24h, volatility, expires_at, notes) ) return db.execute_one("SELECT LAST_INSERT_ID() as id")['id'] @staticmethod def mark_executed(recommendation_id, trade_id=None, execution_result='success'): """标记推荐已执行""" executed_at = get_beijing_time() db.execute_update( """UPDATE trade_recommendations SET status = 'executed', executed_at = %s, execution_result = %s, execution_trade_id = %s WHERE id = %s""", (executed_at, execution_result, trade_id, recommendation_id) ) @staticmethod def mark_expired(recommendation_id): """标记推荐已过期""" db.execute_update( "UPDATE trade_recommendations SET status = 'expired' WHERE id = %s", (recommendation_id,) ) @staticmethod def mark_cancelled(recommendation_id, notes=None): """标记推荐已取消""" db.execute_update( "UPDATE trade_recommendations SET status = 'cancelled', notes = %s WHERE id = %s", (notes, recommendation_id) ) @staticmethod def get_all(status=None, direction=None, limit=100, start_date=None, end_date=None): """获取推荐记录""" query = "SELECT * FROM trade_recommendations WHERE 1=1" params = [] if status: query += " AND status = %s" params.append(status) if direction: query += " AND direction = %s" params.append(direction) if start_date: query += " AND recommendation_time >= %s" params.append(start_date) if end_date: query += " AND recommendation_time <= %s" params.append(end_date) query += " ORDER BY recommendation_time DESC, signal_strength DESC LIMIT %s" params.append(limit) return db.execute_query(query, params) @staticmethod def get_active(): """获取当前有效的推荐(未过期、未执行、未取消) 同一交易对只返回最新的推荐(去重) """ return db.execute_query( """SELECT t1.* FROM trade_recommendations t1 INNER JOIN ( SELECT symbol, MAX(recommendation_time) as max_time FROM trade_recommendations WHERE status = 'active' AND (expires_at IS NULL OR expires_at > NOW()) GROUP BY symbol ) t2 ON t1.symbol = t2.symbol AND t1.recommendation_time = t2.max_time WHERE t1.status = 'active' AND (t1.expires_at IS NULL OR t1.expires_at > NOW()) ORDER BY t1.signal_strength DESC, t1.recommendation_time DESC""" ) @staticmethod def get_by_id(recommendation_id): """根据ID获取推荐""" return db.execute_one( "SELECT * FROM trade_recommendations WHERE id = %s", (recommendation_id,) ) @staticmethod def get_by_symbol(symbol, limit=10): """根据交易对获取推荐记录""" return db.execute_query( """SELECT * FROM trade_recommendations WHERE symbol = %s ORDER BY recommendation_time DESC LIMIT %s""", (symbol, limit) )