auto_trade_sys/backend/database/models.py
薇薇安 4ebfd21761 a
2026-01-17 12:28:37 +08:00

511 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
数据库模型定义
"""
from database.connection import db
from datetime import datetime, timezone, timedelta
import json
import logging
# 北京时间时区UTC+8
BEIJING_TZ = timezone(timedelta(hours=8))
def get_beijing_time():
"""获取当前北京时间UTC+8"""
return datetime.now(BEIJING_TZ).replace(tzinfo=None)
logger = logging.getLogger(__name__)
class TradingConfig:
"""交易配置模型"""
@staticmethod
def get_all():
"""获取所有配置"""
return db.execute_query(
"SELECT * FROM trading_config ORDER BY category, config_key"
)
@staticmethod
def get(key):
"""获取单个配置"""
return db.execute_one(
"SELECT * FROM trading_config WHERE config_key = %s",
(key,)
)
@staticmethod
def set(key, value, config_type, category, description=None):
"""设置配置"""
value_str = TradingConfig._convert_to_string(value, config_type)
db.execute_update(
"""INSERT INTO trading_config
(config_key, config_value, config_type, category, description)
VALUES (%s, %s, %s, %s, %s)
ON DUPLICATE KEY UPDATE
config_value = VALUES(config_value),
config_type = VALUES(config_type),
category = VALUES(category),
description = VALUES(description),
updated_at = CURRENT_TIMESTAMP""",
(key, value_str, config_type, category, description)
)
@staticmethod
def get_value(key, default=None):
"""获取配置值(自动转换类型)"""
result = TradingConfig.get(key)
if result:
return TradingConfig._convert_value(result['config_value'], result['config_type'])
return default
@staticmethod
def _convert_value(value, config_type):
"""转换配置值类型"""
if config_type == 'number':
try:
return float(value) if '.' in str(value) else int(value)
except:
return 0
elif config_type == 'boolean':
return str(value).lower() in ('true', '1', 'yes', 'on')
elif config_type == 'json':
try:
return json.loads(value)
except:
return {}
return value
@staticmethod
def _convert_to_string(value, config_type):
"""转换值为字符串"""
if config_type == 'json':
return json.dumps(value, ensure_ascii=False)
return str(value)
class Trade:
"""交易记录模型"""
@staticmethod
def create(symbol, side, quantity, entry_price, leverage=10, entry_reason=None, entry_order_id=None):
"""创建交易记录(使用北京时间)
Args:
symbol: 交易对
side: 方向
quantity: 数量
entry_price: 入场价
leverage: 杠杆
entry_reason: 入场原因
entry_order_id: 币安开仓订单号(可选,用于对账)
"""
entry_time = get_beijing_time()
db.execute_update(
"""INSERT INTO trades
(symbol, side, quantity, entry_price, leverage, entry_reason, status, entry_time, entry_order_id)
VALUES (%s, %s, %s, %s, %s, %s, 'open', %s, %s)""",
(symbol, side, quantity, entry_price, leverage, entry_reason, entry_time, entry_order_id)
)
return db.execute_one("SELECT LAST_INSERT_ID() as id")['id']
@staticmethod
def update_exit(trade_id, exit_price, exit_reason, pnl, pnl_percent, exit_order_id=None, strategy_type=None, duration_minutes=None):
"""更新平仓信息(使用北京时间)
Args:
trade_id: 交易记录ID
exit_price: 出场价
exit_reason: 平仓原因
pnl: 盈亏
pnl_percent: 盈亏百分比
exit_order_id: 币安平仓订单号(可选,用于对账)
注意:如果 exit_order_id 已存在且属于其他交易记录,会跳过更新 exit_order_id 以避免唯一约束冲突
"""
exit_time = get_beijing_time()
# 如果提供了 exit_order_id先检查是否已被其他交易记录使用
if exit_order_id is not None:
try:
existing_trade = Trade.get_by_exit_order_id(exit_order_id)
if existing_trade:
if existing_trade['id'] == trade_id:
# 如果 exit_order_id 属于当前交易记录,说明已经更新过了,直接返回
logger.debug(f"交易记录 {trade_id} 的 exit_order_id {exit_order_id} 已存在,跳过更新")
return
else:
# 如果 exit_order_id 已被其他交易记录使用,记录警告但不更新 exit_order_id
logger.warning(
f"交易记录 {trade_id} 的 exit_order_id {exit_order_id} 已被交易记录 {existing_trade['id']} 使用,"
f"跳过更新 exit_order_id只更新其他字段"
)
# 只更新其他字段,不更新 exit_order_id
update_fields = [
"exit_price = %s", "exit_time = %s",
"exit_reason = %s", "pnl = %s", "pnl_percent = %s", "status = 'closed'"
]
update_values = [exit_price, exit_time, exit_reason, pnl, pnl_percent]
if strategy_type is not None:
update_fields.append("strategy_type = %s")
update_values.append(strategy_type)
if duration_minutes is not None:
update_fields.append("duration_minutes = %s")
update_values.append(duration_minutes)
update_values.append(trade_id)
db.execute_update(
f"UPDATE trades SET {', '.join(update_fields)} WHERE id = %s",
tuple(update_values)
)
return
except Exception as e:
# 如果查询失败,记录警告但继续正常更新
logger.warning(f"检查 exit_order_id {exit_order_id} 时出错: {e},继续正常更新")
# 正常更新(包括 exit_order_id
try:
update_fields = [
"exit_price = %s", "exit_time = %s",
"exit_reason = %s", "pnl = %s", "pnl_percent = %s", "status = 'closed'",
"exit_order_id = %s"
]
update_values = [exit_price, exit_time, exit_reason, pnl, pnl_percent, exit_order_id]
if strategy_type is not None:
update_fields.append("strategy_type = %s")
update_values.append(strategy_type)
if duration_minutes is not None:
update_fields.append("duration_minutes = %s")
update_values.append(duration_minutes)
update_values.append(trade_id)
db.execute_update(
f"UPDATE trades SET {', '.join(update_fields)} WHERE id = %s",
tuple(update_values)
)
except Exception as e:
# 如果更新失败(可能是唯一约束冲突),尝试不更新 exit_order_id
error_str = str(e)
if "Duplicate entry" in error_str and "exit_order_id" in error_str:
logger.warning(
f"更新交易记录 {trade_id} 时 exit_order_id {exit_order_id} 唯一约束冲突,"
f"尝试不更新 exit_order_id"
)
# 只更新其他字段,不更新 exit_order_id
update_fields = [
"exit_price = %s", "exit_time = %s",
"exit_reason = %s", "pnl = %s", "pnl_percent = %s", "status = 'closed'"
]
update_values = [exit_price, exit_time, exit_reason, pnl, pnl_percent]
if strategy_type is not None:
update_fields.append("strategy_type = %s")
update_values.append(strategy_type)
if duration_minutes is not None:
update_fields.append("duration_minutes = %s")
update_values.append(duration_minutes)
update_values.append(trade_id)
db.execute_update(
f"UPDATE trades SET {', '.join(update_fields)} WHERE id = %s",
tuple(update_values)
)
else:
# 其他错误,重新抛出
raise
@staticmethod
def get_by_entry_order_id(entry_order_id):
"""根据开仓订单号获取交易记录"""
return db.execute_one(
"SELECT * FROM trades WHERE entry_order_id = %s",
(entry_order_id,)
)
@staticmethod
def get_by_exit_order_id(exit_order_id):
"""根据平仓订单号获取交易记录"""
return db.execute_one(
"SELECT * FROM trades WHERE exit_order_id = %s",
(exit_order_id,)
)
@staticmethod
def get_all(start_date=None, end_date=None, symbol=None, status=None, trade_type=None, exit_reason=None):
"""获取交易记录"""
query = "SELECT * FROM trades WHERE 1=1"
params = []
if start_date:
query += " AND created_at >= %s"
params.append(start_date)
if end_date:
query += " AND created_at <= %s"
params.append(end_date)
if symbol:
query += " AND symbol = %s"
params.append(symbol)
if status:
query += " AND status = %s"
params.append(status)
if trade_type:
query += " AND side = %s"
params.append(trade_type)
if exit_reason:
query += " AND exit_reason = %s"
params.append(exit_reason)
# 按平仓时间倒序排序,如果没有平仓时间则按入场时间倒序
# query += " ORDER BY COALESCE(exit_time, entry_time) DESC, entry_time DESC"
query += " ORDER BY id DESC"
return db.execute_query(query, params)
@staticmethod
def get_by_symbol(symbol, status='open'):
"""根据交易对获取持仓"""
return db.execute_query(
"SELECT * FROM trades WHERE symbol = %s AND status = %s",
(symbol, status)
)
class AccountSnapshot:
"""账户快照模型"""
@staticmethod
def create(total_balance, available_balance, total_position_value, total_pnl, open_positions):
"""创建账户快照(使用北京时间)"""
snapshot_time = get_beijing_time()
db.execute_update(
"""INSERT INTO account_snapshots
(total_balance, available_balance, total_position_value, total_pnl, open_positions, snapshot_time)
VALUES (%s, %s, %s, %s, %s, %s)""",
(total_balance, available_balance, total_position_value, total_pnl, open_positions, snapshot_time)
)
@staticmethod
def get_recent(days=7):
"""获取最近的快照"""
return db.execute_query(
"""SELECT * FROM account_snapshots
WHERE snapshot_time >= DATE_SUB(NOW(), INTERVAL %s DAY)
ORDER BY snapshot_time DESC""",
(days,)
)
class MarketScan:
"""市场扫描记录模型"""
@staticmethod
def create(symbols_scanned, symbols_found, top_symbols, scan_duration):
"""创建扫描记录(使用北京时间)"""
scan_time = get_beijing_time()
db.execute_update(
"""INSERT INTO market_scans
(symbols_scanned, symbols_found, top_symbols, scan_duration, scan_time)
VALUES (%s, %s, %s, %s, %s)""",
(symbols_scanned, symbols_found, json.dumps(top_symbols), scan_duration, scan_time)
)
@staticmethod
def get_recent(limit=100):
"""获取最近的扫描记录"""
return db.execute_query(
"SELECT * FROM market_scans ORDER BY scan_time DESC LIMIT %s",
(limit,)
)
class TradingSignal:
"""交易信号模型"""
@staticmethod
def create(symbol, signal_direction, signal_strength, signal_reason,
rsi=None, macd_histogram=None, market_regime=None):
"""创建交易信号(使用北京时间)"""
signal_time = get_beijing_time()
db.execute_update(
"""INSERT INTO trading_signals
(symbol, signal_direction, signal_strength, signal_reason,
rsi, macd_histogram, market_regime, signal_time)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)""",
(symbol, signal_direction, signal_strength, signal_reason,
rsi, macd_histogram, market_regime, signal_time)
)
@staticmethod
def mark_executed(signal_id):
"""标记信号已执行"""
db.execute_update(
"UPDATE trading_signals SET executed = TRUE WHERE id = %s",
(signal_id,)
)
@staticmethod
def get_recent(limit=100):
"""获取最近的信号"""
return db.execute_query(
"SELECT * FROM trading_signals ORDER BY signal_time DESC LIMIT %s",
(limit,)
)
class TradeRecommendation:
"""推荐交易对模型"""
@staticmethod
def create(
symbol, direction, current_price, change_percent, recommendation_reason,
signal_strength, market_regime=None, trend_4h=None,
rsi=None, macd_histogram=None,
bollinger_upper=None, bollinger_middle=None, bollinger_lower=None,
ema20=None, ema50=None, ema20_4h=None, atr=None,
suggested_stop_loss=None, suggested_take_profit_1=None, suggested_take_profit_2=None,
suggested_position_percent=None, suggested_leverage=10,
order_type='LIMIT', suggested_limit_price=None,
volume_24h=None, volatility=None, notes=None
):
"""创建推荐记录(使用北京时间)"""
recommendation_time = get_beijing_time()
# 默认24小时后过期
expires_at = recommendation_time + timedelta(hours=24)
# 检查字段是否存在兼容旧数据库schema
try:
db.execute_one("SELECT order_type FROM trade_recommendations LIMIT 1")
has_order_fields = True
except:
has_order_fields = False
if has_order_fields:
db.execute_update(
"""INSERT INTO trade_recommendations
(symbol, direction, recommendation_time, current_price, change_percent,
recommendation_reason, signal_strength, market_regime, trend_4h,
rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower,
ema20, ema50, ema20_4h, atr,
suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2,
suggested_position_percent, suggested_leverage,
order_type, suggested_limit_price,
volume_24h, volatility, expires_at, notes)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)""",
(symbol, direction, recommendation_time, current_price, change_percent,
recommendation_reason, signal_strength, market_regime, trend_4h,
rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower,
ema20, ema50, ema20_4h, atr,
suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2,
suggested_position_percent, suggested_leverage,
order_type, suggested_limit_price,
volume_24h, volatility, expires_at, notes)
)
else:
# 兼容旧schema
db.execute_update(
"""INSERT INTO trade_recommendations
(symbol, direction, recommendation_time, current_price, change_percent,
recommendation_reason, signal_strength, market_regime, trend_4h,
rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower,
ema20, ema50, ema20_4h, atr,
suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2,
suggested_position_percent, suggested_leverage,
volume_24h, volatility, expires_at, notes)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)""",
(symbol, direction, recommendation_time, current_price, change_percent,
recommendation_reason, signal_strength, market_regime, trend_4h,
rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower,
ema20, ema50, ema20_4h, atr,
suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2,
suggested_position_percent, suggested_leverage,
volume_24h, volatility, expires_at, notes)
)
return db.execute_one("SELECT LAST_INSERT_ID() as id")['id']
@staticmethod
def mark_executed(recommendation_id, trade_id=None, execution_result='success'):
"""标记推荐已执行"""
executed_at = get_beijing_time()
db.execute_update(
"""UPDATE trade_recommendations
SET status = 'executed', executed_at = %s, execution_result = %s, execution_trade_id = %s
WHERE id = %s""",
(executed_at, execution_result, trade_id, recommendation_id)
)
@staticmethod
def mark_expired(recommendation_id):
"""标记推荐已过期"""
db.execute_update(
"UPDATE trade_recommendations SET status = 'expired' WHERE id = %s",
(recommendation_id,)
)
@staticmethod
def mark_cancelled(recommendation_id, notes=None):
"""标记推荐已取消"""
db.execute_update(
"UPDATE trade_recommendations SET status = 'cancelled', notes = %s WHERE id = %s",
(notes, recommendation_id)
)
@staticmethod
def get_all(status=None, direction=None, limit=100, start_date=None, end_date=None):
"""获取推荐记录"""
query = "SELECT * FROM trade_recommendations WHERE 1=1"
params = []
if status:
query += " AND status = %s"
params.append(status)
if direction:
query += " AND direction = %s"
params.append(direction)
if start_date:
query += " AND recommendation_time >= %s"
params.append(start_date)
if end_date:
query += " AND recommendation_time <= %s"
params.append(end_date)
query += " ORDER BY recommendation_time DESC, signal_strength DESC LIMIT %s"
params.append(limit)
return db.execute_query(query, params)
@staticmethod
def get_active():
"""获取当前有效的推荐(未过期、未执行、未取消)
同一交易对只返回最新的推荐(去重)
"""
return db.execute_query(
"""SELECT t1.* FROM trade_recommendations t1
INNER JOIN (
SELECT symbol, MAX(recommendation_time) as max_time
FROM trade_recommendations
WHERE status = 'active' AND (expires_at IS NULL OR expires_at > NOW())
GROUP BY symbol
) t2 ON t1.symbol = t2.symbol AND t1.recommendation_time = t2.max_time
WHERE t1.status = 'active' AND (t1.expires_at IS NULL OR t1.expires_at > NOW())
ORDER BY t1.signal_strength DESC, t1.recommendation_time DESC"""
)
@staticmethod
def get_by_id(recommendation_id):
"""根据ID获取推荐"""
return db.execute_one(
"SELECT * FROM trade_recommendations WHERE id = %s",
(recommendation_id,)
)
@staticmethod
def get_by_symbol(symbol, limit=10):
"""根据交易对获取推荐记录"""
return db.execute_query(
"""SELECT * FROM trade_recommendations
WHERE symbol = %s
ORDER BY recommendation_time DESC LIMIT %s""",
(symbol, limit)
)