""" 数据库模型定义 """ from database.connection import db from datetime import datetime import json import logging logger = logging.getLogger(__name__) class TradingConfig: """交易配置模型""" @staticmethod def get_all(): """获取所有配置""" return db.execute_query( "SELECT * FROM trading_config ORDER BY category, config_key" ) @staticmethod def get(key): """获取单个配置""" return db.execute_one( "SELECT * FROM trading_config WHERE config_key = %s", (key,) ) @staticmethod def set(key, value, config_type, category, description=None): """设置配置""" value_str = TradingConfig._convert_to_string(value, config_type) db.execute_update( """INSERT INTO trading_config (config_key, config_value, config_type, category, description) VALUES (%s, %s, %s, %s, %s) ON DUPLICATE KEY UPDATE config_value = VALUES(config_value), config_type = VALUES(config_type), category = VALUES(category), description = VALUES(description), updated_at = CURRENT_TIMESTAMP""", (key, value_str, config_type, category, description) ) @staticmethod def get_value(key, default=None): """获取配置值(自动转换类型)""" result = TradingConfig.get(key) if result: return TradingConfig._convert_value(result['config_value'], result['config_type']) return default @staticmethod def _convert_value(value, config_type): """转换配置值类型""" if config_type == 'number': try: return float(value) if '.' in str(value) else int(value) except: return 0 elif config_type == 'boolean': return str(value).lower() in ('true', '1', 'yes', 'on') elif config_type == 'json': try: return json.loads(value) except: return {} return value @staticmethod def _convert_to_string(value, config_type): """转换值为字符串""" if config_type == 'json': return json.dumps(value, ensure_ascii=False) return str(value) class Trade: """交易记录模型""" @staticmethod def create(symbol, side, quantity, entry_price, leverage=10, entry_reason=None): """创建交易记录""" db.execute_update( """INSERT INTO trades (symbol, side, quantity, entry_price, leverage, entry_reason, status) VALUES (%s, %s, %s, %s, %s, %s, 'open')""", (symbol, side, quantity, entry_price, leverage, entry_reason) ) return db.execute_one("SELECT LAST_INSERT_ID() as id")['id'] @staticmethod def update_exit(trade_id, exit_price, exit_reason, pnl, pnl_percent): """更新平仓信息""" db.execute_update( """UPDATE trades SET exit_price = %s, exit_time = CURRENT_TIMESTAMP, exit_reason = %s, pnl = %s, pnl_percent = %s, status = 'closed' WHERE id = %s""", (exit_price, exit_reason, pnl, pnl_percent, trade_id) ) @staticmethod def get_all(start_date=None, end_date=None, symbol=None, status=None): """获取交易记录""" query = "SELECT * FROM trades WHERE 1=1" params = [] if start_date: query += " AND entry_time >= %s" params.append(start_date) if end_date: query += " AND entry_time <= %s" params.append(end_date) if symbol: query += " AND symbol = %s" params.append(symbol) if status: query += " AND status = %s" params.append(status) query += " ORDER BY entry_time DESC" return db.execute_query(query, params) @staticmethod def get_by_symbol(symbol, status='open'): """根据交易对获取持仓""" return db.execute_query( "SELECT * FROM trades WHERE symbol = %s AND status = %s", (symbol, status) ) class AccountSnapshot: """账户快照模型""" @staticmethod def create(total_balance, available_balance, total_position_value, total_pnl, open_positions): """创建账户快照""" db.execute_update( """INSERT INTO account_snapshots (total_balance, available_balance, total_position_value, total_pnl, open_positions) VALUES (%s, %s, %s, %s, %s)""", (total_balance, available_balance, total_position_value, total_pnl, open_positions) ) @staticmethod def get_recent(days=7): """获取最近的快照""" return db.execute_query( """SELECT * FROM account_snapshots WHERE snapshot_time >= DATE_SUB(NOW(), INTERVAL %s DAY) ORDER BY snapshot_time DESC""", (days,) ) class MarketScan: """市场扫描记录模型""" @staticmethod def create(symbols_scanned, symbols_found, top_symbols, scan_duration): """创建扫描记录""" db.execute_update( """INSERT INTO market_scans (symbols_scanned, symbols_found, top_symbols, scan_duration) VALUES (%s, %s, %s, %s)""", (symbols_scanned, symbols_found, json.dumps(top_symbols), scan_duration) ) @staticmethod def get_recent(limit=100): """获取最近的扫描记录""" return db.execute_query( "SELECT * FROM market_scans ORDER BY scan_time DESC LIMIT %s", (limit,) ) class TradingSignal: """交易信号模型""" @staticmethod def create(symbol, signal_direction, signal_strength, signal_reason, rsi=None, macd_histogram=None, market_regime=None): """创建交易信号""" db.execute_update( """INSERT INTO trading_signals (symbol, signal_direction, signal_strength, signal_reason, rsi, macd_histogram, market_regime) VALUES (%s, %s, %s, %s, %s, %s, %s)""", (symbol, signal_direction, signal_strength, signal_reason, rsi, macd_histogram, market_regime) ) @staticmethod def mark_executed(signal_id): """标记信号已执行""" db.execute_update( "UPDATE trading_signals SET executed = TRUE WHERE id = %s", (signal_id,) ) @staticmethod def get_recent(limit=100): """获取最近的信号""" return db.execute_query( "SELECT * FROM trading_signals ORDER BY signal_time DESC LIMIT %s", (limit,) )