auto_trade_sys/backend/database/models.py
薇薇安 8a89592cb5 a
2026-01-13 17:30:59 +08:00

213 lines
6.8 KiB
Python

"""
数据库模型定义
"""
from database.connection import db
from datetime import datetime
import json
import logging
logger = logging.getLogger(__name__)
class TradingConfig:
"""交易配置模型"""
@staticmethod
def get_all():
"""获取所有配置"""
return db.execute_query(
"SELECT * FROM trading_config ORDER BY category, config_key"
)
@staticmethod
def get(key):
"""获取单个配置"""
return db.execute_one(
"SELECT * FROM trading_config WHERE config_key = %s",
(key,)
)
@staticmethod
def set(key, value, config_type, category, description=None):
"""设置配置"""
value_str = TradingConfig._convert_to_string(value, config_type)
db.execute_update(
"""INSERT INTO trading_config
(config_key, config_value, config_type, category, description)
VALUES (%s, %s, %s, %s, %s)
ON DUPLICATE KEY UPDATE
config_value = VALUES(config_value),
config_type = VALUES(config_type),
category = VALUES(category),
description = VALUES(description),
updated_at = CURRENT_TIMESTAMP""",
(key, value_str, config_type, category, description)
)
@staticmethod
def get_value(key, default=None):
"""获取配置值(自动转换类型)"""
result = TradingConfig.get(key)
if result:
return TradingConfig._convert_value(result['config_value'], result['config_type'])
return default
@staticmethod
def _convert_value(value, config_type):
"""转换配置值类型"""
if config_type == 'number':
try:
return float(value) if '.' in str(value) else int(value)
except:
return 0
elif config_type == 'boolean':
return str(value).lower() in ('true', '1', 'yes', 'on')
elif config_type == 'json':
try:
return json.loads(value)
except:
return {}
return value
@staticmethod
def _convert_to_string(value, config_type):
"""转换值为字符串"""
if config_type == 'json':
return json.dumps(value, ensure_ascii=False)
return str(value)
class Trade:
"""交易记录模型"""
@staticmethod
def create(symbol, side, quantity, entry_price, leverage=10, entry_reason=None):
"""创建交易记录"""
db.execute_update(
"""INSERT INTO trades
(symbol, side, quantity, entry_price, leverage, entry_reason, status)
VALUES (%s, %s, %s, %s, %s, %s, 'open')""",
(symbol, side, quantity, entry_price, leverage, entry_reason)
)
return db.execute_one("SELECT LAST_INSERT_ID() as id")['id']
@staticmethod
def update_exit(trade_id, exit_price, exit_reason, pnl, pnl_percent):
"""更新平仓信息"""
db.execute_update(
"""UPDATE trades
SET exit_price = %s, exit_time = CURRENT_TIMESTAMP,
exit_reason = %s, pnl = %s, pnl_percent = %s, status = 'closed'
WHERE id = %s""",
(exit_price, exit_reason, pnl, pnl_percent, trade_id)
)
@staticmethod
def get_all(start_date=None, end_date=None, symbol=None, status=None):
"""获取交易记录"""
query = "SELECT * FROM trades WHERE 1=1"
params = []
if start_date:
query += " AND entry_time >= %s"
params.append(start_date)
if end_date:
query += " AND entry_time <= %s"
params.append(end_date)
if symbol:
query += " AND symbol = %s"
params.append(symbol)
if status:
query += " AND status = %s"
params.append(status)
query += " ORDER BY entry_time DESC"
return db.execute_query(query, params)
@staticmethod
def get_by_symbol(symbol, status='open'):
"""根据交易对获取持仓"""
return db.execute_query(
"SELECT * FROM trades WHERE symbol = %s AND status = %s",
(symbol, status)
)
class AccountSnapshot:
"""账户快照模型"""
@staticmethod
def create(total_balance, available_balance, total_position_value, total_pnl, open_positions):
"""创建账户快照"""
db.execute_update(
"""INSERT INTO account_snapshots
(total_balance, available_balance, total_position_value, total_pnl, open_positions)
VALUES (%s, %s, %s, %s, %s)""",
(total_balance, available_balance, total_position_value, total_pnl, open_positions)
)
@staticmethod
def get_recent(days=7):
"""获取最近的快照"""
return db.execute_query(
"""SELECT * FROM account_snapshots
WHERE snapshot_time >= DATE_SUB(NOW(), INTERVAL %s DAY)
ORDER BY snapshot_time DESC""",
(days,)
)
class MarketScan:
"""市场扫描记录模型"""
@staticmethod
def create(symbols_scanned, symbols_found, top_symbols, scan_duration):
"""创建扫描记录"""
db.execute_update(
"""INSERT INTO market_scans
(symbols_scanned, symbols_found, top_symbols, scan_duration)
VALUES (%s, %s, %s, %s)""",
(symbols_scanned, symbols_found, json.dumps(top_symbols), scan_duration)
)
@staticmethod
def get_recent(limit=100):
"""获取最近的扫描记录"""
return db.execute_query(
"SELECT * FROM market_scans ORDER BY scan_time DESC LIMIT %s",
(limit,)
)
class TradingSignal:
"""交易信号模型"""
@staticmethod
def create(symbol, signal_direction, signal_strength, signal_reason,
rsi=None, macd_histogram=None, market_regime=None):
"""创建交易信号"""
db.execute_update(
"""INSERT INTO trading_signals
(symbol, signal_direction, signal_strength, signal_reason,
rsi, macd_histogram, market_regime)
VALUES (%s, %s, %s, %s, %s, %s, %s)""",
(symbol, signal_direction, signal_strength, signal_reason,
rsi, macd_histogram, market_regime)
)
@staticmethod
def mark_executed(signal_id):
"""标记信号已执行"""
db.execute_update(
"UPDATE trading_signals SET executed = TRUE WHERE id = %s",
(signal_id,)
)
@staticmethod
def get_recent(limit=100):
"""获取最近的信号"""
return db.execute_query(
"SELECT * FROM trading_signals ORDER BY signal_time DESC LIMIT %s",
(limit,)
)