213 lines
6.8 KiB
Python
213 lines
6.8 KiB
Python
"""
|
|
数据库模型定义
|
|
"""
|
|
from database.connection import db
|
|
from datetime import datetime
|
|
import json
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TradingConfig:
|
|
"""交易配置模型"""
|
|
|
|
@staticmethod
|
|
def get_all():
|
|
"""获取所有配置"""
|
|
return db.execute_query(
|
|
"SELECT * FROM trading_config ORDER BY category, config_key"
|
|
)
|
|
|
|
@staticmethod
|
|
def get(key):
|
|
"""获取单个配置"""
|
|
return db.execute_one(
|
|
"SELECT * FROM trading_config WHERE config_key = %s",
|
|
(key,)
|
|
)
|
|
|
|
@staticmethod
|
|
def set(key, value, config_type, category, description=None):
|
|
"""设置配置"""
|
|
value_str = TradingConfig._convert_to_string(value, config_type)
|
|
db.execute_update(
|
|
"""INSERT INTO trading_config
|
|
(config_key, config_value, config_type, category, description)
|
|
VALUES (%s, %s, %s, %s, %s)
|
|
ON DUPLICATE KEY UPDATE
|
|
config_value = VALUES(config_value),
|
|
config_type = VALUES(config_type),
|
|
category = VALUES(category),
|
|
description = VALUES(description),
|
|
updated_at = CURRENT_TIMESTAMP""",
|
|
(key, value_str, config_type, category, description)
|
|
)
|
|
|
|
@staticmethod
|
|
def get_value(key, default=None):
|
|
"""获取配置值(自动转换类型)"""
|
|
result = TradingConfig.get(key)
|
|
if result:
|
|
return TradingConfig._convert_value(result['config_value'], result['config_type'])
|
|
return default
|
|
|
|
@staticmethod
|
|
def _convert_value(value, config_type):
|
|
"""转换配置值类型"""
|
|
if config_type == 'number':
|
|
try:
|
|
return float(value) if '.' in str(value) else int(value)
|
|
except:
|
|
return 0
|
|
elif config_type == 'boolean':
|
|
return str(value).lower() in ('true', '1', 'yes', 'on')
|
|
elif config_type == 'json':
|
|
try:
|
|
return json.loads(value)
|
|
except:
|
|
return {}
|
|
return value
|
|
|
|
@staticmethod
|
|
def _convert_to_string(value, config_type):
|
|
"""转换值为字符串"""
|
|
if config_type == 'json':
|
|
return json.dumps(value, ensure_ascii=False)
|
|
return str(value)
|
|
|
|
|
|
class Trade:
|
|
"""交易记录模型"""
|
|
|
|
@staticmethod
|
|
def create(symbol, side, quantity, entry_price, leverage=10, entry_reason=None):
|
|
"""创建交易记录"""
|
|
db.execute_update(
|
|
"""INSERT INTO trades
|
|
(symbol, side, quantity, entry_price, leverage, entry_reason, status)
|
|
VALUES (%s, %s, %s, %s, %s, %s, 'open')""",
|
|
(symbol, side, quantity, entry_price, leverage, entry_reason)
|
|
)
|
|
return db.execute_one("SELECT LAST_INSERT_ID() as id")['id']
|
|
|
|
@staticmethod
|
|
def update_exit(trade_id, exit_price, exit_reason, pnl, pnl_percent):
|
|
"""更新平仓信息"""
|
|
db.execute_update(
|
|
"""UPDATE trades
|
|
SET exit_price = %s, exit_time = CURRENT_TIMESTAMP,
|
|
exit_reason = %s, pnl = %s, pnl_percent = %s, status = 'closed'
|
|
WHERE id = %s""",
|
|
(exit_price, exit_reason, pnl, pnl_percent, trade_id)
|
|
)
|
|
|
|
@staticmethod
|
|
def get_all(start_date=None, end_date=None, symbol=None, status=None):
|
|
"""获取交易记录"""
|
|
query = "SELECT * FROM trades WHERE 1=1"
|
|
params = []
|
|
|
|
if start_date:
|
|
query += " AND entry_time >= %s"
|
|
params.append(start_date)
|
|
if end_date:
|
|
query += " AND entry_time <= %s"
|
|
params.append(end_date)
|
|
if symbol:
|
|
query += " AND symbol = %s"
|
|
params.append(symbol)
|
|
if status:
|
|
query += " AND status = %s"
|
|
params.append(status)
|
|
|
|
query += " ORDER BY entry_time DESC"
|
|
return db.execute_query(query, params)
|
|
|
|
@staticmethod
|
|
def get_by_symbol(symbol, status='open'):
|
|
"""根据交易对获取持仓"""
|
|
return db.execute_query(
|
|
"SELECT * FROM trades WHERE symbol = %s AND status = %s",
|
|
(symbol, status)
|
|
)
|
|
|
|
|
|
class AccountSnapshot:
|
|
"""账户快照模型"""
|
|
|
|
@staticmethod
|
|
def create(total_balance, available_balance, total_position_value, total_pnl, open_positions):
|
|
"""创建账户快照"""
|
|
db.execute_update(
|
|
"""INSERT INTO account_snapshots
|
|
(total_balance, available_balance, total_position_value, total_pnl, open_positions)
|
|
VALUES (%s, %s, %s, %s, %s)""",
|
|
(total_balance, available_balance, total_position_value, total_pnl, open_positions)
|
|
)
|
|
|
|
@staticmethod
|
|
def get_recent(days=7):
|
|
"""获取最近的快照"""
|
|
return db.execute_query(
|
|
"""SELECT * FROM account_snapshots
|
|
WHERE snapshot_time >= DATE_SUB(NOW(), INTERVAL %s DAY)
|
|
ORDER BY snapshot_time DESC""",
|
|
(days,)
|
|
)
|
|
|
|
|
|
class MarketScan:
|
|
"""市场扫描记录模型"""
|
|
|
|
@staticmethod
|
|
def create(symbols_scanned, symbols_found, top_symbols, scan_duration):
|
|
"""创建扫描记录"""
|
|
db.execute_update(
|
|
"""INSERT INTO market_scans
|
|
(symbols_scanned, symbols_found, top_symbols, scan_duration)
|
|
VALUES (%s, %s, %s, %s)""",
|
|
(symbols_scanned, symbols_found, json.dumps(top_symbols), scan_duration)
|
|
)
|
|
|
|
@staticmethod
|
|
def get_recent(limit=100):
|
|
"""获取最近的扫描记录"""
|
|
return db.execute_query(
|
|
"SELECT * FROM market_scans ORDER BY scan_time DESC LIMIT %s",
|
|
(limit,)
|
|
)
|
|
|
|
|
|
class TradingSignal:
|
|
"""交易信号模型"""
|
|
|
|
@staticmethod
|
|
def create(symbol, signal_direction, signal_strength, signal_reason,
|
|
rsi=None, macd_histogram=None, market_regime=None):
|
|
"""创建交易信号"""
|
|
db.execute_update(
|
|
"""INSERT INTO trading_signals
|
|
(symbol, signal_direction, signal_strength, signal_reason,
|
|
rsi, macd_histogram, market_regime)
|
|
VALUES (%s, %s, %s, %s, %s, %s, %s)""",
|
|
(symbol, signal_direction, signal_strength, signal_reason,
|
|
rsi, macd_histogram, market_regime)
|
|
)
|
|
|
|
@staticmethod
|
|
def mark_executed(signal_id):
|
|
"""标记信号已执行"""
|
|
db.execute_update(
|
|
"UPDATE trading_signals SET executed = TRUE WHERE id = %s",
|
|
(signal_id,)
|
|
)
|
|
|
|
@staticmethod
|
|
def get_recent(limit=100):
|
|
"""获取最近的信号"""
|
|
return db.execute_query(
|
|
"SELECT * FROM trading_signals ORDER BY signal_time DESC LIMIT %s",
|
|
(limit,)
|
|
)
|