893 lines
36 KiB
Python
893 lines
36 KiB
Python
"""
|
||
数据库模型定义
|
||
"""
|
||
from database.connection import db
|
||
from datetime import datetime, timezone, timedelta
|
||
import json
|
||
import logging
|
||
import os
|
||
|
||
# 北京时间时区(UTC+8)
|
||
BEIJING_TZ = timezone(timedelta(hours=8))
|
||
|
||
def get_beijing_time():
|
||
"""获取当前北京时间(UTC+8)的Unix时间戳(秒)"""
|
||
return int(datetime.now(BEIJING_TZ).timestamp())
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
def _resolve_default_account_id() -> int:
|
||
"""
|
||
默认账号ID:
|
||
- trading_system 多进程:每个进程可通过 ATS_ACCOUNT_ID 指定自己的 account_id
|
||
- backend:未传 account_id 时默认走 1(兼容单账号)
|
||
"""
|
||
for k in ("ATS_ACCOUNT_ID", "ACCOUNT_ID", "ATS_DEFAULT_ACCOUNT_ID"):
|
||
v = (os.getenv(k, "") or "").strip()
|
||
if v:
|
||
try:
|
||
return int(v)
|
||
except Exception:
|
||
continue
|
||
return 1
|
||
|
||
|
||
DEFAULT_ACCOUNT_ID = _resolve_default_account_id()
|
||
|
||
|
||
def _table_has_column(table: str, col: str) -> bool:
|
||
try:
|
||
db.execute_one(f"SELECT {col} FROM {table} LIMIT 1")
|
||
return True
|
||
except Exception:
|
||
return False
|
||
|
||
|
||
class Account:
|
||
"""
|
||
账号模型(多账号)
|
||
- API Key/Secret 建议加密存储在 accounts 表中,而不是 trading_config
|
||
"""
|
||
|
||
@staticmethod
|
||
def get(account_id: int):
|
||
return db.execute_one("SELECT * FROM accounts WHERE id = %s", (int(account_id),))
|
||
|
||
@staticmethod
|
||
def list_all():
|
||
return db.execute_query("SELECT id, name, status, created_at, updated_at FROM accounts ORDER BY id ASC")
|
||
|
||
@staticmethod
|
||
def create(name: str, api_key: str = "", api_secret: str = "", use_testnet: bool = False, status: str = "active"):
|
||
from security.crypto import encrypt_str # 延迟导入,避免无依赖时直接崩
|
||
|
||
api_key_enc = encrypt_str(api_key or "")
|
||
api_secret_enc = encrypt_str(api_secret or "")
|
||
db.execute_update(
|
||
"""INSERT INTO accounts (name, status, api_key_enc, api_secret_enc, use_testnet)
|
||
VALUES (%s, %s, %s, %s, %s)""",
|
||
(name, status, api_key_enc, api_secret_enc, bool(use_testnet)),
|
||
)
|
||
return db.execute_one("SELECT LAST_INSERT_ID() as id")["id"]
|
||
|
||
@staticmethod
|
||
def update_credentials(account_id: int, api_key: str = None, api_secret: str = None, use_testnet: bool = None):
|
||
from security.crypto import encrypt_str # 延迟导入
|
||
|
||
fields = []
|
||
params = []
|
||
if api_key is not None:
|
||
fields.append("api_key_enc = %s")
|
||
params.append(encrypt_str(api_key))
|
||
if api_secret is not None:
|
||
fields.append("api_secret_enc = %s")
|
||
params.append(encrypt_str(api_secret))
|
||
if use_testnet is not None:
|
||
fields.append("use_testnet = %s")
|
||
params.append(bool(use_testnet))
|
||
if not fields:
|
||
return
|
||
params.append(int(account_id))
|
||
db.execute_update(f"UPDATE accounts SET {', '.join(fields)} WHERE id = %s", tuple(params))
|
||
|
||
@staticmethod
|
||
def get_credentials(account_id: int):
|
||
"""
|
||
返回 (api_key, api_secret, use_testnet);密文字段会自动解密。
|
||
若未配置 master key 且库里是明文,仍可工作(但不安全)。
|
||
"""
|
||
row = Account.get(account_id)
|
||
if not row:
|
||
return "", "", False
|
||
try:
|
||
from security.crypto import decrypt_str
|
||
|
||
api_key = decrypt_str(row.get("api_key_enc") or "")
|
||
api_secret = decrypt_str(row.get("api_secret_enc") or "")
|
||
except Exception:
|
||
# 兼容:无 cryptography 或未配 master key 时:
|
||
# - 若库里是明文,仍可工作
|
||
# - 若库里是 enc:v1 密文但未配 ATS_MASTER_KEY,则不能解密,也不能把密文当作 Key 使用
|
||
api_key_raw = row.get("api_key_enc") or ""
|
||
api_secret_raw = row.get("api_secret_enc") or ""
|
||
api_key = "" if str(api_key_raw).startswith("enc:v1:") else str(api_key_raw)
|
||
api_secret = "" if str(api_secret_raw).startswith("enc:v1:") else str(api_secret_raw)
|
||
use_testnet = bool(row.get("use_testnet") or False)
|
||
return api_key, api_secret, use_testnet
|
||
|
||
|
||
class User:
|
||
"""登录用户(管理员/普通用户)"""
|
||
|
||
@staticmethod
|
||
def get_by_username(username: str):
|
||
return db.execute_one("SELECT * FROM users WHERE username = %s", (str(username),))
|
||
|
||
@staticmethod
|
||
def get_by_id(user_id: int):
|
||
return db.execute_one("SELECT * FROM users WHERE id = %s", (int(user_id),))
|
||
|
||
@staticmethod
|
||
def list_all():
|
||
return db.execute_query("SELECT id, username, role, status, created_at, updated_at FROM users ORDER BY id ASC")
|
||
|
||
@staticmethod
|
||
def create(username: str, password_hash: str, role: str = "user", status: str = "active"):
|
||
db.execute_update(
|
||
"INSERT INTO users (username, password_hash, role, status) VALUES (%s, %s, %s, %s)",
|
||
(username, password_hash, role, status),
|
||
)
|
||
return db.execute_one("SELECT LAST_INSERT_ID() as id")["id"]
|
||
|
||
@staticmethod
|
||
def set_password(user_id: int, password_hash: str):
|
||
db.execute_update("UPDATE users SET password_hash = %s WHERE id = %s", (password_hash, int(user_id)))
|
||
|
||
@staticmethod
|
||
def set_status(user_id: int, status: str):
|
||
db.execute_update("UPDATE users SET status = %s WHERE id = %s", (status, int(user_id)))
|
||
|
||
@staticmethod
|
||
def set_role(user_id: int, role: str):
|
||
db.execute_update("UPDATE users SET role = %s WHERE id = %s", (role, int(user_id)))
|
||
|
||
|
||
class UserAccountMembership:
|
||
"""用户-交易账号授权关系"""
|
||
|
||
@staticmethod
|
||
def add(user_id: int, account_id: int, role: str = "viewer"):
|
||
db.execute_update(
|
||
"""INSERT INTO user_account_memberships (user_id, account_id, role)
|
||
VALUES (%s, %s, %s)
|
||
ON DUPLICATE KEY UPDATE role = VALUES(role)""",
|
||
(int(user_id), int(account_id), role),
|
||
)
|
||
|
||
@staticmethod
|
||
def remove(user_id: int, account_id: int):
|
||
db.execute_update(
|
||
"DELETE FROM user_account_memberships WHERE user_id = %s AND account_id = %s",
|
||
(int(user_id), int(account_id)),
|
||
)
|
||
|
||
@staticmethod
|
||
def list_for_user(user_id: int):
|
||
return db.execute_query(
|
||
"SELECT * FROM user_account_memberships WHERE user_id = %s ORDER BY account_id ASC",
|
||
(int(user_id),),
|
||
)
|
||
|
||
@staticmethod
|
||
def list_for_account(account_id: int):
|
||
return db.execute_query(
|
||
"SELECT * FROM user_account_memberships WHERE account_id = %s ORDER BY user_id ASC",
|
||
(int(account_id),),
|
||
)
|
||
|
||
@staticmethod
|
||
def has_access(user_id: int, account_id: int) -> bool:
|
||
row = db.execute_one(
|
||
"SELECT 1 as ok FROM user_account_memberships WHERE user_id = %s AND account_id = %s",
|
||
(int(user_id), int(account_id)),
|
||
)
|
||
return bool(row)
|
||
|
||
@staticmethod
|
||
def get_role(user_id: int, account_id: int) -> str:
|
||
row = db.execute_one(
|
||
"SELECT role FROM user_account_memberships WHERE user_id = %s AND account_id = %s",
|
||
(int(user_id), int(account_id)),
|
||
)
|
||
return (row.get("role") if isinstance(row, dict) else None) or ""
|
||
|
||
|
||
class TradingConfig:
|
||
"""交易配置模型"""
|
||
|
||
@staticmethod
|
||
def get_all(account_id: int = None):
|
||
"""获取所有配置"""
|
||
aid = int(account_id or DEFAULT_ACCOUNT_ID)
|
||
if _table_has_column("trading_config", "account_id"):
|
||
return db.execute_query(
|
||
"SELECT * FROM trading_config WHERE account_id = %s ORDER BY category, config_key",
|
||
(aid,),
|
||
)
|
||
return db.execute_query("SELECT * FROM trading_config ORDER BY category, config_key")
|
||
|
||
@staticmethod
|
||
def get(key, account_id: int = None):
|
||
"""获取单个配置"""
|
||
aid = int(account_id or DEFAULT_ACCOUNT_ID)
|
||
if _table_has_column("trading_config", "account_id"):
|
||
return db.execute_one(
|
||
"SELECT * FROM trading_config WHERE account_id = %s AND config_key = %s",
|
||
(aid, key),
|
||
)
|
||
return db.execute_one("SELECT * FROM trading_config WHERE config_key = %s", (key,))
|
||
|
||
@staticmethod
|
||
def set(key, value, config_type, category, description=None, account_id: int = None):
|
||
"""设置配置"""
|
||
value_str = TradingConfig._convert_to_string(value, config_type)
|
||
aid = int(account_id or DEFAULT_ACCOUNT_ID)
|
||
if _table_has_column("trading_config", "account_id"):
|
||
db.execute_update(
|
||
"""INSERT INTO trading_config
|
||
(account_id, config_key, config_value, config_type, category, description)
|
||
VALUES (%s, %s, %s, %s, %s, %s)
|
||
ON DUPLICATE KEY UPDATE
|
||
config_value = VALUES(config_value),
|
||
config_type = VALUES(config_type),
|
||
category = VALUES(category),
|
||
description = VALUES(description),
|
||
updated_at = CURRENT_TIMESTAMP""",
|
||
(aid, key, value_str, config_type, category, description),
|
||
)
|
||
else:
|
||
db.execute_update(
|
||
"""INSERT INTO trading_config
|
||
(config_key, config_value, config_type, category, description)
|
||
VALUES (%s, %s, %s, %s, %s)
|
||
ON DUPLICATE KEY UPDATE
|
||
config_value = VALUES(config_value),
|
||
config_type = VALUES(config_type),
|
||
category = VALUES(category),
|
||
description = VALUES(description),
|
||
updated_at = CURRENT_TIMESTAMP""",
|
||
(key, value_str, config_type, category, description),
|
||
)
|
||
|
||
@staticmethod
|
||
def get_value(key, default=None, account_id: int = None):
|
||
"""获取配置值(自动转换类型)"""
|
||
result = TradingConfig.get(key, account_id=account_id)
|
||
if result:
|
||
return TradingConfig._convert_value(result['config_value'], result['config_type'])
|
||
return default
|
||
|
||
@staticmethod
|
||
def _convert_value(value, config_type):
|
||
"""转换配置值类型"""
|
||
if config_type == 'number':
|
||
try:
|
||
return float(value) if '.' in str(value) else int(value)
|
||
except:
|
||
return 0
|
||
elif config_type == 'boolean':
|
||
return str(value).lower() in ('true', '1', 'yes', 'on')
|
||
elif config_type == 'json':
|
||
try:
|
||
return json.loads(value)
|
||
except:
|
||
return {}
|
||
return value
|
||
|
||
@staticmethod
|
||
def _convert_to_string(value, config_type):
|
||
"""转换值为字符串"""
|
||
if config_type == 'json':
|
||
return json.dumps(value, ensure_ascii=False)
|
||
return str(value)
|
||
|
||
|
||
class Trade:
|
||
"""交易记录模型"""
|
||
|
||
@staticmethod
|
||
def create(
|
||
symbol,
|
||
side,
|
||
quantity,
|
||
entry_price,
|
||
leverage=10,
|
||
entry_reason=None,
|
||
entry_order_id=None,
|
||
stop_loss_price=None,
|
||
take_profit_price=None,
|
||
take_profit_1=None,
|
||
take_profit_2=None,
|
||
atr=None,
|
||
notional_usdt=None,
|
||
margin_usdt=None,
|
||
account_id: int = None,
|
||
):
|
||
"""创建交易记录(使用北京时间)
|
||
|
||
Args:
|
||
symbol: 交易对
|
||
side: 方向
|
||
quantity: 数量
|
||
entry_price: 入场价
|
||
leverage: 杠杆
|
||
entry_reason: 入场原因
|
||
entry_order_id: 币安开仓订单号(可选,用于对账)
|
||
stop_loss_price: 实际使用的止损价格(考虑了ATR等动态计算)
|
||
take_profit_price: 实际使用的止盈价格(考虑了ATR等动态计算)
|
||
take_profit_1: 第一目标止盈价(可选)
|
||
take_profit_2: 第二目标止盈价(可选)
|
||
atr: 开仓时使用的ATR值(可选)
|
||
notional_usdt: 名义下单量(USDT,可选)
|
||
margin_usdt: 保证金(USDT,可选)
|
||
"""
|
||
entry_time = get_beijing_time()
|
||
|
||
# 自动计算 notional/margin(若调用方没传)
|
||
try:
|
||
if notional_usdt is None and quantity is not None and entry_price is not None:
|
||
notional_usdt = float(quantity) * float(entry_price)
|
||
except Exception:
|
||
pass
|
||
try:
|
||
if margin_usdt is None and notional_usdt is not None:
|
||
lv = float(leverage) if leverage else 0
|
||
margin_usdt = (float(notional_usdt) / lv) if lv and lv > 0 else float(notional_usdt)
|
||
except Exception:
|
||
pass
|
||
|
||
def _has_column(col: str) -> bool:
|
||
try:
|
||
db.execute_one(f"SELECT {col} FROM trades LIMIT 1")
|
||
return True
|
||
except Exception:
|
||
return False
|
||
|
||
# 动态构建 INSERT(兼容不同schema)
|
||
columns = ["symbol", "side", "quantity", "entry_price", "leverage", "entry_reason", "status", "entry_time"]
|
||
values = [symbol, side, quantity, entry_price, leverage, entry_reason, "open", entry_time]
|
||
|
||
if _has_column("account_id"):
|
||
columns.insert(0, "account_id")
|
||
values.insert(0, int(account_id or DEFAULT_ACCOUNT_ID))
|
||
|
||
if _has_column("entry_order_id"):
|
||
columns.append("entry_order_id")
|
||
values.append(entry_order_id)
|
||
|
||
if _has_column("notional_usdt"):
|
||
columns.append("notional_usdt")
|
||
values.append(notional_usdt)
|
||
|
||
if _has_column("margin_usdt"):
|
||
columns.append("margin_usdt")
|
||
values.append(margin_usdt)
|
||
|
||
if _has_column("atr"):
|
||
columns.append("atr")
|
||
values.append(atr)
|
||
|
||
if _has_column("stop_loss_price"):
|
||
columns.append("stop_loss_price")
|
||
values.append(stop_loss_price)
|
||
|
||
if _has_column("take_profit_price"):
|
||
columns.append("take_profit_price")
|
||
values.append(take_profit_price)
|
||
|
||
if _has_column("take_profit_1"):
|
||
columns.append("take_profit_1")
|
||
values.append(take_profit_1)
|
||
|
||
if _has_column("take_profit_2"):
|
||
columns.append("take_profit_2")
|
||
values.append(take_profit_2)
|
||
|
||
placeholders = ", ".join(["%s"] * len(columns))
|
||
sql = f"INSERT INTO trades ({', '.join(columns)}) VALUES ({placeholders})"
|
||
db.execute_update(sql, tuple(values))
|
||
return db.execute_one("SELECT LAST_INSERT_ID() as id")['id']
|
||
|
||
@staticmethod
|
||
def update_exit(
|
||
trade_id,
|
||
exit_price,
|
||
exit_reason,
|
||
pnl,
|
||
pnl_percent,
|
||
exit_order_id=None,
|
||
strategy_type=None,
|
||
duration_minutes=None,
|
||
exit_time_ts=None,
|
||
):
|
||
"""更新平仓信息(使用北京时间)
|
||
|
||
Args:
|
||
trade_id: 交易记录ID
|
||
exit_price: 出场价
|
||
exit_reason: 平仓原因
|
||
pnl: 盈亏
|
||
pnl_percent: 盈亏百分比
|
||
exit_order_id: 币安平仓订单号(可选,用于对账)
|
||
|
||
注意:如果 exit_order_id 已存在且属于其他交易记录,会跳过更新 exit_order_id 以避免唯一约束冲突
|
||
"""
|
||
# exit_time_ts: 允许外部传入“真实成交时间”(Unix秒)以便统计持仓时长更准确
|
||
try:
|
||
exit_time = int(exit_time_ts) if exit_time_ts is not None else get_beijing_time()
|
||
except Exception:
|
||
exit_time = get_beijing_time()
|
||
|
||
# 如果提供了 exit_order_id,先检查是否已被其他交易记录使用
|
||
if exit_order_id is not None:
|
||
try:
|
||
existing_trade = Trade.get_by_exit_order_id(exit_order_id)
|
||
if existing_trade:
|
||
if existing_trade['id'] == trade_id:
|
||
# exit_order_id 属于当前交易记录:允许继续更新(比如补写 exit_reason / exit_time / duration)
|
||
# 不需要提前 return
|
||
logger.debug(
|
||
f"交易记录 {trade_id} 的 exit_order_id {exit_order_id} 已存在,将继续更新其他字段"
|
||
)
|
||
else:
|
||
# 如果 exit_order_id 已被其他交易记录使用,记录警告但不更新 exit_order_id
|
||
logger.warning(
|
||
f"交易记录 {trade_id} 的 exit_order_id {exit_order_id} 已被交易记录 {existing_trade['id']} 使用,"
|
||
f"跳过更新 exit_order_id,只更新其他字段"
|
||
)
|
||
# 只更新其他字段,不更新 exit_order_id
|
||
update_fields = [
|
||
"exit_price = %s", "exit_time = %s",
|
||
"exit_reason = %s", "pnl = %s", "pnl_percent = %s", "status = 'closed'"
|
||
]
|
||
update_values = [exit_price, exit_time, exit_reason, pnl, pnl_percent]
|
||
|
||
if strategy_type is not None:
|
||
update_fields.append("strategy_type = %s")
|
||
update_values.append(strategy_type)
|
||
if duration_minutes is not None:
|
||
update_fields.append("duration_minutes = %s")
|
||
update_values.append(duration_minutes)
|
||
|
||
update_values.append(trade_id)
|
||
db.execute_update(
|
||
f"UPDATE trades SET {', '.join(update_fields)} WHERE id = %s",
|
||
tuple(update_values)
|
||
)
|
||
return
|
||
except Exception as e:
|
||
# 如果查询失败,记录警告但继续正常更新
|
||
logger.warning(f"检查 exit_order_id {exit_order_id} 时出错: {e},继续正常更新")
|
||
|
||
# 正常更新(包括 exit_order_id)
|
||
try:
|
||
update_fields = [
|
||
"exit_price = %s", "exit_time = %s",
|
||
"exit_reason = %s", "pnl = %s", "pnl_percent = %s", "status = 'closed'",
|
||
"exit_order_id = %s"
|
||
]
|
||
update_values = [exit_price, exit_time, exit_reason, pnl, pnl_percent, exit_order_id]
|
||
|
||
if strategy_type is not None:
|
||
update_fields.append("strategy_type = %s")
|
||
update_values.append(strategy_type)
|
||
if duration_minutes is not None:
|
||
update_fields.append("duration_minutes = %s")
|
||
update_values.append(duration_minutes)
|
||
|
||
update_values.append(trade_id)
|
||
db.execute_update(
|
||
f"UPDATE trades SET {', '.join(update_fields)} WHERE id = %s",
|
||
tuple(update_values)
|
||
)
|
||
except Exception as e:
|
||
# 如果更新失败(可能是唯一约束冲突),尝试不更新 exit_order_id
|
||
error_str = str(e)
|
||
if "Duplicate entry" in error_str and "exit_order_id" in error_str:
|
||
logger.warning(
|
||
f"更新交易记录 {trade_id} 时 exit_order_id {exit_order_id} 唯一约束冲突,"
|
||
f"尝试不更新 exit_order_id"
|
||
)
|
||
# 只更新其他字段,不更新 exit_order_id
|
||
update_fields = [
|
||
"exit_price = %s", "exit_time = %s",
|
||
"exit_reason = %s", "pnl = %s", "pnl_percent = %s", "status = 'closed'"
|
||
]
|
||
update_values = [exit_price, exit_time, exit_reason, pnl, pnl_percent]
|
||
|
||
if strategy_type is not None:
|
||
update_fields.append("strategy_type = %s")
|
||
update_values.append(strategy_type)
|
||
if duration_minutes is not None:
|
||
update_fields.append("duration_minutes = %s")
|
||
update_values.append(duration_minutes)
|
||
|
||
update_values.append(trade_id)
|
||
db.execute_update(
|
||
f"UPDATE trades SET {', '.join(update_fields)} WHERE id = %s",
|
||
tuple(update_values)
|
||
)
|
||
else:
|
||
# 其他错误,重新抛出
|
||
raise
|
||
|
||
@staticmethod
|
||
def get_by_entry_order_id(entry_order_id):
|
||
"""根据开仓订单号获取交易记录"""
|
||
return db.execute_one(
|
||
"SELECT * FROM trades WHERE entry_order_id = %s",
|
||
(entry_order_id,)
|
||
)
|
||
|
||
@staticmethod
|
||
def get_by_exit_order_id(exit_order_id):
|
||
"""根据平仓订单号获取交易记录"""
|
||
return db.execute_one(
|
||
"SELECT * FROM trades WHERE exit_order_id = %s",
|
||
(exit_order_id,)
|
||
)
|
||
|
||
@staticmethod
|
||
def get_all(start_timestamp=None, end_timestamp=None, symbol=None, status=None, trade_type=None, exit_reason=None, account_id: int = None):
|
||
"""获取交易记录
|
||
|
||
Args:
|
||
start_timestamp: 开始时间(Unix时间戳秒数,可选)
|
||
end_timestamp: 结束时间(Unix时间戳秒数,可选)
|
||
symbol: 交易对(可选)
|
||
status: 状态(可选)
|
||
trade_type: 交易类型(可选)
|
||
exit_reason: 平仓原因(可选)
|
||
"""
|
||
query = "SELECT * FROM trades WHERE 1=1"
|
||
params = []
|
||
|
||
# 多账号隔离(兼容旧schema)
|
||
try:
|
||
if _table_has_column("trades", "account_id"):
|
||
query += " AND account_id = %s"
|
||
params.append(int(account_id or DEFAULT_ACCOUNT_ID))
|
||
except Exception:
|
||
pass
|
||
|
||
if start_timestamp is not None:
|
||
query += " AND created_at >= %s"
|
||
params.append(start_timestamp)
|
||
if end_timestamp is not None:
|
||
query += " AND created_at <= %s"
|
||
params.append(end_timestamp)
|
||
if symbol:
|
||
query += " AND symbol = %s"
|
||
params.append(symbol)
|
||
if status:
|
||
query += " AND status = %s"
|
||
params.append(status)
|
||
if trade_type:
|
||
query += " AND side = %s"
|
||
params.append(trade_type)
|
||
if exit_reason:
|
||
query += " AND exit_reason = %s"
|
||
params.append(exit_reason)
|
||
|
||
# 按平仓时间倒序排序,如果没有平仓时间则按入场时间倒序
|
||
# query += " ORDER BY COALESCE(exit_time, entry_time) DESC, entry_time DESC"
|
||
query += " ORDER BY id DESC"
|
||
logger.info(f"查询交易记录: {query}, {params}")
|
||
result = db.execute_query(query, params)
|
||
return result
|
||
|
||
@staticmethod
|
||
def get_by_symbol(symbol, status='open', account_id: int = None):
|
||
"""根据交易对获取持仓"""
|
||
aid = int(account_id or DEFAULT_ACCOUNT_ID)
|
||
if _table_has_column("trades", "account_id"):
|
||
return db.execute_query(
|
||
"SELECT * FROM trades WHERE account_id = %s AND symbol = %s AND status = %s",
|
||
(aid, symbol, status),
|
||
)
|
||
return db.execute_query(
|
||
"SELECT * FROM trades WHERE symbol = %s AND status = %s",
|
||
(symbol, status),
|
||
)
|
||
|
||
|
||
class AccountSnapshot:
|
||
"""账户快照模型"""
|
||
|
||
@staticmethod
|
||
def create(total_balance, available_balance, total_position_value, total_pnl, open_positions, account_id: int = None):
|
||
"""创建账户快照(使用北京时间)"""
|
||
snapshot_time = get_beijing_time()
|
||
if _table_has_column("account_snapshots", "account_id"):
|
||
db.execute_update(
|
||
"""INSERT INTO account_snapshots
|
||
(account_id, total_balance, available_balance, total_position_value, total_pnl, open_positions, snapshot_time)
|
||
VALUES (%s, %s, %s, %s, %s, %s, %s)""",
|
||
(int(account_id or DEFAULT_ACCOUNT_ID), total_balance, available_balance, total_position_value, total_pnl, open_positions, snapshot_time),
|
||
)
|
||
else:
|
||
db.execute_update(
|
||
"""INSERT INTO account_snapshots
|
||
(total_balance, available_balance, total_position_value, total_pnl, open_positions, snapshot_time)
|
||
VALUES (%s, %s, %s, %s, %s, %s)""",
|
||
(total_balance, available_balance, total_position_value, total_pnl, open_positions, snapshot_time),
|
||
)
|
||
|
||
@staticmethod
|
||
def get_recent(days=7, account_id: int = None):
|
||
"""获取最近的快照"""
|
||
aid = int(account_id or DEFAULT_ACCOUNT_ID)
|
||
if _table_has_column("account_snapshots", "account_id"):
|
||
return db.execute_query(
|
||
"""SELECT * FROM account_snapshots
|
||
WHERE account_id = %s AND snapshot_time >= DATE_SUB(NOW(), INTERVAL %s DAY)
|
||
ORDER BY snapshot_time DESC""",
|
||
(aid, days),
|
||
)
|
||
return db.execute_query(
|
||
"""SELECT * FROM account_snapshots
|
||
WHERE snapshot_time >= DATE_SUB(NOW(), INTERVAL %s DAY)
|
||
ORDER BY snapshot_time DESC""",
|
||
(days,),
|
||
)
|
||
|
||
|
||
class MarketScan:
|
||
"""市场扫描记录模型"""
|
||
|
||
@staticmethod
|
||
def create(symbols_scanned, symbols_found, top_symbols, scan_duration):
|
||
"""创建扫描记录(使用北京时间)"""
|
||
scan_time = get_beijing_time()
|
||
db.execute_update(
|
||
"""INSERT INTO market_scans
|
||
(symbols_scanned, symbols_found, top_symbols, scan_duration, scan_time)
|
||
VALUES (%s, %s, %s, %s, %s)""",
|
||
(symbols_scanned, symbols_found, json.dumps(top_symbols), scan_duration, scan_time)
|
||
)
|
||
|
||
@staticmethod
|
||
def get_recent(limit=100):
|
||
"""获取最近的扫描记录"""
|
||
return db.execute_query(
|
||
"SELECT * FROM market_scans ORDER BY scan_time DESC LIMIT %s",
|
||
(limit,)
|
||
)
|
||
|
||
|
||
class TradingSignal:
|
||
"""交易信号模型"""
|
||
|
||
@staticmethod
|
||
def create(symbol, signal_direction, signal_strength, signal_reason,
|
||
rsi=None, macd_histogram=None, market_regime=None):
|
||
"""创建交易信号(使用北京时间)"""
|
||
signal_time = get_beijing_time()
|
||
db.execute_update(
|
||
"""INSERT INTO trading_signals
|
||
(symbol, signal_direction, signal_strength, signal_reason,
|
||
rsi, macd_histogram, market_regime, signal_time)
|
||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)""",
|
||
(symbol, signal_direction, signal_strength, signal_reason,
|
||
rsi, macd_histogram, market_regime, signal_time)
|
||
)
|
||
|
||
@staticmethod
|
||
def mark_executed(signal_id):
|
||
"""标记信号已执行"""
|
||
db.execute_update(
|
||
"UPDATE trading_signals SET executed = TRUE WHERE id = %s",
|
||
(signal_id,)
|
||
)
|
||
|
||
@staticmethod
|
||
def get_recent(limit=100):
|
||
"""获取最近的信号"""
|
||
return db.execute_query(
|
||
"SELECT * FROM trading_signals ORDER BY signal_time DESC LIMIT %s",
|
||
(limit,)
|
||
)
|
||
|
||
|
||
class TradeRecommendation:
|
||
"""推荐交易对模型"""
|
||
|
||
@staticmethod
|
||
def create(
|
||
symbol, direction, current_price, change_percent, recommendation_reason,
|
||
signal_strength, market_regime=None, trend_4h=None,
|
||
rsi=None, macd_histogram=None,
|
||
bollinger_upper=None, bollinger_middle=None, bollinger_lower=None,
|
||
ema20=None, ema50=None, ema20_4h=None, atr=None,
|
||
suggested_stop_loss=None, suggested_take_profit_1=None, suggested_take_profit_2=None,
|
||
suggested_position_percent=None, suggested_leverage=10,
|
||
order_type='LIMIT', suggested_limit_price=None,
|
||
volume_24h=None, volatility=None, notes=None,
|
||
user_guide=None, recommendation_category=None, risk_level=None,
|
||
expected_hold_time=None, trading_tutorial=None, max_hold_days=3
|
||
):
|
||
"""创建推荐记录(使用北京时间)"""
|
||
recommendation_time = get_beijing_time()
|
||
# 默认24小时后过期
|
||
expires_at = recommendation_time + timedelta(hours=24)
|
||
|
||
# 检查字段是否存在(兼容旧数据库schema)
|
||
try:
|
||
db.execute_one("SELECT order_type FROM trade_recommendations LIMIT 1")
|
||
has_order_fields = True
|
||
except:
|
||
has_order_fields = False
|
||
|
||
# 检查是否有 user_guide 字段
|
||
try:
|
||
db.execute_one("SELECT user_guide FROM trade_recommendations LIMIT 1")
|
||
has_user_guide_fields = True
|
||
except:
|
||
has_user_guide_fields = False
|
||
|
||
if has_user_guide_fields:
|
||
# 包含所有新字段(user_guide等)
|
||
db.execute_update(
|
||
"""INSERT INTO trade_recommendations
|
||
(symbol, direction, recommendation_time, current_price, change_percent,
|
||
recommendation_reason, signal_strength, market_regime, trend_4h,
|
||
rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower,
|
||
ema20, ema50, ema20_4h, atr,
|
||
suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2,
|
||
suggested_position_percent, suggested_leverage,
|
||
order_type, suggested_limit_price,
|
||
volume_24h, volatility, expires_at, notes,
|
||
user_guide, recommendation_category, risk_level,
|
||
expected_hold_time, trading_tutorial, max_hold_days)
|
||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)""",
|
||
(symbol, direction, recommendation_time, current_price, change_percent,
|
||
recommendation_reason, signal_strength, market_regime, trend_4h,
|
||
rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower,
|
||
ema20, ema50, ema20_4h, atr,
|
||
suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2,
|
||
suggested_position_percent, suggested_leverage,
|
||
order_type, suggested_limit_price,
|
||
volume_24h, volatility, expires_at, notes,
|
||
user_guide, recommendation_category, risk_level,
|
||
expected_hold_time, trading_tutorial, max_hold_days)
|
||
)
|
||
elif has_order_fields:
|
||
# 只有 order_type 字段,没有 user_guide 字段
|
||
db.execute_update(
|
||
"""INSERT INTO trade_recommendations
|
||
(symbol, direction, recommendation_time, current_price, change_percent,
|
||
recommendation_reason, signal_strength, market_regime, trend_4h,
|
||
rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower,
|
||
ema20, ema50, ema20_4h, atr,
|
||
suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2,
|
||
suggested_position_percent, suggested_leverage,
|
||
order_type, suggested_limit_price,
|
||
volume_24h, volatility, expires_at, notes)
|
||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)""",
|
||
(symbol, direction, recommendation_time, current_price, change_percent,
|
||
recommendation_reason, signal_strength, market_regime, trend_4h,
|
||
rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower,
|
||
ema20, ema50, ema20_4h, atr,
|
||
suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2,
|
||
suggested_position_percent, suggested_leverage,
|
||
order_type, suggested_limit_price,
|
||
volume_24h, volatility, expires_at, notes)
|
||
)
|
||
else:
|
||
# 兼容旧schema
|
||
db.execute_update(
|
||
"""INSERT INTO trade_recommendations
|
||
(symbol, direction, recommendation_time, current_price, change_percent,
|
||
recommendation_reason, signal_strength, market_regime, trend_4h,
|
||
rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower,
|
||
ema20, ema50, ema20_4h, atr,
|
||
suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2,
|
||
suggested_position_percent, suggested_leverage,
|
||
volume_24h, volatility, expires_at, notes)
|
||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)""",
|
||
(symbol, direction, recommendation_time, current_price, change_percent,
|
||
recommendation_reason, signal_strength, market_regime, trend_4h,
|
||
rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower,
|
||
ema20, ema50, ema20_4h, atr,
|
||
suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2,
|
||
suggested_position_percent, suggested_leverage,
|
||
volume_24h, volatility, expires_at, notes)
|
||
)
|
||
return db.execute_one("SELECT LAST_INSERT_ID() as id")['id']
|
||
|
||
@staticmethod
|
||
def mark_executed(recommendation_id, trade_id=None, execution_result='success'):
|
||
"""标记推荐已执行"""
|
||
executed_at = get_beijing_time()
|
||
db.execute_update(
|
||
"""UPDATE trade_recommendations
|
||
SET status = 'executed', executed_at = %s, execution_result = %s, execution_trade_id = %s
|
||
WHERE id = %s""",
|
||
(executed_at, execution_result, trade_id, recommendation_id)
|
||
)
|
||
|
||
@staticmethod
|
||
def mark_expired(recommendation_id):
|
||
"""标记推荐已过期"""
|
||
db.execute_update(
|
||
"UPDATE trade_recommendations SET status = 'expired' WHERE id = %s",
|
||
(recommendation_id,)
|
||
)
|
||
|
||
@staticmethod
|
||
def mark_cancelled(recommendation_id, notes=None):
|
||
"""标记推荐已取消"""
|
||
db.execute_update(
|
||
"UPDATE trade_recommendations SET status = 'cancelled', notes = %s WHERE id = %s",
|
||
(notes, recommendation_id)
|
||
)
|
||
|
||
@staticmethod
|
||
def get_all(status=None, direction=None, limit=100, start_date=None, end_date=None):
|
||
"""获取推荐记录"""
|
||
query = "SELECT * FROM trade_recommendations WHERE 1=1"
|
||
params = []
|
||
|
||
if status:
|
||
query += " AND status = %s"
|
||
params.append(status)
|
||
if direction:
|
||
query += " AND direction = %s"
|
||
params.append(direction)
|
||
if start_date:
|
||
query += " AND recommendation_time >= %s"
|
||
params.append(start_date)
|
||
if end_date:
|
||
query += " AND recommendation_time <= %s"
|
||
params.append(end_date)
|
||
|
||
query += " ORDER BY recommendation_time DESC, signal_strength DESC LIMIT %s"
|
||
params.append(limit)
|
||
|
||
return db.execute_query(query, params)
|
||
|
||
@staticmethod
|
||
def get_active():
|
||
"""获取当前有效的推荐(未过期、未执行、未取消)
|
||
同一交易对只返回最新的推荐(去重)
|
||
"""
|
||
return db.execute_query(
|
||
"""SELECT t1.* FROM trade_recommendations t1
|
||
INNER JOIN (
|
||
SELECT symbol, MAX(recommendation_time) as max_time
|
||
FROM trade_recommendations
|
||
WHERE status = 'active' AND (expires_at IS NULL OR expires_at > NOW())
|
||
GROUP BY symbol
|
||
) t2 ON t1.symbol = t2.symbol AND t1.recommendation_time = t2.max_time
|
||
WHERE t1.status = 'active' AND (t1.expires_at IS NULL OR t1.expires_at > NOW())
|
||
ORDER BY t1.signal_strength DESC, t1.recommendation_time DESC"""
|
||
)
|
||
|
||
@staticmethod
|
||
def get_by_id(recommendation_id):
|
||
"""根据ID获取推荐"""
|
||
return db.execute_one(
|
||
"SELECT * FROM trade_recommendations WHERE id = %s",
|
||
(recommendation_id,)
|
||
)
|
||
|
||
@staticmethod
|
||
def get_by_symbol(symbol, limit=10):
|
||
"""根据交易对获取推荐记录"""
|
||
return db.execute_query(
|
||
"""SELECT * FROM trade_recommendations
|
||
WHERE symbol = %s
|
||
ORDER BY recommendation_time DESC LIMIT %s""",
|
||
(symbol, limit)
|
||
)
|