357 lines
13 KiB
Python
357 lines
13 KiB
Python
"""
|
||
数据库模型定义
|
||
"""
|
||
from database.connection import db
|
||
from datetime import datetime, timezone, timedelta
|
||
import json
|
||
import logging
|
||
|
||
# 北京时间时区(UTC+8)
|
||
BEIJING_TZ = timezone(timedelta(hours=8))
|
||
|
||
def get_beijing_time():
|
||
"""获取当前北京时间(UTC+8)"""
|
||
return datetime.now(BEIJING_TZ).replace(tzinfo=None)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class TradingConfig:
|
||
"""交易配置模型"""
|
||
|
||
@staticmethod
|
||
def get_all():
|
||
"""获取所有配置"""
|
||
return db.execute_query(
|
||
"SELECT * FROM trading_config ORDER BY category, config_key"
|
||
)
|
||
|
||
@staticmethod
|
||
def get(key):
|
||
"""获取单个配置"""
|
||
return db.execute_one(
|
||
"SELECT * FROM trading_config WHERE config_key = %s",
|
||
(key,)
|
||
)
|
||
|
||
@staticmethod
|
||
def set(key, value, config_type, category, description=None):
|
||
"""设置配置"""
|
||
value_str = TradingConfig._convert_to_string(value, config_type)
|
||
db.execute_update(
|
||
"""INSERT INTO trading_config
|
||
(config_key, config_value, config_type, category, description)
|
||
VALUES (%s, %s, %s, %s, %s)
|
||
ON DUPLICATE KEY UPDATE
|
||
config_value = VALUES(config_value),
|
||
config_type = VALUES(config_type),
|
||
category = VALUES(category),
|
||
description = VALUES(description),
|
||
updated_at = CURRENT_TIMESTAMP""",
|
||
(key, value_str, config_type, category, description)
|
||
)
|
||
|
||
@staticmethod
|
||
def get_value(key, default=None):
|
||
"""获取配置值(自动转换类型)"""
|
||
result = TradingConfig.get(key)
|
||
if result:
|
||
return TradingConfig._convert_value(result['config_value'], result['config_type'])
|
||
return default
|
||
|
||
@staticmethod
|
||
def _convert_value(value, config_type):
|
||
"""转换配置值类型"""
|
||
if config_type == 'number':
|
||
try:
|
||
return float(value) if '.' in str(value) else int(value)
|
||
except:
|
||
return 0
|
||
elif config_type == 'boolean':
|
||
return str(value).lower() in ('true', '1', 'yes', 'on')
|
||
elif config_type == 'json':
|
||
try:
|
||
return json.loads(value)
|
||
except:
|
||
return {}
|
||
return value
|
||
|
||
@staticmethod
|
||
def _convert_to_string(value, config_type):
|
||
"""转换值为字符串"""
|
||
if config_type == 'json':
|
||
return json.dumps(value, ensure_ascii=False)
|
||
return str(value)
|
||
|
||
|
||
class Trade:
|
||
"""交易记录模型"""
|
||
|
||
@staticmethod
|
||
def create(symbol, side, quantity, entry_price, leverage=10, entry_reason=None):
|
||
"""创建交易记录(使用北京时间)"""
|
||
entry_time = get_beijing_time()
|
||
db.execute_update(
|
||
"""INSERT INTO trades
|
||
(symbol, side, quantity, entry_price, leverage, entry_reason, status, entry_time)
|
||
VALUES (%s, %s, %s, %s, %s, %s, 'open', %s)""",
|
||
(symbol, side, quantity, entry_price, leverage, entry_reason, entry_time)
|
||
)
|
||
return db.execute_one("SELECT LAST_INSERT_ID() as id")['id']
|
||
|
||
@staticmethod
|
||
def update_exit(trade_id, exit_price, exit_reason, pnl, pnl_percent):
|
||
"""更新平仓信息(使用北京时间)"""
|
||
exit_time = get_beijing_time()
|
||
db.execute_update(
|
||
"""UPDATE trades
|
||
SET exit_price = %s, exit_time = %s,
|
||
exit_reason = %s, pnl = %s, pnl_percent = %s, status = 'closed'
|
||
WHERE id = %s""",
|
||
(exit_price, exit_time, exit_reason, pnl, pnl_percent, trade_id)
|
||
)
|
||
|
||
@staticmethod
|
||
def get_all(start_date=None, end_date=None, symbol=None, status=None, trade_type=None, exit_reason=None):
|
||
"""获取交易记录"""
|
||
query = "SELECT * FROM trades WHERE 1=1"
|
||
params = []
|
||
|
||
if start_date:
|
||
query += " AND entry_time >= %s"
|
||
params.append(start_date)
|
||
if end_date:
|
||
query += " AND entry_time <= %s"
|
||
params.append(end_date)
|
||
if symbol:
|
||
query += " AND symbol = %s"
|
||
params.append(symbol)
|
||
if status:
|
||
query += " AND status = %s"
|
||
params.append(status)
|
||
if trade_type:
|
||
query += " AND side = %s"
|
||
params.append(trade_type)
|
||
if exit_reason:
|
||
query += " AND exit_reason = %s"
|
||
params.append(exit_reason)
|
||
|
||
query += " ORDER BY entry_time DESC"
|
||
return db.execute_query(query, params)
|
||
|
||
@staticmethod
|
||
def get_by_symbol(symbol, status='open'):
|
||
"""根据交易对获取持仓"""
|
||
return db.execute_query(
|
||
"SELECT * FROM trades WHERE symbol = %s AND status = %s",
|
||
(symbol, status)
|
||
)
|
||
|
||
|
||
class AccountSnapshot:
|
||
"""账户快照模型"""
|
||
|
||
@staticmethod
|
||
def create(total_balance, available_balance, total_position_value, total_pnl, open_positions):
|
||
"""创建账户快照(使用北京时间)"""
|
||
snapshot_time = get_beijing_time()
|
||
db.execute_update(
|
||
"""INSERT INTO account_snapshots
|
||
(total_balance, available_balance, total_position_value, total_pnl, open_positions, snapshot_time)
|
||
VALUES (%s, %s, %s, %s, %s, %s)""",
|
||
(total_balance, available_balance, total_position_value, total_pnl, open_positions, snapshot_time)
|
||
)
|
||
|
||
@staticmethod
|
||
def get_recent(days=7):
|
||
"""获取最近的快照"""
|
||
return db.execute_query(
|
||
"""SELECT * FROM account_snapshots
|
||
WHERE snapshot_time >= DATE_SUB(NOW(), INTERVAL %s DAY)
|
||
ORDER BY snapshot_time DESC""",
|
||
(days,)
|
||
)
|
||
|
||
|
||
class MarketScan:
|
||
"""市场扫描记录模型"""
|
||
|
||
@staticmethod
|
||
def create(symbols_scanned, symbols_found, top_symbols, scan_duration):
|
||
"""创建扫描记录(使用北京时间)"""
|
||
scan_time = get_beijing_time()
|
||
db.execute_update(
|
||
"""INSERT INTO market_scans
|
||
(symbols_scanned, symbols_found, top_symbols, scan_duration, scan_time)
|
||
VALUES (%s, %s, %s, %s, %s)""",
|
||
(symbols_scanned, symbols_found, json.dumps(top_symbols), scan_duration, scan_time)
|
||
)
|
||
|
||
@staticmethod
|
||
def get_recent(limit=100):
|
||
"""获取最近的扫描记录"""
|
||
return db.execute_query(
|
||
"SELECT * FROM market_scans ORDER BY scan_time DESC LIMIT %s",
|
||
(limit,)
|
||
)
|
||
|
||
|
||
class TradingSignal:
|
||
"""交易信号模型"""
|
||
|
||
@staticmethod
|
||
def create(symbol, signal_direction, signal_strength, signal_reason,
|
||
rsi=None, macd_histogram=None, market_regime=None):
|
||
"""创建交易信号(使用北京时间)"""
|
||
signal_time = get_beijing_time()
|
||
db.execute_update(
|
||
"""INSERT INTO trading_signals
|
||
(symbol, signal_direction, signal_strength, signal_reason,
|
||
rsi, macd_histogram, market_regime, signal_time)
|
||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)""",
|
||
(symbol, signal_direction, signal_strength, signal_reason,
|
||
rsi, macd_histogram, market_regime, signal_time)
|
||
)
|
||
|
||
@staticmethod
|
||
def mark_executed(signal_id):
|
||
"""标记信号已执行"""
|
||
db.execute_update(
|
||
"UPDATE trading_signals SET executed = TRUE WHERE id = %s",
|
||
(signal_id,)
|
||
)
|
||
|
||
@staticmethod
|
||
def get_recent(limit=100):
|
||
"""获取最近的信号"""
|
||
return db.execute_query(
|
||
"SELECT * FROM trading_signals ORDER BY signal_time DESC LIMIT %s",
|
||
(limit,)
|
||
)
|
||
|
||
|
||
class TradeRecommendation:
|
||
"""推荐交易对模型"""
|
||
|
||
@staticmethod
|
||
def create(
|
||
symbol, direction, current_price, change_percent, recommendation_reason,
|
||
signal_strength, market_regime=None, trend_4h=None,
|
||
rsi=None, macd_histogram=None,
|
||
bollinger_upper=None, bollinger_middle=None, bollinger_lower=None,
|
||
ema20=None, ema50=None, ema20_4h=None, atr=None,
|
||
suggested_stop_loss=None, suggested_take_profit_1=None, suggested_take_profit_2=None,
|
||
suggested_position_percent=None, suggested_leverage=10,
|
||
volume_24h=None, volatility=None, notes=None
|
||
):
|
||
"""创建推荐记录(使用北京时间)"""
|
||
recommendation_time = get_beijing_time()
|
||
# 默认24小时后过期
|
||
expires_at = recommendation_time + timedelta(hours=24)
|
||
|
||
db.execute_update(
|
||
"""INSERT INTO trade_recommendations
|
||
(symbol, direction, recommendation_time, current_price, change_percent,
|
||
recommendation_reason, signal_strength, market_regime, trend_4h,
|
||
rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower,
|
||
ema20, ema50, ema20_4h, atr,
|
||
suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2,
|
||
suggested_position_percent, suggested_leverage,
|
||
volume_24h, volatility, expires_at, notes)
|
||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)""",
|
||
(symbol, direction, recommendation_time, current_price, change_percent,
|
||
recommendation_reason, signal_strength, market_regime, trend_4h,
|
||
rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower,
|
||
ema20, ema50, ema20_4h, atr,
|
||
suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2,
|
||
suggested_position_percent, suggested_leverage,
|
||
volume_24h, volatility, expires_at, notes)
|
||
)
|
||
return db.execute_one("SELECT LAST_INSERT_ID() as id")['id']
|
||
|
||
@staticmethod
|
||
def mark_executed(recommendation_id, trade_id=None, execution_result='success'):
|
||
"""标记推荐已执行"""
|
||
executed_at = get_beijing_time()
|
||
db.execute_update(
|
||
"""UPDATE trade_recommendations
|
||
SET status = 'executed', executed_at = %s, execution_result = %s, execution_trade_id = %s
|
||
WHERE id = %s""",
|
||
(executed_at, execution_result, trade_id, recommendation_id)
|
||
)
|
||
|
||
@staticmethod
|
||
def mark_expired(recommendation_id):
|
||
"""标记推荐已过期"""
|
||
db.execute_update(
|
||
"UPDATE trade_recommendations SET status = 'expired' WHERE id = %s",
|
||
(recommendation_id,)
|
||
)
|
||
|
||
@staticmethod
|
||
def mark_cancelled(recommendation_id, notes=None):
|
||
"""标记推荐已取消"""
|
||
db.execute_update(
|
||
"UPDATE trade_recommendations SET status = 'cancelled', notes = %s WHERE id = %s",
|
||
(notes, recommendation_id)
|
||
)
|
||
|
||
@staticmethod
|
||
def get_all(status=None, direction=None, limit=100, start_date=None, end_date=None):
|
||
"""获取推荐记录"""
|
||
query = "SELECT * FROM trade_recommendations WHERE 1=1"
|
||
params = []
|
||
|
||
if status:
|
||
query += " AND status = %s"
|
||
params.append(status)
|
||
if direction:
|
||
query += " AND direction = %s"
|
||
params.append(direction)
|
||
if start_date:
|
||
query += " AND recommendation_time >= %s"
|
||
params.append(start_date)
|
||
if end_date:
|
||
query += " AND recommendation_time <= %s"
|
||
params.append(end_date)
|
||
|
||
query += " ORDER BY recommendation_time DESC, signal_strength DESC LIMIT %s"
|
||
params.append(limit)
|
||
|
||
return db.execute_query(query, params)
|
||
|
||
@staticmethod
|
||
def get_active():
|
||
"""获取当前有效的推荐(未过期、未执行、未取消)
|
||
同一交易对只返回最新的推荐(去重)
|
||
"""
|
||
return db.execute_query(
|
||
"""SELECT t1.* FROM trade_recommendations t1
|
||
INNER JOIN (
|
||
SELECT symbol, MAX(recommendation_time) as max_time
|
||
FROM trade_recommendations
|
||
WHERE status = 'active' AND (expires_at IS NULL OR expires_at > NOW())
|
||
GROUP BY symbol
|
||
) t2 ON t1.symbol = t2.symbol AND t1.recommendation_time = t2.max_time
|
||
WHERE t1.status = 'active' AND (t1.expires_at IS NULL OR t1.expires_at > NOW())
|
||
ORDER BY t1.signal_strength DESC, t1.recommendation_time DESC"""
|
||
)
|
||
|
||
@staticmethod
|
||
def get_by_id(recommendation_id):
|
||
"""根据ID获取推荐"""
|
||
return db.execute_one(
|
||
"SELECT * FROM trade_recommendations WHERE id = %s",
|
||
(recommendation_id,)
|
||
)
|
||
|
||
@staticmethod
|
||
def get_by_symbol(symbol, limit=10):
|
||
"""根据交易对获取推荐记录"""
|
||
return db.execute_query(
|
||
"""SELECT * FROM trade_recommendations
|
||
WHERE symbol = %s
|
||
ORDER BY recommendation_time DESC LIMIT %s""",
|
||
(symbol, limit)
|
||
)
|