auto_trade_sys/backend/database/models.py
薇薇安 09373b16ac a
2026-01-14 13:50:00 +08:00

225 lines
7.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
数据库模型定义
"""
from database.connection import db
from datetime import datetime, timezone, timedelta
import json
import logging
# 北京时间时区UTC+8
BEIJING_TZ = timezone(timedelta(hours=8))
def get_beijing_time():
"""获取当前北京时间UTC+8"""
return datetime.now(BEIJING_TZ).replace(tzinfo=None)
logger = logging.getLogger(__name__)
class TradingConfig:
"""交易配置模型"""
@staticmethod
def get_all():
"""获取所有配置"""
return db.execute_query(
"SELECT * FROM trading_config ORDER BY category, config_key"
)
@staticmethod
def get(key):
"""获取单个配置"""
return db.execute_one(
"SELECT * FROM trading_config WHERE config_key = %s",
(key,)
)
@staticmethod
def set(key, value, config_type, category, description=None):
"""设置配置"""
value_str = TradingConfig._convert_to_string(value, config_type)
db.execute_update(
"""INSERT INTO trading_config
(config_key, config_value, config_type, category, description)
VALUES (%s, %s, %s, %s, %s)
ON DUPLICATE KEY UPDATE
config_value = VALUES(config_value),
config_type = VALUES(config_type),
category = VALUES(category),
description = VALUES(description),
updated_at = CURRENT_TIMESTAMP""",
(key, value_str, config_type, category, description)
)
@staticmethod
def get_value(key, default=None):
"""获取配置值(自动转换类型)"""
result = TradingConfig.get(key)
if result:
return TradingConfig._convert_value(result['config_value'], result['config_type'])
return default
@staticmethod
def _convert_value(value, config_type):
"""转换配置值类型"""
if config_type == 'number':
try:
return float(value) if '.' in str(value) else int(value)
except:
return 0
elif config_type == 'boolean':
return str(value).lower() in ('true', '1', 'yes', 'on')
elif config_type == 'json':
try:
return json.loads(value)
except:
return {}
return value
@staticmethod
def _convert_to_string(value, config_type):
"""转换值为字符串"""
if config_type == 'json':
return json.dumps(value, ensure_ascii=False)
return str(value)
class Trade:
"""交易记录模型"""
@staticmethod
def create(symbol, side, quantity, entry_price, leverage=10, entry_reason=None):
"""创建交易记录(使用北京时间)"""
entry_time = get_beijing_time()
db.execute_update(
"""INSERT INTO trades
(symbol, side, quantity, entry_price, leverage, entry_reason, status, entry_time)
VALUES (%s, %s, %s, %s, %s, %s, 'open', %s)""",
(symbol, side, quantity, entry_price, leverage, entry_reason, entry_time)
)
return db.execute_one("SELECT LAST_INSERT_ID() as id")['id']
@staticmethod
def update_exit(trade_id, exit_price, exit_reason, pnl, pnl_percent):
"""更新平仓信息(使用北京时间)"""
exit_time = get_beijing_time()
db.execute_update(
"""UPDATE trades
SET exit_price = %s, exit_time = %s,
exit_reason = %s, pnl = %s, pnl_percent = %s, status = 'closed'
WHERE id = %s""",
(exit_price, exit_time, exit_reason, pnl, pnl_percent, trade_id)
)
@staticmethod
def get_all(start_date=None, end_date=None, symbol=None, status=None):
"""获取交易记录"""
query = "SELECT * FROM trades WHERE 1=1"
params = []
if start_date:
query += " AND entry_time >= %s"
params.append(start_date)
if end_date:
query += " AND entry_time <= %s"
params.append(end_date)
if symbol:
query += " AND symbol = %s"
params.append(symbol)
if status:
query += " AND status = %s"
params.append(status)
query += " ORDER BY entry_time DESC"
return db.execute_query(query, params)
@staticmethod
def get_by_symbol(symbol, status='open'):
"""根据交易对获取持仓"""
return db.execute_query(
"SELECT * FROM trades WHERE symbol = %s AND status = %s",
(symbol, status)
)
class AccountSnapshot:
"""账户快照模型"""
@staticmethod
def create(total_balance, available_balance, total_position_value, total_pnl, open_positions):
"""创建账户快照(使用北京时间)"""
snapshot_time = get_beijing_time()
db.execute_update(
"""INSERT INTO account_snapshots
(total_balance, available_balance, total_position_value, total_pnl, open_positions, snapshot_time)
VALUES (%s, %s, %s, %s, %s, %s)""",
(total_balance, available_balance, total_position_value, total_pnl, open_positions, snapshot_time)
)
@staticmethod
def get_recent(days=7):
"""获取最近的快照"""
return db.execute_query(
"""SELECT * FROM account_snapshots
WHERE snapshot_time >= DATE_SUB(NOW(), INTERVAL %s DAY)
ORDER BY snapshot_time DESC""",
(days,)
)
class MarketScan:
"""市场扫描记录模型"""
@staticmethod
def create(symbols_scanned, symbols_found, top_symbols, scan_duration):
"""创建扫描记录(使用北京时间)"""
scan_time = get_beijing_time()
db.execute_update(
"""INSERT INTO market_scans
(symbols_scanned, symbols_found, top_symbols, scan_duration, scan_time)
VALUES (%s, %s, %s, %s, %s)""",
(symbols_scanned, symbols_found, json.dumps(top_symbols), scan_duration, scan_time)
)
@staticmethod
def get_recent(limit=100):
"""获取最近的扫描记录"""
return db.execute_query(
"SELECT * FROM market_scans ORDER BY scan_time DESC LIMIT %s",
(limit,)
)
class TradingSignal:
"""交易信号模型"""
@staticmethod
def create(symbol, signal_direction, signal_strength, signal_reason,
rsi=None, macd_histogram=None, market_regime=None):
"""创建交易信号(使用北京时间)"""
signal_time = get_beijing_time()
db.execute_update(
"""INSERT INTO trading_signals
(symbol, signal_direction, signal_strength, signal_reason,
rsi, macd_histogram, market_regime, signal_time)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)""",
(symbol, signal_direction, signal_strength, signal_reason,
rsi, macd_histogram, market_regime, signal_time)
)
@staticmethod
def mark_executed(signal_id):
"""标记信号已执行"""
db.execute_update(
"UPDATE trading_signals SET executed = TRUE WHERE id = %s",
(signal_id,)
)
@staticmethod
def get_recent(limit=100):
"""获取最近的信号"""
return db.execute_query(
"SELECT * FROM trading_signals ORDER BY signal_time DESC LIMIT %s",
(limit,)
)