auto_trade_sys/backend/database/models.py
薇薇安 fad8a1d6fd a
2026-01-23 20:35:11 +08:00

974 lines
40 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
数据库模型定义
"""
from database.connection import db
from datetime import datetime, timezone, timedelta
import json
import logging
import os
# 北京时间时区UTC+8
BEIJING_TZ = timezone(timedelta(hours=8))
def get_beijing_time():
"""获取当前北京时间UTC+8的Unix时间戳"""
return int(datetime.now(BEIJING_TZ).timestamp())
logger = logging.getLogger(__name__)
def _resolve_default_account_id() -> int:
"""
默认账号ID
- trading_system 多进程:每个进程可通过 ATS_ACCOUNT_ID 指定自己的 account_id
- backend未传 account_id 时默认走 1兼容单账号
"""
for k in ("ATS_ACCOUNT_ID", "ACCOUNT_ID", "ATS_DEFAULT_ACCOUNT_ID"):
v = (os.getenv(k, "") or "").strip()
if v:
try:
return int(v)
except Exception:
continue
return 1
DEFAULT_ACCOUNT_ID = _resolve_default_account_id()
def _table_has_column(table: str, col: str) -> bool:
try:
db.execute_one(f"SELECT {col} FROM {table} LIMIT 1")
return True
except Exception:
return False
class Account:
"""
账号模型(多账号)
- API Key/Secret 建议加密存储在 accounts 表中,而不是 trading_config
"""
@staticmethod
def get(account_id: int):
import logging
logger = logging.getLogger(__name__)
logger.info(f"Account.get called with account_id={account_id}")
row = db.execute_one("SELECT * FROM accounts WHERE id = %s", (int(account_id),))
if row:
logger.info(f"Account.get: found account_id={account_id}, name={row.get('name', 'N/A')}, status={row.get('status', 'N/A')}")
else:
logger.warning(f"Account.get: account_id={account_id} not found in database")
return row
@staticmethod
def list_all():
return db.execute_query("SELECT id, name, status, created_at, updated_at FROM accounts ORDER BY id ASC")
@staticmethod
def create(name: str, api_key: str = "", api_secret: str = "", use_testnet: bool = False, status: str = "active"):
from security.crypto import encrypt_str # 延迟导入,避免无依赖时直接崩
api_key_enc = encrypt_str(api_key or "")
api_secret_enc = encrypt_str(api_secret or "")
db.execute_update(
"""INSERT INTO accounts (name, status, api_key_enc, api_secret_enc, use_testnet)
VALUES (%s, %s, %s, %s, %s)""",
(name, status, api_key_enc, api_secret_enc, bool(use_testnet)),
)
return db.execute_one("SELECT LAST_INSERT_ID() as id")["id"]
@staticmethod
def update_credentials(account_id: int, api_key: str = None, api_secret: str = None, use_testnet: bool = None):
from security.crypto import encrypt_str # 延迟导入
fields = []
params = []
if api_key is not None:
fields.append("api_key_enc = %s")
params.append(encrypt_str(api_key))
if api_secret is not None:
fields.append("api_secret_enc = %s")
params.append(encrypt_str(api_secret))
if use_testnet is not None:
fields.append("use_testnet = %s")
params.append(bool(use_testnet))
if not fields:
return
params.append(int(account_id))
db.execute_update(f"UPDATE accounts SET {', '.join(fields)} WHERE id = %s", tuple(params))
@staticmethod
def get_credentials(account_id: int):
"""
返回 (api_key, api_secret, use_testnet, status);密文字段会自动解密。
若未配置 master key 且库里是明文,仍可工作(但不安全)。
"""
import logging
logger = logging.getLogger(__name__)
logger.info(f"Account.get_credentials called with account_id={account_id}")
row = Account.get(account_id)
if not row:
logger.warning(f"Account.get_credentials: account_id={account_id} not found in database")
return "", "", False, "disabled"
try:
from security.crypto import decrypt_str
status = row.get("status") or "active"
api_key = decrypt_str(row.get("api_key_enc") or "")
api_secret = decrypt_str(row.get("api_secret_enc") or "")
except Exception:
# 兼容:无 cryptography 或未配 master key 时:
# - 若库里是明文,仍可工作
# - 若库里是 enc:v1 密文但未配 ATS_MASTER_KEY则不能解密也不能把密文当作 Key 使用
status = "disabled"
api_key_raw = row.get("api_key_enc") or ""
api_secret_raw = row.get("api_secret_enc") or ""
api_key = "" if str(api_key_raw).startswith("enc:v1:") else str(api_key_raw)
api_secret = "" if str(api_secret_raw).startswith("enc:v1:") else str(api_secret_raw)
use_testnet = bool(row.get("use_testnet") or False)
return api_key, api_secret, use_testnet, status
class User:
"""登录用户(管理员/普通用户)"""
@staticmethod
def get_by_username(username: str):
return db.execute_one("SELECT * FROM users WHERE username = %s", (str(username),))
@staticmethod
def get_by_id(user_id: int):
return db.execute_one("SELECT * FROM users WHERE id = %s", (int(user_id),))
@staticmethod
def list_all():
return db.execute_query("SELECT id, username, role, status, created_at, updated_at FROM users ORDER BY id ASC")
@staticmethod
def create(username: str, password_hash: str, role: str = "user", status: str = "active"):
db.execute_update(
"INSERT INTO users (username, password_hash, role, status) VALUES (%s, %s, %s, %s)",
(username, password_hash, role, status),
)
return db.execute_one("SELECT LAST_INSERT_ID() as id")["id"]
@staticmethod
def set_password(user_id: int, password_hash: str):
db.execute_update("UPDATE users SET password_hash = %s WHERE id = %s", (password_hash, int(user_id)))
@staticmethod
def set_status(user_id: int, status: str):
db.execute_update("UPDATE users SET status = %s WHERE id = %s", (status, int(user_id)))
@staticmethod
def set_role(user_id: int, role: str):
db.execute_update("UPDATE users SET role = %s WHERE id = %s", (role, int(user_id)))
class UserAccountMembership:
"""用户-交易账号授权关系"""
@staticmethod
def add(user_id: int, account_id: int, role: str = "viewer"):
db.execute_update(
"""INSERT INTO user_account_memberships (user_id, account_id, role)
VALUES (%s, %s, %s)
ON DUPLICATE KEY UPDATE role = VALUES(role)""",
(int(user_id), int(account_id), role),
)
@staticmethod
def remove(user_id: int, account_id: int):
db.execute_update(
"DELETE FROM user_account_memberships WHERE user_id = %s AND account_id = %s",
(int(user_id), int(account_id)),
)
@staticmethod
def list_for_user(user_id: int):
return db.execute_query(
"SELECT * FROM user_account_memberships WHERE user_id = %s ORDER BY account_id ASC",
(int(user_id),),
)
@staticmethod
def list_for_account(account_id: int):
return db.execute_query(
"SELECT * FROM user_account_memberships WHERE account_id = %s ORDER BY user_id ASC",
(int(account_id),),
)
@staticmethod
def has_access(user_id: int, account_id: int) -> bool:
row = db.execute_one(
"SELECT 1 as ok FROM user_account_memberships WHERE user_id = %s AND account_id = %s",
(int(user_id), int(account_id)),
)
return bool(row)
@staticmethod
def get_role(user_id: int, account_id: int) -> str:
row = db.execute_one(
"SELECT role FROM user_account_memberships WHERE user_id = %s AND account_id = %s",
(int(user_id), int(account_id)),
)
return (row.get("role") if isinstance(row, dict) else None) or ""
class TradingConfig:
"""交易配置模型"""
@staticmethod
def get_all(account_id: int = None):
"""获取所有配置"""
aid = int(account_id or DEFAULT_ACCOUNT_ID)
if _table_has_column("trading_config", "account_id"):
return db.execute_query(
"SELECT * FROM trading_config WHERE account_id = %s ORDER BY category, config_key",
(aid,),
)
return db.execute_query("SELECT * FROM trading_config ORDER BY category, config_key")
@staticmethod
def get(key, account_id: int = None):
"""获取单个配置"""
aid = int(account_id or DEFAULT_ACCOUNT_ID)
if _table_has_column("trading_config", "account_id"):
return db.execute_one(
"SELECT * FROM trading_config WHERE account_id = %s AND config_key = %s",
(aid, key),
)
return db.execute_one("SELECT * FROM trading_config WHERE config_key = %s", (key,))
@staticmethod
def set(key, value, config_type, category, description=None, account_id: int = None):
"""设置配置"""
value_str = TradingConfig._convert_to_string(value, config_type)
aid = int(account_id or DEFAULT_ACCOUNT_ID)
if _table_has_column("trading_config", "account_id"):
db.execute_update(
"""INSERT INTO trading_config
(account_id, config_key, config_value, config_type, category, description)
VALUES (%s, %s, %s, %s, %s, %s)
ON DUPLICATE KEY UPDATE
config_value = VALUES(config_value),
config_type = VALUES(config_type),
category = VALUES(category),
description = VALUES(description),
updated_at = CURRENT_TIMESTAMP""",
(aid, key, value_str, config_type, category, description),
)
else:
db.execute_update(
"""INSERT INTO trading_config
(config_key, config_value, config_type, category, description)
VALUES (%s, %s, %s, %s, %s)
ON DUPLICATE KEY UPDATE
config_value = VALUES(config_value),
config_type = VALUES(config_type),
category = VALUES(category),
description = VALUES(description),
updated_at = CURRENT_TIMESTAMP""",
(key, value_str, config_type, category, description),
)
@staticmethod
def get_value(key, default=None, account_id: int = None):
"""获取配置值(自动转换类型)"""
result = TradingConfig.get(key, account_id=account_id)
if result:
return TradingConfig._convert_value(result['config_value'], result['config_type'])
return default
@staticmethod
def _convert_value(value, config_type):
"""转换配置值类型"""
if config_type == 'number':
try:
return float(value) if '.' in str(value) else int(value)
except:
return 0
elif config_type == 'boolean':
return str(value).lower() in ('true', '1', 'yes', 'on')
elif config_type == 'json':
try:
return json.loads(value)
except:
return {}
return value
@staticmethod
def _convert_to_string(value, config_type):
"""转换值为字符串"""
if config_type == 'json':
return json.dumps(value, ensure_ascii=False)
return str(value)
class GlobalStrategyConfig:
"""全局策略配置模型(独立于账户,管理员专用)"""
@staticmethod
def get_all():
"""获取所有全局配置"""
if not _table_has_column("global_strategy_config", "config_key"):
return []
return db.execute_query(
"SELECT * FROM global_strategy_config ORDER BY category, config_key"
)
@staticmethod
def get(key):
"""获取单个全局配置"""
if not _table_has_column("global_strategy_config", "config_key"):
return None
return db.execute_one(
"SELECT * FROM global_strategy_config WHERE config_key = %s",
(key,)
)
@staticmethod
def set(key, value, config_type, category, description=None, updated_by=None):
"""设置全局配置"""
if not _table_has_column("global_strategy_config", "config_key"):
# 表不存在时回退到trading_config兼容旧系统
return TradingConfig.set(key, value, config_type, category, description, account_id=1)
value_str = TradingConfig._convert_to_string(value, config_type)
db.execute_update(
"""INSERT INTO global_strategy_config
(config_key, config_value, config_type, category, description, updated_by)
VALUES (%s, %s, %s, %s, %s, %s)
ON DUPLICATE KEY UPDATE
config_value = VALUES(config_value),
config_type = VALUES(config_type),
category = VALUES(category),
description = VALUES(description),
updated_by = VALUES(updated_by),
updated_at = CURRENT_TIMESTAMP""",
(key, value_str, config_type, category, description, updated_by),
)
@staticmethod
def get_value(key, default=None):
"""获取全局配置值(自动转换类型)"""
result = GlobalStrategyConfig.get(key)
if result:
return GlobalStrategyConfig._convert_value(result['config_value'], result['config_type'])
return default
@staticmethod
def _convert_value(value, config_type):
"""转换配置值类型复用TradingConfig的逻辑"""
return TradingConfig._convert_value(value, config_type)
@staticmethod
def delete(key):
"""删除全局配置"""
if not _table_has_column("global_strategy_config", "config_key"):
return
db.execute_update(
"DELETE FROM global_strategy_config WHERE config_key = %s",
(key,)
)
class Trade:
"""交易记录模型"""
@staticmethod
def create(
symbol,
side,
quantity,
entry_price,
leverage=10,
entry_reason=None,
entry_order_id=None,
stop_loss_price=None,
take_profit_price=None,
take_profit_1=None,
take_profit_2=None,
atr=None,
notional_usdt=None,
margin_usdt=None,
account_id: int = None,
):
"""创建交易记录(使用北京时间)
Args:
symbol: 交易对
side: 方向
quantity: 数量
entry_price: 入场价
leverage: 杠杆
entry_reason: 入场原因
entry_order_id: 币安开仓订单号(可选,用于对账)
stop_loss_price: 实际使用的止损价格考虑了ATR等动态计算
take_profit_price: 实际使用的止盈价格考虑了ATR等动态计算
take_profit_1: 第一目标止盈价(可选)
take_profit_2: 第二目标止盈价(可选)
atr: 开仓时使用的ATR值可选
notional_usdt: 名义下单量USDT可选
margin_usdt: 保证金USDT可选
"""
entry_time = get_beijing_time()
# 自动计算 notional/margin若调用方没传
try:
if notional_usdt is None and quantity is not None and entry_price is not None:
notional_usdt = float(quantity) * float(entry_price)
except Exception:
pass
try:
if margin_usdt is None and notional_usdt is not None:
lv = float(leverage) if leverage else 0
margin_usdt = (float(notional_usdt) / lv) if lv and lv > 0 else float(notional_usdt)
except Exception:
pass
def _has_column(col: str) -> bool:
try:
db.execute_one(f"SELECT {col} FROM trades LIMIT 1")
return True
except Exception:
return False
# 动态构建 INSERT兼容不同schema
columns = ["symbol", "side", "quantity", "entry_price", "leverage", "entry_reason", "status", "entry_time"]
values = [symbol, side, quantity, entry_price, leverage, entry_reason, "open", entry_time]
if _has_column("account_id"):
columns.insert(0, "account_id")
values.insert(0, int(account_id or DEFAULT_ACCOUNT_ID))
if _has_column("entry_order_id"):
columns.append("entry_order_id")
values.append(entry_order_id)
if _has_column("notional_usdt"):
columns.append("notional_usdt")
values.append(notional_usdt)
if _has_column("margin_usdt"):
columns.append("margin_usdt")
values.append(margin_usdt)
if _has_column("atr"):
columns.append("atr")
values.append(atr)
if _has_column("stop_loss_price"):
columns.append("stop_loss_price")
values.append(stop_loss_price)
if _has_column("take_profit_price"):
columns.append("take_profit_price")
values.append(take_profit_price)
if _has_column("take_profit_1"):
columns.append("take_profit_1")
values.append(take_profit_1)
if _has_column("take_profit_2"):
columns.append("take_profit_2")
values.append(take_profit_2)
placeholders = ", ".join(["%s"] * len(columns))
sql = f"INSERT INTO trades ({', '.join(columns)}) VALUES ({placeholders})"
db.execute_update(sql, tuple(values))
return db.execute_one("SELECT LAST_INSERT_ID() as id")['id']
@staticmethod
def update_exit(
trade_id,
exit_price,
exit_reason,
pnl,
pnl_percent,
exit_order_id=None,
strategy_type=None,
duration_minutes=None,
exit_time_ts=None,
):
"""更新平仓信息(使用北京时间)
Args:
trade_id: 交易记录ID
exit_price: 出场价
exit_reason: 平仓原因
pnl: 盈亏
pnl_percent: 盈亏百分比
exit_order_id: 币安平仓订单号(可选,用于对账)
注意:如果 exit_order_id 已存在且属于其他交易记录,会跳过更新 exit_order_id 以避免唯一约束冲突
"""
# exit_time_ts: 允许外部传入“真实成交时间”Unix秒以便统计持仓时长更准确
try:
exit_time = int(exit_time_ts) if exit_time_ts is not None else get_beijing_time()
except Exception:
exit_time = get_beijing_time()
# 如果提供了 exit_order_id先检查是否已被其他交易记录使用
if exit_order_id is not None:
try:
existing_trade = Trade.get_by_exit_order_id(exit_order_id)
if existing_trade:
if existing_trade['id'] == trade_id:
# exit_order_id 属于当前交易记录:允许继续更新(比如补写 exit_reason / exit_time / duration
# 不需要提前 return
logger.debug(
f"交易记录 {trade_id} 的 exit_order_id {exit_order_id} 已存在,将继续更新其他字段"
)
else:
# 如果 exit_order_id 已被其他交易记录使用,记录警告但不更新 exit_order_id
logger.warning(
f"交易记录 {trade_id} 的 exit_order_id {exit_order_id} 已被交易记录 {existing_trade['id']} 使用,"
f"跳过更新 exit_order_id只更新其他字段"
)
# 只更新其他字段,不更新 exit_order_id
update_fields = [
"exit_price = %s", "exit_time = %s",
"exit_reason = %s", "pnl = %s", "pnl_percent = %s", "status = 'closed'"
]
update_values = [exit_price, exit_time, exit_reason, pnl, pnl_percent]
if strategy_type is not None:
update_fields.append("strategy_type = %s")
update_values.append(strategy_type)
if duration_minutes is not None:
update_fields.append("duration_minutes = %s")
update_values.append(duration_minutes)
update_values.append(trade_id)
db.execute_update(
f"UPDATE trades SET {', '.join(update_fields)} WHERE id = %s",
tuple(update_values)
)
return
except Exception as e:
# 如果查询失败,记录警告但继续正常更新
logger.warning(f"检查 exit_order_id {exit_order_id} 时出错: {e},继续正常更新")
# 正常更新(包括 exit_order_id
try:
update_fields = [
"exit_price = %s", "exit_time = %s",
"exit_reason = %s", "pnl = %s", "pnl_percent = %s", "status = 'closed'",
"exit_order_id = %s"
]
update_values = [exit_price, exit_time, exit_reason, pnl, pnl_percent, exit_order_id]
if strategy_type is not None:
update_fields.append("strategy_type = %s")
update_values.append(strategy_type)
if duration_minutes is not None:
update_fields.append("duration_minutes = %s")
update_values.append(duration_minutes)
update_values.append(trade_id)
db.execute_update(
f"UPDATE trades SET {', '.join(update_fields)} WHERE id = %s",
tuple(update_values)
)
except Exception as e:
# 如果更新失败(可能是唯一约束冲突),尝试不更新 exit_order_id
error_str = str(e)
if "Duplicate entry" in error_str and "exit_order_id" in error_str:
logger.warning(
f"更新交易记录 {trade_id} 时 exit_order_id {exit_order_id} 唯一约束冲突,"
f"尝试不更新 exit_order_id"
)
# 只更新其他字段,不更新 exit_order_id
update_fields = [
"exit_price = %s", "exit_time = %s",
"exit_reason = %s", "pnl = %s", "pnl_percent = %s", "status = 'closed'"
]
update_values = [exit_price, exit_time, exit_reason, pnl, pnl_percent]
if strategy_type is not None:
update_fields.append("strategy_type = %s")
update_values.append(strategy_type)
if duration_minutes is not None:
update_fields.append("duration_minutes = %s")
update_values.append(duration_minutes)
update_values.append(trade_id)
db.execute_update(
f"UPDATE trades SET {', '.join(update_fields)} WHERE id = %s",
tuple(update_values)
)
else:
# 其他错误,重新抛出
raise
@staticmethod
def get_by_entry_order_id(entry_order_id):
"""根据开仓订单号获取交易记录"""
return db.execute_one(
"SELECT * FROM trades WHERE entry_order_id = %s",
(entry_order_id,)
)
@staticmethod
def get_by_exit_order_id(exit_order_id):
"""根据平仓订单号获取交易记录"""
return db.execute_one(
"SELECT * FROM trades WHERE exit_order_id = %s",
(exit_order_id,)
)
@staticmethod
def get_all(start_timestamp=None, end_timestamp=None, symbol=None, status=None, trade_type=None, exit_reason=None, account_id: int = None):
"""获取交易记录
Args:
start_timestamp: 开始时间Unix时间戳秒数可选
end_timestamp: 结束时间Unix时间戳秒数可选
symbol: 交易对(可选)
status: 状态(可选)
trade_type: 交易类型(可选)
exit_reason: 平仓原因(可选)
"""
query = "SELECT * FROM trades WHERE 1=1"
params = []
# 多账号隔离兼容旧schema
try:
if _table_has_column("trades", "account_id"):
query += " AND account_id = %s"
params.append(int(account_id or DEFAULT_ACCOUNT_ID))
except Exception:
pass
if start_timestamp is not None:
query += " AND created_at >= %s"
params.append(start_timestamp)
if end_timestamp is not None:
query += " AND created_at <= %s"
params.append(end_timestamp)
if symbol:
query += " AND symbol = %s"
params.append(symbol)
if status:
query += " AND status = %s"
params.append(status)
if trade_type:
query += " AND side = %s"
params.append(trade_type)
if exit_reason:
query += " AND exit_reason = %s"
params.append(exit_reason)
# 按平仓时间倒序排序,如果没有平仓时间则按入场时间倒序
# query += " ORDER BY COALESCE(exit_time, entry_time) DESC, entry_time DESC"
query += " ORDER BY id DESC"
logger.info(f"查询交易记录: {query}, {params}")
result = db.execute_query(query, params)
return result
@staticmethod
def get_by_symbol(symbol, status='open', account_id: int = None):
"""根据交易对获取持仓"""
aid = int(account_id or DEFAULT_ACCOUNT_ID)
if _table_has_column("trades", "account_id"):
return db.execute_query(
"SELECT * FROM trades WHERE account_id = %s AND symbol = %s AND status = %s",
(aid, symbol, status),
)
return db.execute_query(
"SELECT * FROM trades WHERE symbol = %s AND status = %s",
(symbol, status),
)
class AccountSnapshot:
"""账户快照模型"""
@staticmethod
def create(total_balance, available_balance, total_position_value, total_pnl, open_positions, account_id: int = None):
"""创建账户快照(使用北京时间)"""
snapshot_time = get_beijing_time()
if _table_has_column("account_snapshots", "account_id"):
db.execute_update(
"""INSERT INTO account_snapshots
(account_id, total_balance, available_balance, total_position_value, total_pnl, open_positions, snapshot_time)
VALUES (%s, %s, %s, %s, %s, %s, %s)""",
(int(account_id or DEFAULT_ACCOUNT_ID), total_balance, available_balance, total_position_value, total_pnl, open_positions, snapshot_time),
)
else:
db.execute_update(
"""INSERT INTO account_snapshots
(total_balance, available_balance, total_position_value, total_pnl, open_positions, snapshot_time)
VALUES (%s, %s, %s, %s, %s, %s)""",
(total_balance, available_balance, total_position_value, total_pnl, open_positions, snapshot_time),
)
@staticmethod
def get_recent(days=7, account_id: int = None):
"""获取最近的快照"""
aid = int(account_id or DEFAULT_ACCOUNT_ID)
if _table_has_column("account_snapshots", "account_id"):
return db.execute_query(
"""SELECT * FROM account_snapshots
WHERE account_id = %s AND snapshot_time >= DATE_SUB(NOW(), INTERVAL %s DAY)
ORDER BY snapshot_time DESC""",
(aid, days),
)
return db.execute_query(
"""SELECT * FROM account_snapshots
WHERE snapshot_time >= DATE_SUB(NOW(), INTERVAL %s DAY)
ORDER BY snapshot_time DESC""",
(days,),
)
class MarketScan:
"""市场扫描记录模型"""
@staticmethod
def create(symbols_scanned, symbols_found, top_symbols, scan_duration):
"""创建扫描记录(使用北京时间)"""
scan_time = get_beijing_time()
db.execute_update(
"""INSERT INTO market_scans
(symbols_scanned, symbols_found, top_symbols, scan_duration, scan_time)
VALUES (%s, %s, %s, %s, %s)""",
(symbols_scanned, symbols_found, json.dumps(top_symbols), scan_duration, scan_time)
)
@staticmethod
def get_recent(limit=100):
"""获取最近的扫描记录"""
return db.execute_query(
"SELECT * FROM market_scans ORDER BY scan_time DESC LIMIT %s",
(limit,)
)
class TradingSignal:
"""交易信号模型"""
@staticmethod
def create(symbol, signal_direction, signal_strength, signal_reason,
rsi=None, macd_histogram=None, market_regime=None):
"""创建交易信号(使用北京时间)"""
signal_time = get_beijing_time()
db.execute_update(
"""INSERT INTO trading_signals
(symbol, signal_direction, signal_strength, signal_reason,
rsi, macd_histogram, market_regime, signal_time)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)""",
(symbol, signal_direction, signal_strength, signal_reason,
rsi, macd_histogram, market_regime, signal_time)
)
@staticmethod
def mark_executed(signal_id):
"""标记信号已执行"""
db.execute_update(
"UPDATE trading_signals SET executed = TRUE WHERE id = %s",
(signal_id,)
)
@staticmethod
def get_recent(limit=100):
"""获取最近的信号"""
return db.execute_query(
"SELECT * FROM trading_signals ORDER BY signal_time DESC LIMIT %s",
(limit,)
)
class TradeRecommendation:
"""推荐交易对模型"""
@staticmethod
def create(
symbol, direction, current_price, change_percent, recommendation_reason,
signal_strength, market_regime=None, trend_4h=None,
rsi=None, macd_histogram=None,
bollinger_upper=None, bollinger_middle=None, bollinger_lower=None,
ema20=None, ema50=None, ema20_4h=None, atr=None,
suggested_stop_loss=None, suggested_take_profit_1=None, suggested_take_profit_2=None,
suggested_position_percent=None, suggested_leverage=10,
order_type='LIMIT', suggested_limit_price=None,
volume_24h=None, volatility=None, notes=None,
user_guide=None, recommendation_category=None, risk_level=None,
expected_hold_time=None, trading_tutorial=None, max_hold_days=3
):
"""创建推荐记录(使用北京时间)"""
recommendation_time = get_beijing_time()
# 默认24小时后过期
expires_at = recommendation_time + timedelta(hours=24)
# 检查字段是否存在兼容旧数据库schema
try:
db.execute_one("SELECT order_type FROM trade_recommendations LIMIT 1")
has_order_fields = True
except:
has_order_fields = False
# 检查是否有 user_guide 字段
try:
db.execute_one("SELECT user_guide FROM trade_recommendations LIMIT 1")
has_user_guide_fields = True
except:
has_user_guide_fields = False
if has_user_guide_fields:
# 包含所有新字段user_guide等
db.execute_update(
"""INSERT INTO trade_recommendations
(symbol, direction, recommendation_time, current_price, change_percent,
recommendation_reason, signal_strength, market_regime, trend_4h,
rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower,
ema20, ema50, ema20_4h, atr,
suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2,
suggested_position_percent, suggested_leverage,
order_type, suggested_limit_price,
volume_24h, volatility, expires_at, notes,
user_guide, recommendation_category, risk_level,
expected_hold_time, trading_tutorial, max_hold_days)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)""",
(symbol, direction, recommendation_time, current_price, change_percent,
recommendation_reason, signal_strength, market_regime, trend_4h,
rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower,
ema20, ema50, ema20_4h, atr,
suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2,
suggested_position_percent, suggested_leverage,
order_type, suggested_limit_price,
volume_24h, volatility, expires_at, notes,
user_guide, recommendation_category, risk_level,
expected_hold_time, trading_tutorial, max_hold_days)
)
elif has_order_fields:
# 只有 order_type 字段,没有 user_guide 字段
db.execute_update(
"""INSERT INTO trade_recommendations
(symbol, direction, recommendation_time, current_price, change_percent,
recommendation_reason, signal_strength, market_regime, trend_4h,
rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower,
ema20, ema50, ema20_4h, atr,
suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2,
suggested_position_percent, suggested_leverage,
order_type, suggested_limit_price,
volume_24h, volatility, expires_at, notes)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)""",
(symbol, direction, recommendation_time, current_price, change_percent,
recommendation_reason, signal_strength, market_regime, trend_4h,
rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower,
ema20, ema50, ema20_4h, atr,
suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2,
suggested_position_percent, suggested_leverage,
order_type, suggested_limit_price,
volume_24h, volatility, expires_at, notes)
)
else:
# 兼容旧schema
db.execute_update(
"""INSERT INTO trade_recommendations
(symbol, direction, recommendation_time, current_price, change_percent,
recommendation_reason, signal_strength, market_regime, trend_4h,
rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower,
ema20, ema50, ema20_4h, atr,
suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2,
suggested_position_percent, suggested_leverage,
volume_24h, volatility, expires_at, notes)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)""",
(symbol, direction, recommendation_time, current_price, change_percent,
recommendation_reason, signal_strength, market_regime, trend_4h,
rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower,
ema20, ema50, ema20_4h, atr,
suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2,
suggested_position_percent, suggested_leverage,
volume_24h, volatility, expires_at, notes)
)
return db.execute_one("SELECT LAST_INSERT_ID() as id")['id']
@staticmethod
def mark_executed(recommendation_id, trade_id=None, execution_result='success'):
"""标记推荐已执行"""
executed_at = get_beijing_time()
db.execute_update(
"""UPDATE trade_recommendations
SET status = 'executed', executed_at = %s, execution_result = %s, execution_trade_id = %s
WHERE id = %s""",
(executed_at, execution_result, trade_id, recommendation_id)
)
@staticmethod
def mark_expired(recommendation_id):
"""标记推荐已过期"""
db.execute_update(
"UPDATE trade_recommendations SET status = 'expired' WHERE id = %s",
(recommendation_id,)
)
@staticmethod
def mark_cancelled(recommendation_id, notes=None):
"""标记推荐已取消"""
db.execute_update(
"UPDATE trade_recommendations SET status = 'cancelled', notes = %s WHERE id = %s",
(notes, recommendation_id)
)
@staticmethod
def get_all(status=None, direction=None, limit=100, start_date=None, end_date=None):
"""获取推荐记录"""
query = "SELECT * FROM trade_recommendations WHERE 1=1"
params = []
if status:
query += " AND status = %s"
params.append(status)
if direction:
query += " AND direction = %s"
params.append(direction)
if start_date:
query += " AND recommendation_time >= %s"
params.append(start_date)
if end_date:
query += " AND recommendation_time <= %s"
params.append(end_date)
query += " ORDER BY recommendation_time DESC, signal_strength DESC LIMIT %s"
params.append(limit)
return db.execute_query(query, params)
@staticmethod
def get_active():
"""获取当前有效的推荐(未过期、未执行、未取消)
同一交易对只返回最新的推荐(去重)
"""
return db.execute_query(
"""SELECT t1.* FROM trade_recommendations t1
INNER JOIN (
SELECT symbol, MAX(recommendation_time) as max_time
FROM trade_recommendations
WHERE status = 'active' AND (expires_at IS NULL OR expires_at > NOW())
GROUP BY symbol
) t2 ON t1.symbol = t2.symbol AND t1.recommendation_time = t2.max_time
WHERE t1.status = 'active' AND (t1.expires_at IS NULL OR t1.expires_at > NOW())
ORDER BY t1.signal_strength DESC, t1.recommendation_time DESC"""
)
@staticmethod
def get_by_id(recommendation_id):
"""根据ID获取推荐"""
return db.execute_one(
"SELECT * FROM trade_recommendations WHERE id = %s",
(recommendation_id,)
)
@staticmethod
def get_by_symbol(symbol, limit=10):
"""根据交易对获取推荐记录"""
return db.execute_query(
"""SELECT * FROM trade_recommendations
WHERE symbol = %s
ORDER BY recommendation_time DESC LIMIT %s""",
(symbol, limit)
)