auto_trade_sys/backend/database/models.py
薇薇安 b08d97b442 a
2026-01-15 11:34:53 +08:00

349 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
数据库模型定义
"""
from database.connection import db
from datetime import datetime, timezone, timedelta
import json
import logging
# 北京时间时区UTC+8
BEIJING_TZ = timezone(timedelta(hours=8))
def get_beijing_time():
"""获取当前北京时间UTC+8"""
return datetime.now(BEIJING_TZ).replace(tzinfo=None)
logger = logging.getLogger(__name__)
class TradingConfig:
"""交易配置模型"""
@staticmethod
def get_all():
"""获取所有配置"""
return db.execute_query(
"SELECT * FROM trading_config ORDER BY category, config_key"
)
@staticmethod
def get(key):
"""获取单个配置"""
return db.execute_one(
"SELECT * FROM trading_config WHERE config_key = %s",
(key,)
)
@staticmethod
def set(key, value, config_type, category, description=None):
"""设置配置"""
value_str = TradingConfig._convert_to_string(value, config_type)
db.execute_update(
"""INSERT INTO trading_config
(config_key, config_value, config_type, category, description)
VALUES (%s, %s, %s, %s, %s)
ON DUPLICATE KEY UPDATE
config_value = VALUES(config_value),
config_type = VALUES(config_type),
category = VALUES(category),
description = VALUES(description),
updated_at = CURRENT_TIMESTAMP""",
(key, value_str, config_type, category, description)
)
@staticmethod
def get_value(key, default=None):
"""获取配置值(自动转换类型)"""
result = TradingConfig.get(key)
if result:
return TradingConfig._convert_value(result['config_value'], result['config_type'])
return default
@staticmethod
def _convert_value(value, config_type):
"""转换配置值类型"""
if config_type == 'number':
try:
return float(value) if '.' in str(value) else int(value)
except:
return 0
elif config_type == 'boolean':
return str(value).lower() in ('true', '1', 'yes', 'on')
elif config_type == 'json':
try:
return json.loads(value)
except:
return {}
return value
@staticmethod
def _convert_to_string(value, config_type):
"""转换值为字符串"""
if config_type == 'json':
return json.dumps(value, ensure_ascii=False)
return str(value)
class Trade:
"""交易记录模型"""
@staticmethod
def create(symbol, side, quantity, entry_price, leverage=10, entry_reason=None):
"""创建交易记录(使用北京时间)"""
entry_time = get_beijing_time()
db.execute_update(
"""INSERT INTO trades
(symbol, side, quantity, entry_price, leverage, entry_reason, status, entry_time)
VALUES (%s, %s, %s, %s, %s, %s, 'open', %s)""",
(symbol, side, quantity, entry_price, leverage, entry_reason, entry_time)
)
return db.execute_one("SELECT LAST_INSERT_ID() as id")['id']
@staticmethod
def update_exit(trade_id, exit_price, exit_reason, pnl, pnl_percent):
"""更新平仓信息(使用北京时间)"""
exit_time = get_beijing_time()
db.execute_update(
"""UPDATE trades
SET exit_price = %s, exit_time = %s,
exit_reason = %s, pnl = %s, pnl_percent = %s, status = 'closed'
WHERE id = %s""",
(exit_price, exit_time, exit_reason, pnl, pnl_percent, trade_id)
)
@staticmethod
def get_all(start_date=None, end_date=None, symbol=None, status=None, trade_type=None, exit_reason=None):
"""获取交易记录"""
query = "SELECT * FROM trades WHERE 1=1"
params = []
if start_date:
query += " AND entry_time >= %s"
params.append(start_date)
if end_date:
query += " AND entry_time <= %s"
params.append(end_date)
if symbol:
query += " AND symbol = %s"
params.append(symbol)
if status:
query += " AND status = %s"
params.append(status)
if trade_type:
query += " AND side = %s"
params.append(trade_type)
if exit_reason:
query += " AND exit_reason = %s"
params.append(exit_reason)
query += " ORDER BY entry_time DESC"
return db.execute_query(query, params)
@staticmethod
def get_by_symbol(symbol, status='open'):
"""根据交易对获取持仓"""
return db.execute_query(
"SELECT * FROM trades WHERE symbol = %s AND status = %s",
(symbol, status)
)
class AccountSnapshot:
"""账户快照模型"""
@staticmethod
def create(total_balance, available_balance, total_position_value, total_pnl, open_positions):
"""创建账户快照(使用北京时间)"""
snapshot_time = get_beijing_time()
db.execute_update(
"""INSERT INTO account_snapshots
(total_balance, available_balance, total_position_value, total_pnl, open_positions, snapshot_time)
VALUES (%s, %s, %s, %s, %s, %s)""",
(total_balance, available_balance, total_position_value, total_pnl, open_positions, snapshot_time)
)
@staticmethod
def get_recent(days=7):
"""获取最近的快照"""
return db.execute_query(
"""SELECT * FROM account_snapshots
WHERE snapshot_time >= DATE_SUB(NOW(), INTERVAL %s DAY)
ORDER BY snapshot_time DESC""",
(days,)
)
class MarketScan:
"""市场扫描记录模型"""
@staticmethod
def create(symbols_scanned, symbols_found, top_symbols, scan_duration):
"""创建扫描记录(使用北京时间)"""
scan_time = get_beijing_time()
db.execute_update(
"""INSERT INTO market_scans
(symbols_scanned, symbols_found, top_symbols, scan_duration, scan_time)
VALUES (%s, %s, %s, %s, %s)""",
(symbols_scanned, symbols_found, json.dumps(top_symbols), scan_duration, scan_time)
)
@staticmethod
def get_recent(limit=100):
"""获取最近的扫描记录"""
return db.execute_query(
"SELECT * FROM market_scans ORDER BY scan_time DESC LIMIT %s",
(limit,)
)
class TradingSignal:
"""交易信号模型"""
@staticmethod
def create(symbol, signal_direction, signal_strength, signal_reason,
rsi=None, macd_histogram=None, market_regime=None):
"""创建交易信号(使用北京时间)"""
signal_time = get_beijing_time()
db.execute_update(
"""INSERT INTO trading_signals
(symbol, signal_direction, signal_strength, signal_reason,
rsi, macd_histogram, market_regime, signal_time)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)""",
(symbol, signal_direction, signal_strength, signal_reason,
rsi, macd_histogram, market_regime, signal_time)
)
@staticmethod
def mark_executed(signal_id):
"""标记信号已执行"""
db.execute_update(
"UPDATE trading_signals SET executed = TRUE WHERE id = %s",
(signal_id,)
)
@staticmethod
def get_recent(limit=100):
"""获取最近的信号"""
return db.execute_query(
"SELECT * FROM trading_signals ORDER BY signal_time DESC LIMIT %s",
(limit,)
)
class TradeRecommendation:
"""推荐交易对模型"""
@staticmethod
def create(
symbol, direction, current_price, change_percent, recommendation_reason,
signal_strength, market_regime=None, trend_4h=None,
rsi=None, macd_histogram=None,
bollinger_upper=None, bollinger_middle=None, bollinger_lower=None,
ema20=None, ema50=None, ema20_4h=None, atr=None,
suggested_stop_loss=None, suggested_take_profit_1=None, suggested_take_profit_2=None,
suggested_position_percent=None, suggested_leverage=10,
volume_24h=None, volatility=None, notes=None
):
"""创建推荐记录(使用北京时间)"""
recommendation_time = get_beijing_time()
# 默认24小时后过期
expires_at = recommendation_time + timedelta(hours=24)
db.execute_update(
"""INSERT INTO trade_recommendations
(symbol, direction, recommendation_time, current_price, change_percent,
recommendation_reason, signal_strength, market_regime, trend_4h,
rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower,
ema20, ema50, ema20_4h, atr,
suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2,
suggested_position_percent, suggested_leverage,
volume_24h, volatility, expires_at, notes)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)""",
(symbol, direction, recommendation_time, current_price, change_percent,
recommendation_reason, signal_strength, market_regime, trend_4h,
rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower,
ema20, ema50, ema20_4h, atr,
suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2,
suggested_position_percent, suggested_leverage,
volume_24h, volatility, expires_at, notes)
)
return db.execute_one("SELECT LAST_INSERT_ID() as id")['id']
@staticmethod
def mark_executed(recommendation_id, trade_id=None, execution_result='success'):
"""标记推荐已执行"""
executed_at = get_beijing_time()
db.execute_update(
"""UPDATE trade_recommendations
SET status = 'executed', executed_at = %s, execution_result = %s, execution_trade_id = %s
WHERE id = %s""",
(executed_at, execution_result, trade_id, recommendation_id)
)
@staticmethod
def mark_expired(recommendation_id):
"""标记推荐已过期"""
db.execute_update(
"UPDATE trade_recommendations SET status = 'expired' WHERE id = %s",
(recommendation_id,)
)
@staticmethod
def mark_cancelled(recommendation_id, notes=None):
"""标记推荐已取消"""
db.execute_update(
"UPDATE trade_recommendations SET status = 'cancelled', notes = %s WHERE id = %s",
(notes, recommendation_id)
)
@staticmethod
def get_all(status=None, direction=None, limit=100, start_date=None, end_date=None):
"""获取推荐记录"""
query = "SELECT * FROM trade_recommendations WHERE 1=1"
params = []
if status:
query += " AND status = %s"
params.append(status)
if direction:
query += " AND direction = %s"
params.append(direction)
if start_date:
query += " AND recommendation_time >= %s"
params.append(start_date)
if end_date:
query += " AND recommendation_time <= %s"
params.append(end_date)
query += " ORDER BY recommendation_time DESC, signal_strength DESC LIMIT %s"
params.append(limit)
return db.execute_query(query, params)
@staticmethod
def get_active():
"""获取当前有效的推荐(未过期、未执行、未取消)"""
return db.execute_query(
"""SELECT * FROM trade_recommendations
WHERE status = 'active' AND (expires_at IS NULL OR expires_at > NOW())
ORDER BY signal_strength DESC, recommendation_time DESC"""
)
@staticmethod
def get_by_id(recommendation_id):
"""根据ID获取推荐"""
return db.execute_one(
"SELECT * FROM trade_recommendations WHERE id = %s",
(recommendation_id,)
)
@staticmethod
def get_by_symbol(symbol, limit=10):
"""根据交易对获取推荐记录"""
return db.execute_query(
"""SELECT * FROM trade_recommendations
WHERE symbol = %s
ORDER BY recommendation_time DESC LIMIT %s""",
(symbol, limit)
)