diff --git a/MULTI_USER_ARCHITECTURE.md b/MULTI_USER_ARCHITECTURE.md index a536fb2..7ef9a51 100644 --- a/MULTI_USER_ARCHITECTURE.md +++ b/MULTI_USER_ARCHITECTURE.md @@ -138,6 +138,45 @@ --- +## 面向“更多用户”的演进战略准备(从现在就要做对的几件事) + +你问“每账号一进程是不是终极方案”。我的建议是把它当作**长期默认架构**,并提前把“可演进”埋点做对,这样未来扩容不会推倒重来。 + +### 1)固定边界:所有“私有数据”必须天然带 account_id + +- **数据库**:`trades / trading_config / account_snapshots / positions(若有)` 都必须有 `account_id`,且所有查询默认按 `account_id` 过滤。 +- **Redis**:账号私有数据统一命名空间: + - `ats:cfg:{account_id}` 或 `trading_config:{account_id}`(一账号一份配置 hash) + - `ats:positions:{account_id}`、`ats:orders:pending:{account_id}` 等 +- **API**:所有与交易/配置/统计相关的接口都要支持 `account_id`(Header 或 Path),哪怕当前只有一个账号。 + +这一步一旦做对,未来从“多进程”演进到“多 worker/分布式”几乎不改数据层。 + +### 2)把“共享层”单独做成服务:推荐/行情永远不绑定账号 + +- 推荐:一份全局 snapshot(你已拆成独立推荐进程/服务),后面可水平扩容但要有锁。 +- 行情:建议尽早演进为全局 MarketDataService(单实例拉取 + Redis 分发),账号 worker 只消费缓存。 + +这一步是从 10~30 账号走向 100+ 账号的关键,否则会先撞 Binance IP 限频。 + +### 3)演进路线(从易到难,逐步替换,不做“重写”) + +1. **阶段A(现在)**:每账号一个进程(Supervisor) + - 最稳、隔离最好、上线快 +2. **阶段B(账号增多)**:引入“控制器 + worker”但仍可单机 + - 控制器负责:调度、限频预算、健康检查、任务重启 + - worker 负责:每账号决策/下单/同步(可仍按进程隔离) +3. **阶段C(规模更大)**:队列化/分布式(K8s/多机) + - 账号按 `account_id` 分片到不同节点(sharding) + - 共享服务(行情/推荐)做成单独部署,或按区域分片 + +### 4)安全策略提前统一:API Key/Secret 必须与“普通配置”分离 + +- 强烈建议:API Key/Secret 存 `accounts` 表,**加密存储**(服务端 master key 解密),前端永不回传 secret 明文。 +- 交易进程只拿到自己账号的解密结果(进程隔离的优势)。 + +--- + ## 风险提示与建议 - **安全**:API Key 必须加密存储;前端永远不返回明文 secret。 diff --git a/backend/api/auth_deps.py b/backend/api/auth_deps.py new file mode 100644 index 0000000..ce73947 --- /dev/null +++ b/backend/api/auth_deps.py @@ -0,0 +1,72 @@ +""" +FastAPI 依赖:解析 JWT、获取当前用户、校验 admin、校验 account_id 访问权 +""" + +from __future__ import annotations + +from fastapi import Header, HTTPException, Depends +from typing import Optional, Dict, Any +import os + +from api.auth_utils import jwt_decode +from database.models import User, UserAccountMembership + + +def _auth_enabled() -> bool: + v = (os.getenv("ATS_AUTH_ENABLED") or "true").strip().lower() + return v not in {"0", "false", "no"} + + +def get_current_user(authorization: Optional[str] = Header(None, alias="Authorization")) -> Dict[str, Any]: + if not _auth_enabled(): + # 未启用登录:视为超级管理员(兼容开发/灰度) + return {"id": 0, "username": "dev", "role": "admin", "status": "active"} + + if not authorization or not authorization.lower().startswith("bearer "): + raise HTTPException(status_code=401, detail="未登录") + token = authorization.split(" ", 1)[1].strip() + try: + payload = jwt_decode(token) + except Exception: + raise HTTPException(status_code=401, detail="登录已失效") + + sub = payload.get("sub") + try: + uid = int(sub) + except Exception: + raise HTTPException(status_code=401, detail="登录已失效") + + u = User.get_by_id(uid) + if not u: + raise HTTPException(status_code=401, detail="登录已失效") + if (u.get("status") or "active") != "active": + raise HTTPException(status_code=403, detail="用户已被禁用") + return {"id": int(u["id"]), "username": u.get("username") or "", "role": u.get("role") or "user", "status": u.get("status") or "active"} + + +def require_admin(user: Dict[str, Any]) -> Dict[str, Any]: + if (user.get("role") or "user") != "admin": + raise HTTPException(status_code=403, detail="需要管理员权限") + return user + + +def require_account_access(account_id: int, user: Dict[str, Any]) -> int: + aid = int(account_id or 1) + if (user.get("role") or "user") == "admin": + return aid + if UserAccountMembership.has_access(int(user["id"]), aid): + return aid + raise HTTPException(status_code=403, detail="无权访问该账号") + + +def get_admin_user(user: Dict[str, Any] = Depends(get_current_user)) -> Dict[str, Any]: + return require_admin(user) + + +def get_account_id( + x_account_id: Optional[int] = Header(None, alias="X-Account-Id"), + user: Dict[str, Any] = Depends(get_current_user), +) -> int: + aid = int(x_account_id or 1) + return require_account_access(aid, user) + diff --git a/backend/api/auth_utils.py b/backend/api/auth_utils.py new file mode 100644 index 0000000..63d6ace --- /dev/null +++ b/backend/api/auth_utils.py @@ -0,0 +1,75 @@ +""" +登录鉴权工具(JWT + 密码哈希) + +设计目标: +- 最小依赖:密码哈希用 pbkdf2_hmac(标准库) +- JWT 使用 python-jose(已加入 requirements) +""" + +from __future__ import annotations + +import base64 +import hashlib +import hmac +import os +import time +from typing import Any, Dict, Optional + +from jose import jwt # type: ignore + + +def _jwt_secret() -> str: + s = (os.getenv("ATS_JWT_SECRET") or os.getenv("JWT_SECRET") or "").strip() + if s: + return s + # 允许开发环境兜底,但线上务必配置 + return "dev-secret-change-me" + + +def jwt_encode(payload: Dict[str, Any], exp_sec: int = 3600) -> str: + now = int(time.time()) + body = dict(payload or {}) + body["iat"] = now + body["exp"] = now + int(exp_sec) + return jwt.encode(body, _jwt_secret(), algorithm="HS256") + + +def jwt_decode(token: str) -> Dict[str, Any]: + return jwt.decode(token, _jwt_secret(), algorithms=["HS256"]) + + +def _b64(b: bytes) -> str: + return base64.urlsafe_b64encode(b).decode("utf-8").rstrip("=") + + +def _b64d(s: str) -> bytes: + s = (s or "").strip() + s = s + ("=" * (-len(s) % 4)) + return base64.urlsafe_b64decode(s.encode("utf-8")) + + +def hash_password(password: str, iterations: int = 260_000) -> str: + """ + PBKDF2-SHA256:返回格式 + pbkdf2_sha256$$$ + """ + pw = (password or "").encode("utf-8") + salt = os.urandom(16) + dk = hashlib.pbkdf2_hmac("sha256", pw, salt, int(iterations)) + return f"pbkdf2_sha256${int(iterations)}${_b64(salt)}${_b64(dk)}" + + +def verify_password(password: str, password_hash: str) -> bool: + try: + s = str(password_hash or "") + if not s.startswith("pbkdf2_sha256$"): + return False + _, it_s, salt_b64, dk_b64 = s.split("$", 3) + it = int(it_s) + salt = _b64d(salt_b64) + dk0 = _b64d(dk_b64) + dk1 = hashlib.pbkdf2_hmac("sha256", (password or "").encode("utf-8"), salt, it) + return hmac.compare_digest(dk0, dk1) + except Exception: + return False + diff --git a/backend/api/main.py b/backend/api/main.py index 41db2a7..9ade67e 100644 --- a/backend/api/main.py +++ b/backend/api/main.py @@ -3,7 +3,7 @@ FastAPI应用主入口 """ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -from api.routes import config, trades, stats, dashboard, account, recommendations, system +from api.routes import config, trades, stats, dashboard, account, recommendations, system, accounts, auth, admin import os import logging from pathlib import Path @@ -165,6 +165,42 @@ app = FastAPI( redirect_slashes=False # 禁用自动重定向,避免307重定向问题 ) +# 启动时:确保存在一个初始管理员(通过环境变量配置) +@app.on_event("startup") +async def _ensure_initial_admin(): + try: + import os + from database.models import User, UserAccountMembership + from api.auth_utils import hash_password + + username = (os.getenv("ATS_ADMIN_USERNAME") or "admin").strip() + password = (os.getenv("ATS_ADMIN_PASSWORD") or "").strip() + if not password: + # 不强制创建,避免你忘记改默认密码导致安全风险 + # 你可以设置 ATS_ADMIN_PASSWORD 后重启后端自动创建 + logger.warning("未设置 ATS_ADMIN_PASSWORD,跳过自动创建初始管理员") + return + + u = User.get_by_username(username) + if not u: + uid = User.create(username=username, password_hash=hash_password(password), role="admin", status="active") + # 默认给管理员绑定 account_id=1(default) + try: + UserAccountMembership.add(int(uid), 1, role="owner") + except Exception: + pass + logger.info(f"✓ 已创建初始管理员用户: {username} (id={uid})") + else: + # 若已存在但不是 admin,则提升为 admin(可注释掉更保守) + if (u.get("role") or "user") != "admin": + try: + User.set_role(int(u["id"]), "admin") + logger.warning(f"已将用户 {username} 提升为 admin") + except Exception: + pass + except Exception as e: + logger.warning(f"初始化管理员失败(可忽略): {e}") + # CORS配置(允许React前端访问) # 默认包含:本地开发端口、主前端域名、推荐查看器域名 cors_origins_str = os.getenv('CORS_ORIGINS', 'http://localhost:3000,http://localhost:3001,http://localhost:5173,http://as.deepx1.com,http://asapi.deepx1.com,http://r.deepx1.com,https://r.deepx1.com') @@ -183,6 +219,9 @@ app.add_middleware( # 注册路由 app.include_router(config.router, prefix="/api/config", tags=["配置管理"]) +app.include_router(auth.router, tags=["auth"]) +app.include_router(admin.router) +app.include_router(accounts.router, prefix="/api/accounts", tags=["账号管理"]) app.include_router(trades.router, prefix="/api/trades", tags=["交易记录"]) app.include_router(stats.router, prefix="/api/stats", tags=["统计分析"]) app.include_router(dashboard.router, prefix="/api/dashboard", tags=["仪表板"]) diff --git a/backend/api/routes/account.py b/backend/api/routes/account.py index abadcd4..50d28b6 100644 --- a/backend/api/routes/account.py +++ b/backend/api/routes/account.py @@ -1,7 +1,7 @@ """ 账户实时数据API - 从币安API获取实时账户和订单数据 """ -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, HTTPException, Header, Depends from fastapi import Query import sys from pathlib import Path @@ -13,21 +13,20 @@ sys.path.insert(0, str(project_root)) sys.path.insert(0, str(project_root / 'backend')) sys.path.insert(0, str(project_root / 'trading_system')) -from database.models import TradingConfig +from database.models import TradingConfig, Account +from api.auth_deps import get_account_id logger = logging.getLogger(__name__) router = APIRouter() -async def _ensure_exchange_sltp_for_symbol(symbol: str): +async def _ensure_exchange_sltp_for_symbol(symbol: str, account_id: int = 1): """ 在币安侧补挂该 symbol 的止损/止盈保护单(STOP_MARKET + TAKE_PROFIT_MARKET)。 该接口用于“手动补挂”,不依赖 trading_system 的监控任务。 """ - # 从数据库读取API密钥 - api_key = TradingConfig.get_value('BINANCE_API_KEY') - api_secret = TradingConfig.get_value('BINANCE_API_SECRET') - use_testnet = TradingConfig.get_value('USE_TESTNET', False) + # 从 accounts 表读取账号私有API密钥 + api_key, api_secret, use_testnet = Account.get_credentials(int(account_id or 1)) if not api_key or not api_secret: raise HTTPException(status_code=400, detail="API密钥未配置") @@ -262,12 +261,12 @@ async def _ensure_exchange_sltp_for_symbol(symbol: str): @router.post("/positions/{symbol}/sltp/ensure") -async def ensure_position_sltp(symbol: str): +async def ensure_position_sltp(symbol: str, account_id: int = Depends(get_account_id)): """ 手动补挂该 symbol 的止盈止损保护单(币安侧可见)。 """ try: - return await _ensure_exchange_sltp_for_symbol(symbol) + return await _ensure_exchange_sltp_for_symbol(symbol, account_id=int(account_id)) except HTTPException: raise except Exception as e: @@ -277,14 +276,15 @@ async def ensure_position_sltp(symbol: str): @router.post("/positions/sltp/ensure-all") -async def ensure_all_positions_sltp(limit: int = Query(50, ge=1, le=200, description="最多处理多少个持仓symbol")): +async def ensure_all_positions_sltp( + limit: int = Query(50, ge=1, le=200, description="最多处理多少个持仓symbol"), + account_id: int = Depends(get_account_id), +): """ 批量补挂当前所有持仓的止盈止损保护单。 """ # 先拿当前持仓symbol列表 - api_key = TradingConfig.get_value('BINANCE_API_KEY') - api_secret = TradingConfig.get_value('BINANCE_API_SECRET') - use_testnet = TradingConfig.get_value('USE_TESTNET', False) + api_key, api_secret, use_testnet = Account.get_credentials(account_id) if not api_key or not api_secret: raise HTTPException(status_code=400, detail="API密钥未配置") @@ -308,7 +308,7 @@ async def ensure_all_positions_sltp(limit: int = Query(50, ge=1, le=200, descrip errors = [] for sym in symbols: try: - res = await _ensure_exchange_sltp_for_symbol(sym) + res = await _ensure_exchange_sltp_for_symbol(sym, account_id=account_id) results.append( { "symbol": sym, @@ -339,37 +339,33 @@ async def ensure_all_positions_sltp(limit: int = Query(50, ge=1, le=200, descrip } -async def get_realtime_account_data(): +async def get_realtime_account_data(account_id: int = 1): """从币安API实时获取账户数据""" logger.info("=" * 60) logger.info("开始获取实时账户数据") logger.info("=" * 60) try: - # 从数据库读取API密钥 - logger.info("步骤1: 从数据库读取API配置...") - api_key = TradingConfig.get_value('BINANCE_API_KEY') - api_secret = TradingConfig.get_value('BINANCE_API_SECRET') - use_testnet = TradingConfig.get_value('USE_TESTNET', False) + # 从 accounts 表读取账号私有API密钥 + logger.info(f"步骤1: 从accounts读取API配置... (account_id={account_id})") + api_key, api_secret, use_testnet = Account.get_credentials(account_id) logger.info(f" - API密钥存在: {bool(api_key)}") if api_key: logger.info(f" - API密钥长度: {len(api_key)} 字符") - logger.info(f" - API密钥前缀: {api_key[:10]}...") else: logger.warning(" - API密钥为空!") - - logger.info(f" - API密钥存在: {bool(api_secret)}") + + logger.info(f" - API密钥Secret存在: {bool(api_secret)}") if api_secret: - logger.info(f" - API密钥长度: {len(api_secret)} 字符") - logger.info(f" - API密钥前缀: {api_secret[:10]}...") + logger.info(f" - API密钥Secret长度: {len(api_secret)} 字符") else: - logger.warning(" - API密钥为空!") + logger.warning(" - API密钥Secret为空!") logger.info(f" - 使用测试网: {use_testnet}") if not api_key or not api_secret: - error_msg = "API密钥未配置,请在配置界面设置BINANCE_API_KEY和BINANCE_API_SECRET" + error_msg = "API密钥未配置,请在配置界面设置该账号的BINANCE_API_KEY和BINANCE_API_SECRET" logger.error(f" ✗ {error_msg}") raise HTTPException( status_code=400, @@ -555,20 +551,17 @@ async def get_realtime_account_data(): @router.get("/realtime") -async def get_realtime_account(): +async def get_realtime_account(account_id: int = Depends(get_account_id)): """获取实时账户数据""" - return await get_realtime_account_data() + return await get_realtime_account_data(account_id=account_id) @router.get("/positions") -async def get_realtime_positions(): +async def get_realtime_positions(account_id: int = Depends(get_account_id)): """获取实时持仓数据""" client = None try: - # 从数据库读取API密钥 - api_key = TradingConfig.get_value('BINANCE_API_KEY') - api_secret = TradingConfig.get_value('BINANCE_API_SECRET') - use_testnet = TradingConfig.get_value('USE_TESTNET', False) + api_key, api_secret, use_testnet = Account.get_credentials(account_id) logger.info(f"尝试获取实时持仓数据 (testnet={use_testnet})") @@ -734,17 +727,14 @@ async def get_realtime_positions(): @router.post("/positions/{symbol}/close") -async def close_position(symbol: str): +async def close_position(symbol: str, account_id: int = Depends(get_account_id)): """手动平仓指定交易对的持仓""" try: logger.info(f"=" * 60) logger.info(f"收到平仓请求: {symbol}") logger.info(f"=" * 60) - # 从数据库读取API密钥 - api_key = TradingConfig.get_value('BINANCE_API_KEY') - api_secret = TradingConfig.get_value('BINANCE_API_SECRET') - use_testnet = TradingConfig.get_value('USE_TESTNET', False) + api_key, api_secret, use_testnet = Account.get_credentials(account_id) if not api_key or not api_secret: error_msg = "API密钥未配置" @@ -981,7 +971,7 @@ async def close_position(symbol: str): fallback_exit_price = None # 更新数据库记录 - open_trades = Trade.get_by_symbol(symbol, status='open') + open_trades = Trade.get_by_symbol(symbol, status='open', account_id=account_id) if open_trades: # 对冲模式可能有多条 trade(BUY/LONG 和 SELL/SHORT),尽量按方向匹配订单更新 used_order_ids = set() @@ -1048,17 +1038,14 @@ async def close_position(symbol: str): @router.post("/positions/sync") -async def sync_positions(): +async def sync_positions(account_id: int = Depends(get_account_id)): """同步币安实际持仓状态与数据库状态""" try: logger.info("=" * 60) logger.info("收到持仓状态同步请求") logger.info("=" * 60) - # 从数据库读取API密钥 - api_key = TradingConfig.get_value('BINANCE_API_KEY') - api_secret = TradingConfig.get_value('BINANCE_API_SECRET') - use_testnet = TradingConfig.get_value('USE_TESTNET', False) + api_key, api_secret, use_testnet = Account.get_credentials(account_id) if not api_key or not api_secret: error_msg = "API密钥未配置" @@ -1077,11 +1064,7 @@ async def sync_positions(): from database.models import Trade # 创建客户端 - client = BinanceClient( - api_key=api_key, - api_secret=api_secret, - testnet=use_testnet - ) + client = BinanceClient(api_key=api_key, api_secret=api_secret, testnet=use_testnet) logger.info("连接币安API...") await client.connect() @@ -1095,7 +1078,7 @@ async def sync_positions(): logger.info(f" 持仓列表: {', '.join(binance_symbols)}") # 2. 获取数据库中状态为open的交易记录 - db_open_trades = Trade.get_all(status='open') + db_open_trades = Trade.get_all(status='open', account_id=account_id) db_open_symbols = {t['symbol'] for t in db_open_trades} logger.info(f"数据库open状态: {len(db_open_symbols)} 个") if db_open_symbols: diff --git a/backend/api/routes/accounts.py b/backend/api/routes/accounts.py new file mode 100644 index 0000000..db80099 --- /dev/null +++ b/backend/api/routes/accounts.py @@ -0,0 +1,162 @@ +""" +账号管理 API(多账号) + +说明: +- 这是“多账号第一步”的管理入口:创建/禁用/更新密钥 +- 交易/配置/统计接口通过 X-Account-Id 头来选择账号(默认 1) +""" + +from fastapi import APIRouter, HTTPException, Depends +from pydantic import BaseModel, Field +from typing import Optional, List, Dict, Any +import logging + +from database.models import Account, UserAccountMembership +from api.auth_deps import get_current_user, get_admin_user + +logger = logging.getLogger(__name__) +router = APIRouter() + + +class AccountCreate(BaseModel): + name: str = Field(..., min_length=1, max_length=100) + api_key: Optional[str] = "" + api_secret: Optional[str] = "" + use_testnet: bool = False + status: str = Field("active", pattern="^(active|disabled)$") + + +class AccountUpdate(BaseModel): + name: Optional[str] = Field(None, min_length=1, max_length=100) + status: Optional[str] = Field(None, pattern="^(active|disabled)$") + use_testnet: Optional[bool] = None + + +class AccountCredentialsUpdate(BaseModel): + api_key: Optional[str] = None + api_secret: Optional[str] = None + use_testnet: Optional[bool] = None + + +def _mask(s: str) -> str: + s = "" if s is None else str(s) + if not s: + return "" + if len(s) <= 8: + return "****" + return f"{s[:4]}...{s[-4:]}" + + +@router.get("") +@router.get("/") +async def list_accounts(user: Dict[str, Any] = Depends(get_current_user)) -> List[Dict[str, Any]]: + try: + is_admin = (user.get("role") or "user") == "admin" + + out: List[Dict[str, Any]] = [] + if is_admin: + rows = Account.list_all() + for r in rows or []: + aid = int(r.get("id")) + api_key, api_secret, use_testnet = Account.get_credentials(aid) + out.append( + { + "id": aid, + "name": r.get("name") or "", + "status": r.get("status") or "active", + "use_testnet": bool(use_testnet), + "has_api_key": bool(api_key), + "has_api_secret": bool(api_secret), + "api_key_masked": _mask(api_key), + } + ) + return out + + memberships = UserAccountMembership.list_for_user(int(user["id"])) + account_ids = [int(m.get("account_id")) for m in (memberships or []) if m.get("account_id") is not None] + for aid in account_ids: + r = Account.get(int(aid)) + if not r: + continue + # 普通用户:不返回密钥相关字段 + _, _, use_testnet = Account.get_credentials(int(aid)) + out.append( + { + "id": int(aid), + "name": r.get("name") or "", + "status": r.get("status") or "active", + "use_testnet": bool(use_testnet), + } + ) + return out + except Exception as e: + raise HTTPException(status_code=500, detail=f"获取账号列表失败: {e}") + + +@router.post("") +@router.post("/") +async def create_account(payload: AccountCreate, _admin: Dict[str, Any] = Depends(get_admin_user)): + try: + aid = Account.create( + name=payload.name, + api_key=payload.api_key or "", + api_secret=payload.api_secret or "", + use_testnet=bool(payload.use_testnet), + status=payload.status, + ) + return {"success": True, "id": int(aid), "message": "账号已创建"} + except Exception as e: + raise HTTPException(status_code=500, detail=f"创建账号失败: {e}") + + +@router.put("/{account_id}") +async def update_account(account_id: int, payload: AccountUpdate, _admin: Dict[str, Any] = Depends(get_admin_user)): + try: + row = Account.get(int(account_id)) + if not row: + raise HTTPException(status_code=404, detail="账号不存在") + + # name/status + fields = [] + params = [] + if payload.name is not None: + fields.append("name = %s") + params.append(payload.name) + if payload.status is not None: + fields.append("status = %s") + params.append(payload.status) + if payload.use_testnet is not None: + fields.append("use_testnet = %s") + params.append(bool(payload.use_testnet)) + if fields: + params.append(int(account_id)) + from database.connection import db + + db.execute_update(f"UPDATE accounts SET {', '.join(fields)} WHERE id = %s", tuple(params)) + + return {"success": True, "message": "账号已更新"} + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"更新账号失败: {e}") + + +@router.put("/{account_id}/credentials") +async def update_credentials(account_id: int, payload: AccountCredentialsUpdate, _admin: Dict[str, Any] = Depends(get_admin_user)): + try: + row = Account.get(int(account_id)) + if not row: + raise HTTPException(status_code=404, detail="账号不存在") + + Account.update_credentials( + int(account_id), + api_key=payload.api_key, + api_secret=payload.api_secret, + use_testnet=payload.use_testnet, + ) + return {"success": True, "message": "账号密钥已更新(建议重启该账号交易进程)"} + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"更新账号密钥失败: {e}") + diff --git a/backend/api/routes/admin.py b/backend/api/routes/admin.py new file mode 100644 index 0000000..4ff4469 --- /dev/null +++ b/backend/api/routes/admin.py @@ -0,0 +1,125 @@ +""" +管理员接口:用户管理 / 授权管理 +""" + +from fastapi import APIRouter, HTTPException, Depends +from pydantic import BaseModel, Field +from typing import Optional, List, Dict, Any + +from api.auth_deps import get_admin_user +from api.auth_utils import hash_password +from database.models import User, UserAccountMembership, Account + + +router = APIRouter(prefix="/api/admin", tags=["admin"]) + + +class UserCreateReq(BaseModel): + username: str = Field(..., min_length=1, max_length=64) + password: str = Field(..., min_length=1, max_length=200) + role: str = Field("user", pattern="^(admin|user)$") + status: str = Field("active", pattern="^(active|disabled)$") + + +@router.get("/users") +async def list_users(_admin: Dict[str, Any] = Depends(get_admin_user)): + return User.list_all() + + +@router.post("/users") +async def create_user(payload: UserCreateReq, _admin: Dict[str, Any] = Depends(get_admin_user)): + exists = User.get_by_username(payload.username) + if exists: + raise HTTPException(status_code=400, detail="用户名已存在") + uid = User.create( + username=payload.username, + password_hash=hash_password(payload.password), + role=payload.role, + status=payload.status, + ) + return {"success": True, "id": int(uid)} + + +class UserPasswordReq(BaseModel): + password: str = Field(..., min_length=1, max_length=200) + + +@router.put("/users/{user_id}/password") +async def set_user_password(user_id: int, payload: UserPasswordReq, _admin: Dict[str, Any] = Depends(get_admin_user)): + u = User.get_by_id(int(user_id)) + if not u: + raise HTTPException(status_code=404, detail="用户不存在") + User.set_password(int(user_id), hash_password(payload.password)) + return {"success": True} + + +class UserRoleReq(BaseModel): + role: str = Field(..., pattern="^(admin|user)$") + + +@router.put("/users/{user_id}/role") +async def set_user_role(user_id: int, payload: UserRoleReq, _admin: Dict[str, Any] = Depends(get_admin_user)): + u = User.get_by_id(int(user_id)) + if not u: + raise HTTPException(status_code=404, detail="用户不存在") + User.set_role(int(user_id), payload.role) + return {"success": True} + + +class UserStatusReq(BaseModel): + status: str = Field(..., pattern="^(active|disabled)$") + + +@router.put("/users/{user_id}/status") +async def set_user_status(user_id: int, payload: UserStatusReq, _admin: Dict[str, Any] = Depends(get_admin_user)): + u = User.get_by_id(int(user_id)) + if not u: + raise HTTPException(status_code=404, detail="用户不存在") + User.set_status(int(user_id), payload.status) + return {"success": True} + + +@router.get("/users/{user_id}/accounts") +async def list_user_accounts(user_id: int, _admin: Dict[str, Any] = Depends(get_admin_user)): + u = User.get_by_id(int(user_id)) + if not u: + raise HTTPException(status_code=404, detail="用户不存在") + memberships = UserAccountMembership.list_for_user(int(user_id)) + # 追加账号名称(便于前端展示) + out = [] + for m in memberships or []: + aid = int(m.get("account_id")) + a = Account.get(aid) or {} + out.append( + { + "user_id": int(m.get("user_id")), + "account_id": aid, + "role": m.get("role") or "viewer", + "account_name": a.get("name") or "", + "account_status": a.get("status") or "", + } + ) + return out + + +class GrantReq(BaseModel): + role: str = Field("viewer", pattern="^(owner|viewer)$") + + +@router.put("/users/{user_id}/accounts/{account_id}") +async def grant_user_account(user_id: int, account_id: int, payload: GrantReq, _admin: Dict[str, Any] = Depends(get_admin_user)): + u = User.get_by_id(int(user_id)) + if not u: + raise HTTPException(status_code=404, detail="用户不存在") + a = Account.get(int(account_id)) + if not a: + raise HTTPException(status_code=404, detail="账号不存在") + UserAccountMembership.add(int(user_id), int(account_id), role=payload.role) + return {"success": True} + + +@router.delete("/users/{user_id}/accounts/{account_id}") +async def revoke_user_account(user_id: int, account_id: int, _admin: Dict[str, Any] = Depends(get_admin_user)): + UserAccountMembership.remove(int(user_id), int(account_id)) + return {"success": True} + diff --git a/backend/api/routes/auth.py b/backend/api/routes/auth.py new file mode 100644 index 0000000..a0e6b71 --- /dev/null +++ b/backend/api/routes/auth.py @@ -0,0 +1,71 @@ +""" +登录鉴权 API(JWT) +""" + +from fastapi import APIRouter, HTTPException, Depends +from pydantic import BaseModel, Field +from typing import Optional, Dict, Any +import os + +from database.models import User +from api.auth_utils import verify_password, jwt_encode +from api.auth_deps import get_current_user + + +router = APIRouter(prefix="/api/auth", tags=["auth"]) + + +class LoginReq(BaseModel): + username: str = Field(..., min_length=1, max_length=64) + password: str = Field(..., min_length=1, max_length=200) + + +class LoginResp(BaseModel): + access_token: str + token_type: str = "bearer" + user: Dict[str, Any] + + +def _auth_enabled() -> bool: + v = (os.getenv("ATS_AUTH_ENABLED") or "true").strip().lower() + return v not in {"0", "false", "no"} + + +@router.post("/login", response_model=LoginResp) +async def login(payload: LoginReq): + if not _auth_enabled(): + raise HTTPException(status_code=400, detail="当前环境未启用登录(ATS_AUTH_ENABLED=false)") + + u = User.get_by_username(payload.username) + if not u: + raise HTTPException(status_code=401, detail="用户名或密码错误") + if (u.get("status") or "active") != "active": + raise HTTPException(status_code=403, detail="用户已被禁用") + if not verify_password(payload.password, u.get("password_hash") or ""): + raise HTTPException(status_code=401, detail="用户名或密码错误") + + token = jwt_encode({"sub": str(u["id"]), "role": u.get("role") or "user"}, exp_sec=24 * 3600) + return { + "access_token": token, + "token_type": "bearer", + "user": {"id": u["id"], "username": u["username"], "role": u.get("role") or "user", "status": u.get("status") or "active"}, + } + + +class MeResp(BaseModel): + id: int + username: str + role: str + status: str + + +@router.get("/me", response_model=MeResp) +async def me(user: Dict[str, Any] = Depends(get_current_user)): + return { + "id": int(user["id"]), + "username": user.get("username") or "", + "role": user.get("role") or "user", + "status": user.get("status") or "active", + } + + diff --git a/backend/api/routes/config.py b/backend/api/routes/config.py index 416ac53..1b8834f 100644 --- a/backend/api/routes/config.py +++ b/backend/api/routes/config.py @@ -1,11 +1,12 @@ """ 配置管理API """ -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, HTTPException, Header, Depends from api.models.config import ConfigItem, ConfigUpdate import sys from pathlib import Path import logging +from typing import Dict, Any # 添加项目根目录到路径 project_root = Path(__file__).parent.parent.parent.parent @@ -13,11 +14,20 @@ sys.path.insert(0, str(project_root)) sys.path.insert(0, str(project_root / 'backend')) sys.path.insert(0, str(project_root / 'trading_system')) -from database.models import TradingConfig +from database.models import TradingConfig, Account +from api.auth_deps import get_current_user, get_account_id, require_admin logger = logging.getLogger(__name__) router = APIRouter() +# API key/secret 脱敏 +def _mask(s: str) -> str: + s = "" if s is None else str(s) + if not s: + return "" + if len(s) <= 8: + return "****" + return f"{s[:4]}...{s[-4:]}" # 智能入场(方案C)配置:为了“配置页可见”,即使数据库尚未创建,也在 GET /api/config 返回默认项 SMART_ENTRY_CONFIG_DEFAULTS = { "SMART_ENTRY_ENABLED": { @@ -101,10 +111,13 @@ AUTO_TRADE_FILTER_DEFAULTS = { @router.get("") @router.get("/") -async def get_all_configs(): +async def get_all_configs( + user: Dict[str, Any] = Depends(get_current_user), + account_id: int = Depends(get_account_id), +): """获取所有配置""" try: - configs = TradingConfig.get_all() + configs = TradingConfig.get_all(account_id=account_id) result = {} for config in configs: result[config['config_key']] = { @@ -117,6 +130,31 @@ async def get_all_configs(): 'description': config['description'] } + # 合并账号级 API Key/Secret(从 accounts 表读,避免把密钥当普通配置存) + try: + api_key, api_secret, use_testnet = Account.get_credentials(account_id) + except Exception: + api_key, api_secret, use_testnet = "", "", False + # 仅用于配置页展示/更新:不返回 secret 明文;api_key 仅脱敏展示 + result["BINANCE_API_KEY"] = { + "value": _mask(api_key or ""), + "type": "string", + "category": "api", + "description": "币安API密钥(账号私有,仅脱敏展示;仅管理员可修改)", + } + result["BINANCE_API_SECRET"] = { + "value": "", + "type": "string", + "category": "api", + "description": "币安API密钥Secret(账号私有,不回传明文;仅管理员可修改)", + } + result["USE_TESTNET"] = { + "value": bool(use_testnet), + "type": "boolean", + "category": "api", + "description": "是否使用测试网(账号私有)", + } + # 合并“默认但未入库”的配置项(用于新功能上线时直接在配置页可见) for k, meta in SMART_ENTRY_CONFIG_DEFAULTS.items(): if k not in result: @@ -131,7 +169,9 @@ async def get_all_configs(): @router.get("/feasibility-check") -async def check_config_feasibility(): +async def check_config_feasibility( + account_id: int = Depends(get_account_id), +): """ 检查配置可行性,基于当前账户余额和杠杆倍数计算可行的配置建议 """ @@ -139,7 +179,7 @@ async def check_config_feasibility(): # 获取账户余额 try: from api.routes.account import get_realtime_account_data - account_data = await get_realtime_account_data() + account_data = await get_realtime_account_data(account_id=account_id) available_balance = account_data.get('available_balance', 0) total_balance = account_data.get('total_balance', 0) except Exception as e: @@ -155,12 +195,12 @@ async def check_config_feasibility(): } # 获取当前配置 - min_margin_usdt = TradingConfig.get_value('MIN_MARGIN_USDT', 5.0) - min_position_percent = TradingConfig.get_value('MIN_POSITION_PERCENT', 0.02) - max_position_percent = TradingConfig.get_value('MAX_POSITION_PERCENT', 0.08) - base_leverage = TradingConfig.get_value('LEVERAGE', 10) - max_leverage = TradingConfig.get_value('MAX_LEVERAGE', 15) - use_dynamic_leverage = TradingConfig.get_value('USE_DYNAMIC_LEVERAGE', True) + min_margin_usdt = TradingConfig.get_value('MIN_MARGIN_USDT', 5.0, account_id=account_id) + min_position_percent = TradingConfig.get_value('MIN_POSITION_PERCENT', 0.02, account_id=account_id) + max_position_percent = TradingConfig.get_value('MAX_POSITION_PERCENT', 0.08, account_id=account_id) + base_leverage = TradingConfig.get_value('LEVERAGE', 10, account_id=account_id) + max_leverage = TradingConfig.get_value('MAX_LEVERAGE', 15, account_id=account_id) + use_dynamic_leverage = TradingConfig.get_value('USE_DYNAMIC_LEVERAGE', True, account_id=account_id) # 检查所有可能的杠杆倍数(考虑动态杠杆) leverage_to_check = [base_leverage] @@ -417,10 +457,23 @@ async def check_config_feasibility(): @router.get("/{key}") -async def get_config(key: str): +async def get_config( + key: str, + user: Dict[str, Any] = Depends(get_current_user), + account_id: int = Depends(get_account_id), +): """获取单个配置""" try: - config = TradingConfig.get(key) + # 虚拟字段:从 accounts 表读取 + if key in {"BINANCE_API_KEY", "BINANCE_API_SECRET", "USE_TESTNET"}: + api_key, api_secret, use_testnet = Account.get_credentials(account_id) + if key == "BINANCE_API_KEY": + return {"key": key, "value": _mask(api_key or ""), "type": "string", "category": "api", "description": "币安API密钥(仅脱敏展示)"} + if key == "BINANCE_API_SECRET": + return {"key": key, "value": "", "type": "string", "category": "api", "description": "币安API密钥Secret(不回传明文)"} + return {"key": key, "value": bool(use_testnet), "type": "boolean", "category": "api", "description": "是否使用测试网(账号私有)"} + + config = TradingConfig.get(key, account_id=account_id) if not config: raise HTTPException(status_code=404, detail="Config not found") @@ -441,11 +494,35 @@ async def get_config(key: str): @router.put("/{key}") -async def update_config(key: str, item: ConfigUpdate): +async def update_config( + key: str, + item: ConfigUpdate, + user: Dict[str, Any] = Depends(get_current_user), + account_id: int = Depends(get_account_id), +): """更新配置""" try: + # API Key/Secret/Testnet:写入 accounts 表(账号私有) + if key in {"BINANCE_API_KEY", "BINANCE_API_SECRET", "USE_TESTNET"}: + require_admin(user) + try: + if key == "BINANCE_API_KEY": + Account.update_credentials(account_id, api_key=str(item.value or "")) + elif key == "BINANCE_API_SECRET": + Account.update_credentials(account_id, api_secret=str(item.value or "")) + else: + Account.update_credentials(account_id, use_testnet=bool(item.value)) + except Exception as e: + raise HTTPException(status_code=500, detail=f"更新账号API配置失败: {e}") + return { + "message": "配置已更新", + "key": key, + "value": item.value, + "note": "账号API配置已更新(建议重启对应账号的交易进程以立即生效)", + } + # 获取现有配置以确定类型和分类 - existing = TradingConfig.get(key) + existing = TradingConfig.get(key, account_id=account_id) if existing: config_type = item.type or existing['config_type'] category = item.category or existing['category'] @@ -481,12 +558,16 @@ async def update_config(key: str, item: ConfigUpdate): ) # 更新配置(会同时更新数据库和Redis缓存) - TradingConfig.set(key, item.value, config_type, category, description) + TradingConfig.set(key, item.value, config_type, category, description, account_id=account_id) # 更新config_manager的缓存(包括Redis) try: import config_manager - if hasattr(config_manager, 'config_manager') and config_manager.config_manager: + if hasattr(config_manager, 'ConfigManager') and hasattr(config_manager.ConfigManager, "for_account"): + mgr = config_manager.ConfigManager.for_account(account_id) + mgr.set(key, item.value, config_type, category, description) + logger.info(f"配置已更新到Redis缓存(account_id={account_id}): {key} = {item.value}") + elif hasattr(config_manager, 'config_manager') and config_manager.config_manager: # 调用set方法会同时更新数据库、Redis和本地缓存 config_manager.config_manager.set(key, item.value, config_type, category, description) logger.info(f"配置已更新到Redis缓存: {key} = {item.value}") @@ -506,7 +587,11 @@ async def update_config(key: str, item: ConfigUpdate): @router.post("/batch") -async def update_configs_batch(configs: list[ConfigItem]): +async def update_configs_batch( + configs: list[ConfigItem], + user: Dict[str, Any] = Depends(get_current_user), + account_id: int = Depends(get_account_id), +): """批量更新配置""" try: updated_count = 0 @@ -514,6 +599,8 @@ async def update_configs_batch(configs: list[ConfigItem]): for item in configs: try: + if item.key in {"BINANCE_API_KEY", "BINANCE_API_SECRET", "USE_TESTNET"}: + require_admin(user) # 验证配置值 if item.type == 'number': try: @@ -528,13 +615,23 @@ async def update_configs_batch(configs: list[ConfigItem]): errors.append(f"{item.key}: Must be between 0 and 1") continue - TradingConfig.set( - item.key, - item.value, - item.type, - item.category, - item.description - ) + if item.key in {"BINANCE_API_KEY", "BINANCE_API_SECRET", "USE_TESTNET"}: + # 账号私有API配置:写入 accounts + if item.key == "BINANCE_API_KEY": + Account.update_credentials(account_id, api_key=str(item.value or "")) + elif item.key == "BINANCE_API_SECRET": + Account.update_credentials(account_id, api_secret=str(item.value or "")) + else: + Account.update_credentials(account_id, use_testnet=bool(item.value)) + else: + TradingConfig.set( + item.key, + item.value, + item.type, + item.category, + item.description, + account_id=account_id, + ) updated_count += 1 except Exception as e: errors.append(f"{item.key}: {str(e)}") diff --git a/backend/api/routes/stats.py b/backend/api/routes/stats.py index 25d61bf..266ac9f 100644 --- a/backend/api/routes/stats.py +++ b/backend/api/routes/stats.py @@ -1,7 +1,7 @@ """ 统计分析API """ -from fastapi import APIRouter, Query +from fastapi import APIRouter, Query, Header, Depends import sys from pathlib import Path from datetime import datetime, timedelta @@ -13,21 +13,25 @@ sys.path.insert(0, str(project_root / 'backend')) from database.models import AccountSnapshot, Trade, MarketScan, TradingSignal from fastapi import HTTPException +from api.auth_deps import get_account_id logger = logging.getLogger(__name__) router = APIRouter() @router.get("/performance") -async def get_performance_stats(days: int = Query(7, ge=1, le=365)): +async def get_performance_stats( + days: int = Query(7, ge=1, le=365), + account_id: int = Depends(get_account_id), +): """获取性能统计""" try: # 账户快照 - snapshots = AccountSnapshot.get_recent(days) + snapshots = AccountSnapshot.get_recent(days, account_id=account_id) # 交易统计 start_date = (datetime.now() - timedelta(days=days)).strftime('%Y-%m-%d') - trades = Trade.get_all(start_date=start_date) + trades = Trade.get_all(start_date=start_date, account_id=account_id) return { "snapshots": snapshots, @@ -39,7 +43,7 @@ async def get_performance_stats(days: int = Query(7, ge=1, le=365)): @router.get("/dashboard") -async def get_dashboard_data(): +async def get_dashboard_data(account_id: int = Depends(get_account_id)): """获取仪表板数据""" try: account_data = None @@ -48,7 +52,7 @@ async def get_dashboard_data(): # 优先尝试获取实时账户数据 try: from api.routes.account import get_realtime_account_data - account_data = await get_realtime_account_data() + account_data = await get_realtime_account_data(account_id=account_id) logger.info("成功获取实时账户数据") except HTTPException as e: # HTTPException 需要特殊处理,提取错误信息 @@ -56,7 +60,7 @@ async def get_dashboard_data(): logger.warning(f"获取实时账户数据失败 (HTTP {e.status_code}): {account_error}") # 回退到数据库快照 try: - snapshots = AccountSnapshot.get_recent(1) + snapshots = AccountSnapshot.get_recent(1, account_id=account_id) if snapshots: account_data = { "total_balance": snapshots[0].get('total_balance', 0), @@ -75,7 +79,7 @@ async def get_dashboard_data(): logger.warning(f"获取实时账户数据失败: {account_error}", exc_info=True) # 回退到数据库快照 try: - snapshots = AccountSnapshot.get_recent(1) + snapshots = AccountSnapshot.get_recent(1, account_id=account_id) if snapshots: account_data = { "total_balance": snapshots[0].get('total_balance', 0), @@ -93,7 +97,7 @@ async def get_dashboard_data(): positions_error = None try: from api.routes.account import get_realtime_positions - positions = await get_realtime_positions() + positions = await get_realtime_positions(account_id=account_id) # 转换为前端需要的格式 open_trades = positions logger.info(f"成功获取实时持仓数据: {len(open_trades)} 个持仓") @@ -102,7 +106,7 @@ async def get_dashboard_data(): logger.warning(f"获取实时持仓失败 (HTTP {e.status_code}): {positions_error}") # 回退到数据库记录 try: - db_trades = Trade.get_all(status='open')[:10] + db_trades = Trade.get_all(status='open', account_id=account_id)[:10] # 格式化数据库记录,添加 entry_value_usdt 字段 open_trades = [] for trade in db_trades: diff --git a/backend/api/routes/system.py b/backend/api/routes/system.py index 5aa0ee6..be002e0 100644 --- a/backend/api/routes/system.py +++ b/backend/api/routes/system.py @@ -6,7 +6,7 @@ import time from pathlib import Path from typing import Any, Dict, Optional, Tuple -from fastapi import APIRouter, HTTPException, Header +from fastapi import APIRouter, HTTPException, Header, Depends from pydantic import BaseModel import logging @@ -15,6 +15,14 @@ logger = logging.getLogger(__name__) # 路由统一挂在 /api/system 下,前端直接调用 /api/system/... router = APIRouter(prefix="/api/system") +# JWT 管理员鉴权(启用后替代 X-Admin-Token;未启用登录时仍可使用 X-Admin-Token 做保护) +from api.auth_deps import get_admin_user # noqa: E402 + + +def _auth_enabled() -> bool: + v = (os.getenv("ATS_AUTH_ENABLED") or "true").strip().lower() + return v not in {"0", "false", "no"} + LOG_GROUPS = ("error", "warning", "info") # 后端服务启动时间(用于前端展示“运行多久/是否已重启”) @@ -175,13 +183,11 @@ def _beijing_time_str() -> str: @router.post("/logs/test-write") async def logs_test_write( - x_admin_token: Optional[str] = Header(default=None, alias="X-Admin-Token"), + _admin: Dict[str, Any] = Depends(require_system_admin), ) -> Dict[str, Any]: """ 写入 3 条测试日志到 Redis(error/warning/info),用于验证“是否写入到同一台 Redis、同一组 key”。 """ - _require_admin(os.getenv("SYSTEM_CONTROL_TOKEN", "").strip(), x_admin_token) - client = _get_redis_client_for_logs() if client is None: raise HTTPException(status_code=503, detail="Redis 不可用,无法写入测试日志") @@ -311,7 +317,7 @@ async def get_logs( start: int = 0, service: Optional[str] = None, level: Optional[str] = None, - x_admin_token: Optional[str] = Header(default=None, alias="X-Admin-Token"), + _admin: Dict[str, Any] = Depends(require_system_admin), ) -> Dict[str, Any]: """ 从 Redis List 读取最新日志(默认 group=error -> ats:logs:error)。 @@ -322,8 +328,6 @@ async def get_logs( - service: 过滤(backend / trading_system) - level: 过滤(ERROR / CRITICAL ...) """ - _require_admin(os.getenv("SYSTEM_CONTROL_TOKEN", "").strip(), x_admin_token) - if limit <= 0: limit = 200 if limit > 20000: @@ -414,8 +418,7 @@ async def get_logs( @router.get("/logs/overview") -async def logs_overview(x_admin_token: Optional[str] = Header(default=None, alias="X-Admin-Token")) -> Dict[str, Any]: - _require_admin(os.getenv("SYSTEM_CONTROL_TOKEN", "").strip(), x_admin_token) +async def logs_overview(_admin: Dict[str, Any] = Depends(require_system_admin)) -> Dict[str, Any]: client = _get_redis_client_for_logs() if client is None: @@ -472,10 +475,8 @@ async def logs_overview(x_admin_token: Optional[str] = Header(default=None, alia @router.put("/logs/config") async def update_logs_config( payload: LogsConfigUpdate, - x_admin_token: Optional[str] = Header(default=None, alias="X-Admin-Token"), + _admin: Dict[str, Any] = Depends(require_system_admin), ) -> Dict[str, Any]: - _require_admin(os.getenv("SYSTEM_CONTROL_TOKEN", "").strip(), x_admin_token) - client = _get_redis_client_for_logs() if client is None: raise HTTPException(status_code=503, detail="Redis 不可用,无法更新日志配置") @@ -525,6 +526,16 @@ def _require_admin(token: Optional[str], provided: Optional[str]) -> None: raise HTTPException(status_code=401, detail="Unauthorized") +def require_system_admin( + x_admin_token: Optional[str] = Header(default=None, alias="X-Admin-Token"), + user: Dict[str, Any] = Depends(get_admin_user), +) -> Dict[str, Any]: + # 未启用登录:仍允许使用历史 token 保护 + if not _auth_enabled(): + _require_admin(os.getenv("SYSTEM_CONTROL_TOKEN", "").strip(), x_admin_token) + return user + + def _build_supervisorctl_cmd(args: list[str]) -> list[str]: supervisorctl_path = os.getenv("SUPERVISORCTL_PATH", "supervisorctl") supervisor_conf = os.getenv("SUPERVISOR_CONF", "").strip() @@ -677,16 +688,22 @@ def _action_with_fallback(action: str, program: str) -> Tuple[str, Optional[str] @router.post("/clear-cache") -async def clear_cache(x_admin_token: Optional[str] = Header(default=None, alias="X-Admin-Token")) -> Dict[str, Any]: +async def clear_cache( + _admin: Dict[str, Any] = Depends(require_system_admin), + x_account_id: Optional[int] = Header(default=None, alias="X-Account-Id"), +) -> Dict[str, Any]: """ 清理配置缓存(Redis Hash: trading_config),并从数据库回灌到 Redis。 """ - _require_admin(os.getenv("SYSTEM_CONTROL_TOKEN", "").strip(), x_admin_token) - try: import config_manager - cm = getattr(config_manager, "config_manager", None) + account_id = int(x_account_id or 1) + cm = None + if hasattr(config_manager, "ConfigManager") and hasattr(config_manager.ConfigManager, "for_account"): + cm = config_manager.ConfigManager.for_account(account_id) + else: + cm = getattr(config_manager, "config_manager", None) if cm is None: raise HTTPException(status_code=500, detail="config_manager 未初始化") @@ -710,10 +727,16 @@ async def clear_cache(x_admin_token: Optional[str] = Header(default=None, alias= if redis_client is not None and redis_connected: try: - redis_client.delete("trading_config") - deleted_keys.append("trading_config") + key = getattr(cm, "_redis_hash_key", "trading_config") + redis_client.delete(key) + deleted_keys.append(str(key)) + # 兼容:老 key(仅 default 账号) + legacy = getattr(cm, "_legacy_hash_key", None) + if legacy and legacy != key: + redis_client.delete(legacy) + deleted_keys.append(str(legacy)) except Exception as e: - logger.warning(f"删除 Redis key trading_config 失败: {e}") + logger.warning(f"删除 Redis key 失败: {e}") # 可选:实时推荐缓存(如果存在) try: @@ -743,8 +766,7 @@ async def clear_cache(x_admin_token: Optional[str] = Header(default=None, alias= @router.get("/trading/status") -async def trading_status(x_admin_token: Optional[str] = Header(default=None, alias="X-Admin-Token")) -> Dict[str, Any]: - _require_admin(os.getenv("SYSTEM_CONTROL_TOKEN", "").strip(), x_admin_token) +async def trading_status(_admin: Dict[str, Any] = Depends(require_system_admin)) -> Dict[str, Any]: program = _get_program_name() try: @@ -770,8 +792,7 @@ async def trading_status(x_admin_token: Optional[str] = Header(default=None, ali @router.post("/trading/start") -async def trading_start(x_admin_token: Optional[str] = Header(default=None, alias="X-Admin-Token")) -> Dict[str, Any]: - _require_admin(os.getenv("SYSTEM_CONTROL_TOKEN", "").strip(), x_admin_token) +async def trading_start(_admin: Dict[str, Any] = Depends(require_system_admin)) -> Dict[str, Any]: program = _get_program_name() try: @@ -797,8 +818,7 @@ async def trading_start(x_admin_token: Optional[str] = Header(default=None, alia @router.post("/trading/stop") -async def trading_stop(x_admin_token: Optional[str] = Header(default=None, alias="X-Admin-Token")) -> Dict[str, Any]: - _require_admin(os.getenv("SYSTEM_CONTROL_TOKEN", "").strip(), x_admin_token) +async def trading_stop(_admin: Dict[str, Any] = Depends(require_system_admin)) -> Dict[str, Any]: program = _get_program_name() try: @@ -824,8 +844,7 @@ async def trading_stop(x_admin_token: Optional[str] = Header(default=None, alias @router.post("/trading/restart") -async def trading_restart(x_admin_token: Optional[str] = Header(default=None, alias="X-Admin-Token")) -> Dict[str, Any]: - _require_admin(os.getenv("SYSTEM_CONTROL_TOKEN", "").strip(), x_admin_token) +async def trading_restart(_admin: Dict[str, Any] = Depends(require_system_admin)) -> Dict[str, Any]: program = _get_program_name() try: @@ -868,7 +887,7 @@ async def trading_restart(x_admin_token: Optional[str] = Header(default=None, al @router.get("/backend/status") -async def backend_status(x_admin_token: Optional[str] = Header(default=None, alias="X-Admin-Token")) -> Dict[str, Any]: +async def backend_status(_admin: Dict[str, Any] = Depends(require_system_admin)) -> Dict[str, Any]: """ 查看后端服务状态(当前 uvicorn 进程)。 @@ -876,7 +895,6 @@ async def backend_status(x_admin_token: Optional[str] = Header(default=None, ali - pid 使用 os.getpid()(当前 FastAPI 进程) - last_restart 从 Redis 读取(若可用) """ - _require_admin(os.getenv("SYSTEM_CONTROL_TOKEN", "").strip(), x_admin_token) meta = _system_meta_read("backend:last_restart") or {} return { "running": True, @@ -888,7 +906,7 @@ async def backend_status(x_admin_token: Optional[str] = Header(default=None, ali @router.post("/backend/restart") -async def backend_restart(x_admin_token: Optional[str] = Header(default=None, alias="X-Admin-Token")) -> Dict[str, Any]: +async def backend_restart(_admin: Dict[str, Any] = Depends(require_system_admin)) -> Dict[str, Any]: """ 重启后端服务(uvicorn)。 @@ -901,8 +919,6 @@ async def backend_restart(x_admin_token: Optional[str] = Header(default=None, al 注意: - 为了让接口能先返回,这里会延迟 1s 再执行 restart.sh """ - _require_admin(os.getenv("SYSTEM_CONTROL_TOKEN", "").strip(), x_admin_token) - backend_dir = Path(__file__).parent.parent.parent # backend/ restart_script = backend_dir / "restart.sh" if not restart_script.exists(): diff --git a/backend/api/routes/trades.py b/backend/api/routes/trades.py index 780e01e..b69fcaf 100644 --- a/backend/api/routes/trades.py +++ b/backend/api/routes/trades.py @@ -1,7 +1,7 @@ """ 交易记录API """ -from fastapi import APIRouter, Query, HTTPException +from fastapi import APIRouter, Query, HTTPException, Header, Depends from typing import Optional from datetime import datetime, timedelta from collections import Counter @@ -17,6 +17,7 @@ sys.path.insert(0, str(project_root)) sys.path.insert(0, str(project_root / 'backend')) from database.models import Trade +from api.auth_deps import get_account_id router = APIRouter() # 在模块级别创建logger(与其他路由文件保持一致) @@ -69,6 +70,7 @@ def get_timestamp_range(period: Optional[str] = None): @router.get("") @router.get("/") async def get_trades( + account_id: int = Depends(get_account_id), start_date: Optional[str] = Query(None, description="开始日期 (YYYY-MM-DD 或 YYYY-MM-DD HH:MM:SS)"), end_date: Optional[str] = Query(None, description="结束日期 (YYYY-MM-DD 或 YYYY-MM-DD HH:MM:SS)"), period: Optional[str] = Query(None, description="快速时间段筛选: '1d'(最近1天), '7d'(最近7天), '30d'(最近30天), 'today'(今天), 'week'(本周), 'month'(本月)"), @@ -122,7 +124,7 @@ async def get_trades( except ValueError: logger.warning(f"无效的结束日期格式: {end_date}") - trades = Trade.get_all(start_timestamp, end_timestamp, symbol, status, trade_type, exit_reason) + trades = Trade.get_all(start_timestamp, end_timestamp, symbol, status, trade_type, exit_reason, account_id=account_id) logger.info(f"查询到 {len(trades)} 条交易记录") # 格式化交易记录,添加平仓类型的中文显示 @@ -169,6 +171,7 @@ async def get_trades( @router.get("/stats") async def get_trade_stats( + account_id: int = Depends(get_account_id), start_date: Optional[str] = Query(None, description="开始日期 (YYYY-MM-DD 或 YYYY-MM-DD HH:MM:SS)"), end_date: Optional[str] = Query(None, description="结束日期 (YYYY-MM-DD 或 YYYY-MM-DD HH:MM:SS)"), period: Optional[str] = Query(None, description="快速时间段筛选: '1d', '7d', '30d', 'today', 'week', 'month'"), @@ -209,7 +212,7 @@ async def get_trade_stats( except ValueError: logger.warning(f"无效的结束日期格式: {end_date}") - trades = Trade.get_all(start_timestamp, end_timestamp, symbol, None) + trades = Trade.get_all(start_timestamp, end_timestamp, symbol, None, account_id=account_id) closed_trades = [t for t in trades if t['status'] == 'closed'] # 排除0盈亏的订单(abs(pnl) < 0.01 USDT视为0盈亏),这些订单不应该影响胜率统计 diff --git a/backend/config_manager.py b/backend/config_manager.py index 5defc0d..69d7624 100644 --- a/backend/config_manager.py +++ b/backend/config_manager.py @@ -35,9 +35,10 @@ sys.path.insert(0, str(project_root)) # 延迟导入,避免在trading_system中导入时因为缺少依赖而失败 try: - from database.models import TradingConfig + from database.models import TradingConfig, Account except ImportError as e: TradingConfig = None + Account = None import logging logger = logging.getLogger(__name__) logger.warning(f"无法导入TradingConfig: {e},配置管理器将无法使用数据库") @@ -58,12 +59,27 @@ except ImportError: class ConfigManager: """配置管理器 - 优先从Redis缓存读取,其次从数据库读取,回退到环境变量和默认值""" - def __init__(self): + _instances = {} + + def __init__(self, account_id: int = 1): + self.account_id = int(account_id or 1) self._cache = {} self._redis_client: Optional[redis.Redis] = None self._redis_connected = False + self._redis_hash_key = f"trading_config:{self.account_id}" + self._legacy_hash_key = "trading_config" if self.account_id == 1 else None self._init_redis() self._load_from_db() + + @classmethod + def for_account(cls, account_id: int): + aid = int(account_id or 1) + inst = cls._instances.get(aid) + if inst: + return inst + inst = cls(account_id=aid) + cls._instances[aid] = inst + return inst def _init_redis(self): """初始化Redis客户端(同步)""" @@ -151,8 +167,10 @@ class ConfigManager: return None try: - # 使用Hash存储所有配置,键为 trading_config:{key} - value = self._redis_client.hget('trading_config', key) + # 使用账号维度 Hash 存储所有配置 + value = self._redis_client.hget(self._redis_hash_key, key) + if (value is None or value == '') and self._legacy_hash_key: + value = self._redis_client.hget(self._legacy_hash_key, key) if value is not None and value != '': return self._coerce_redis_value(value) except Exception as e: @@ -217,21 +235,22 @@ class ConfigManager: return s def _set_to_redis(self, key: str, value: Any): - """设置配置到Redis""" + """设置配置到Redis(账号维度 + legacy兼容)""" if not self._redis_connected or not self._redis_client: return False try: - # 使用Hash存储所有配置,键为 trading_config:{key} # 将值序列化:复杂类型/基础类型使用 JSON,避免 bool 被写成 "False" 字符串后逻辑误判 if isinstance(value, (dict, list, bool, int, float)): value_str = json.dumps(value, ensure_ascii=False) else: value_str = str(value) - self._redis_client.hset('trading_config', key, value_str) - # 设置整个Hash的过期时间为7天(配置不会频繁变化,但需要定期刷新) - self._redis_client.expire('trading_config', 7 * 24 * 3600) + self._redis_client.hset(self._redis_hash_key, key, value_str) + self._redis_client.expire(self._redis_hash_key, 3600) + if self._legacy_hash_key: + self._redis_client.hset(self._legacy_hash_key, key, value_str) + self._redis_client.expire(self._legacy_hash_key, 3600) return True except Exception as e: logger.debug(f"设置配置到Redis失败 {key}: {e}") @@ -244,8 +263,11 @@ class ConfigManager: value_str = json.dumps(value, ensure_ascii=False) else: value_str = str(value) - self._redis_client.hset('trading_config', key, value_str) - self._redis_client.expire('trading_config', 7 * 24 * 3600) + self._redis_client.hset(self._redis_hash_key, key, value_str) + self._redis_client.expire(self._redis_hash_key, 3600) + if self._legacy_hash_key: + self._redis_client.hset(self._legacy_hash_key, key, value_str) + self._redis_client.expire(self._legacy_hash_key, 3600) return True except: self._redis_connected = False @@ -257,15 +279,23 @@ class ConfigManager: return try: - # 批量设置所有配置到Redis + # 批量设置所有配置到Redis(账号维度) pipe = self._redis_client.pipeline() for key, value in self._cache.items(): if isinstance(value, (dict, list, bool, int, float)): value_str = json.dumps(value, ensure_ascii=False) else: value_str = str(value) - pipe.hset('trading_config', key, value_str) - pipe.expire('trading_config', 7 * 24 * 3600) + pipe.hset(self._redis_hash_key, key, value_str) + pipe.expire(self._redis_hash_key, 3600) + if self._legacy_hash_key: + for key, value in self._cache.items(): + if isinstance(value, (dict, list, bool, int, float)): + value_str = json.dumps(value, ensure_ascii=False) + else: + value_str = str(value) + pipe.hset(self._legacy_hash_key, key, value_str) + pipe.expire(self._legacy_hash_key, 3600) pipe.execute() logger.debug(f"已将 {len(self._cache)} 个配置项同步到Redis") except Exception as e: @@ -284,7 +314,9 @@ class ConfigManager: try: # 测试连接是否真正可用 self._redis_client.ping() - redis_configs = self._redis_client.hgetall('trading_config') + redis_configs = self._redis_client.hgetall(self._redis_hash_key) + if (not redis_configs) and self._legacy_hash_key: + redis_configs = self._redis_client.hgetall(self._legacy_hash_key) if redis_configs and len(redis_configs) > 0: # 解析Redis中的配置 for key, value_str in redis_configs.items(): @@ -303,7 +335,7 @@ class ConfigManager: self._redis_connected = False # 从数据库加载配置(仅在Redis不可用或Redis中没有数据时) - configs = TradingConfig.get_all() + configs = TradingConfig.get_all(account_id=self.account_id) for config in configs: key = config['config_key'] value = TradingConfig._convert_value( @@ -321,6 +353,19 @@ class ConfigManager: def get(self, key, default=None): """获取配置值""" + # 账号私有:API Key/Secret/Testnet 从 accounts 表读取(不走 trading_config) + if key in ("BINANCE_API_KEY", "BINANCE_API_SECRET", "USE_TESTNET") and Account is not None: + try: + api_key, api_secret, use_testnet = Account.get_credentials(self.account_id) + if key == "BINANCE_API_KEY": + return api_key if api_key else default + if key == "BINANCE_API_SECRET": + return api_secret if api_secret else default + return bool(use_testnet) + except Exception: + # 回退到后续逻辑(旧数据/无表) + pass + # 1. 优先从Redis缓存读取(最新) # 注意:只在Redis连接正常时尝试读取,避免频繁连接失败 if self._redis_connected and self._redis_client: @@ -341,9 +386,24 @@ class ConfigManager: # 4. 返回默认值 return default - + def set(self, key, value, config_type='string', category='general', description=None): """设置配置(同时更新数据库、Redis缓存和本地缓存)""" + # 账号私有:API Key/Secret/Testnet 写入 accounts 表 + if key in ("BINANCE_API_KEY", "BINANCE_API_SECRET", "USE_TESTNET") and Account is not None: + try: + if key == "BINANCE_API_KEY": + Account.update_credentials(self.account_id, api_key=str(value or "")) + elif key == "BINANCE_API_SECRET": + Account.update_credentials(self.account_id, api_secret=str(value or "")) + else: + Account.update_credentials(self.account_id, use_testnet=bool(value)) + self._cache[key] = value + return + except Exception as e: + logger.error(f"更新账号API配置失败: {e}") + raise + if TradingConfig is None: logger.warning("TradingConfig未导入,无法更新数据库配置") self._cache[key] = value @@ -353,7 +413,7 @@ class ConfigManager: try: # 1. 更新数据库 - TradingConfig.set(key, value, config_type, category, description) + TradingConfig.set(key, value, config_type, category, description, account_id=self.account_id) # 2. 更新本地缓存 self._cache[key] = value @@ -387,7 +447,9 @@ class ConfigManager: return try: - redis_configs = self._redis_client.hgetall('trading_config') + redis_configs = self._redis_client.hgetall(self._redis_hash_key) + if (not redis_configs) and self._legacy_hash_key: + redis_configs = self._redis_client.hgetall(self._legacy_hash_key) if redis_configs and len(redis_configs) > 0: self._cache = {} # 清空缓存 for key, value_str in redis_configs.items(): @@ -474,9 +536,26 @@ class ConfigManager: } + def _sync_to_redis(self): + """将配置同步到Redis缓存(账号维度)""" + if not self._redis_connected or not self._redis_client: + return + try: + payload = {k: json.dumps(v) for k, v in self._cache.items()} + self._redis_client.hset(self._redis_hash_key, mapping=payload) + self._redis_client.expire(self._redis_hash_key, 3600) + if self._legacy_hash_key: + self._redis_client.hset(self._legacy_hash_key, mapping=payload) + self._redis_client.expire(self._legacy_hash_key, 3600) + except Exception as e: + logger.debug(f"同步配置到Redis失败: {e}") -# 全局配置管理器实例 -config_manager = ConfigManager() +# 全局配置管理器实例(默认账号;trading_system 进程可通过 ATS_ACCOUNT_ID 指定) +try: + _default_account_id = int(os.getenv("ATS_ACCOUNT_ID") or os.getenv("ACCOUNT_ID") or 1) +except Exception: + _default_account_id = 1 +config_manager = ConfigManager.for_account(_default_account_id) # 兼容原有config.py的接口 def get_config(key, default=None): diff --git a/backend/database/add_auth.sql b/backend/database/add_auth.sql new file mode 100644 index 0000000..7c57660 --- /dev/null +++ b/backend/database/add_auth.sql @@ -0,0 +1,31 @@ +-- 登录与权限系统迁移脚本(在已有库上执行一次) +-- 目标: +-- 1) 新增 users 表(管理员/普通用户) +-- 2) 新增 user_account_memberships 表(用户可访问哪些交易账号) +-- +-- 执行前建议备份数据库。 + +USE `auto_trade_sys`; + +CREATE TABLE IF NOT EXISTS `users` ( + `id` INT PRIMARY KEY AUTO_INCREMENT, + `username` VARCHAR(64) NOT NULL, + `password_hash` VARCHAR(255) NOT NULL, + `role` VARCHAR(20) NOT NULL DEFAULT 'user' COMMENT 'admin, user', + `status` VARCHAR(20) NOT NULL DEFAULT 'active' COMMENT 'active, disabled', + `created_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + `updated_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + UNIQUE KEY `uk_username` (`username`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='登录用户'; + +CREATE TABLE IF NOT EXISTS `user_account_memberships` ( + `id` INT PRIMARY KEY AUTO_INCREMENT, + `user_id` INT NOT NULL, + `account_id` INT NOT NULL, + `role` VARCHAR(20) NOT NULL DEFAULT 'viewer' COMMENT 'owner, viewer', + `created_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + UNIQUE KEY `uk_user_account` (`user_id`, `account_id`), + INDEX `idx_user_id` (`user_id`), + INDEX `idx_account_id` (`account_id`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='用户-交易账号授权'; + diff --git a/backend/database/add_multi_account.sql b/backend/database/add_multi_account.sql new file mode 100644 index 0000000..d0ad15a --- /dev/null +++ b/backend/database/add_multi_account.sql @@ -0,0 +1,91 @@ +-- 多账号迁移脚本(在已有库上执行一次) +-- 目标: +-- 1) 新增 accounts 表(存加密后的 API KEY/SECRET) +-- 2) trading_config/trades/account_snapshots 增加 account_id(默认=1) +-- 3) trading_config 的唯一约束从 config_key 改为 (account_id, config_key) +-- +-- ⚠️ 注意: +-- - 不同 MySQL 版本对 "ADD COLUMN IF NOT EXISTS" 支持不一致,因此这里用 INFORMATION_SCHEMA + 动态SQL。 +-- - 执行前建议先备份数据库。 + +USE `auto_trade_sys`; + +-- 1) accounts 表 +CREATE TABLE IF NOT EXISTS `accounts` ( + `id` INT PRIMARY KEY AUTO_INCREMENT, + `name` VARCHAR(100) NOT NULL, + `status` VARCHAR(20) DEFAULT 'active' COMMENT 'active, disabled', + `api_key_enc` TEXT NULL COMMENT '加密后的 API KEY(enc:v1:...)', + `api_secret_enc` TEXT NULL COMMENT '加密后的 API SECRET(enc:v1:...)', + `use_testnet` BOOLEAN DEFAULT FALSE, + `created_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + `updated_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='账号表(多账号)'; + +INSERT INTO `accounts` (`id`, `name`, `status`, `use_testnet`) +VALUES (1, 'default', 'active', false) +ON DUPLICATE KEY UPDATE `name`=VALUES(`name`); + +-- 2) trading_config.account_id +SET @has_col := ( + SELECT COUNT(1) + FROM INFORMATION_SCHEMA.COLUMNS + WHERE TABLE_SCHEMA = DATABASE() + AND TABLE_NAME = 'trading_config' + AND COLUMN_NAME = 'account_id' +); +SET @sql := IF(@has_col = 0, 'ALTER TABLE trading_config ADD COLUMN account_id INT NOT NULL DEFAULT 1 AFTER id', 'SELECT 1'); +PREPARE stmt FROM @sql; EXECUTE stmt; DEALLOCATE PREPARE stmt; + +-- 3) trades.account_id +SET @has_col := ( + SELECT COUNT(1) + FROM INFORMATION_SCHEMA.COLUMNS + WHERE TABLE_SCHEMA = DATABASE() + AND TABLE_NAME = 'trades' + AND COLUMN_NAME = 'account_id' +); +SET @sql := IF(@has_col = 0, 'ALTER TABLE trades ADD COLUMN account_id INT NOT NULL DEFAULT 1 AFTER id', 'SELECT 1'); +PREPARE stmt FROM @sql; EXECUTE stmt; DEALLOCATE PREPARE stmt; + +-- 4) account_snapshots.account_id +SET @has_col := ( + SELECT COUNT(1) + FROM INFORMATION_SCHEMA.COLUMNS + WHERE TABLE_SCHEMA = DATABASE() + AND TABLE_NAME = 'account_snapshots' + AND COLUMN_NAME = 'account_id' +); +SET @sql := IF(@has_col = 0, 'ALTER TABLE account_snapshots ADD COLUMN account_id INT NOT NULL DEFAULT 1 AFTER id', 'SELECT 1'); +PREPARE stmt FROM @sql; EXECUTE stmt; DEALLOCATE PREPARE stmt; + +-- 5) trading_config 唯一键:改为 (account_id, config_key) +-- 尝试删除旧 UNIQUE(config_key)(名字可能是 config_key 或其他) +SET @idx_name := ( + SELECT INDEX_NAME + FROM INFORMATION_SCHEMA.STATISTICS + WHERE TABLE_SCHEMA = DATABASE() + AND TABLE_NAME = 'trading_config' + AND NON_UNIQUE = 0 + AND COLUMN_NAME = 'config_key' + LIMIT 1 +); +SET @sql := IF(@idx_name IS NOT NULL, CONCAT('ALTER TABLE trading_config DROP INDEX ', @idx_name), 'SELECT 1'); +PREPARE stmt FROM @sql; EXECUTE stmt; DEALLOCATE PREPARE stmt; + +-- 添加新唯一键(如果不存在) +SET @has_uk := ( + SELECT COUNT(1) + FROM INFORMATION_SCHEMA.STATISTICS + WHERE TABLE_SCHEMA = DATABASE() + AND TABLE_NAME = 'trading_config' + AND INDEX_NAME = 'uk_account_config_key' +); +SET @sql := IF(@has_uk = 0, 'ALTER TABLE trading_config ADD UNIQUE KEY uk_account_config_key (account_id, config_key)', 'SELECT 1'); +PREPARE stmt FROM @sql; EXECUTE stmt; DEALLOCATE PREPARE stmt; + +-- 6) 索引(可选,老版本 MySQL 不支持 IF NOT EXISTS,可忽略报错后手动检查) +-- 如果你看到 “Duplicate key name” 可直接忽略。 +CREATE INDEX idx_trades_account_id ON trades(account_id); +CREATE INDEX idx_account_snapshots_account_id ON account_snapshots(account_id); + diff --git a/backend/database/init.sql b/backend/database/init.sql index c5400e2..c6217bf 100644 --- a/backend/database/init.sql +++ b/backend/database/init.sql @@ -4,22 +4,69 @@ CREATE DATABASE IF NOT EXISTS `auto_trade_sys` DEFAULT CHARACTER SET utf8mb4 COL USE `auto_trade_sys`; +-- 用户表(登录用户:管理员/普通用户) +CREATE TABLE IF NOT EXISTS `users` ( + `id` INT PRIMARY KEY AUTO_INCREMENT, + `username` VARCHAR(64) NOT NULL, + `password_hash` VARCHAR(255) NOT NULL, + `role` VARCHAR(20) NOT NULL DEFAULT 'user' COMMENT 'admin, user', + `status` VARCHAR(20) NOT NULL DEFAULT 'active' COMMENT 'active, disabled', + `created_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + `updated_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + UNIQUE KEY `uk_username` (`username`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='登录用户'; + +-- 用户-交易账号授权关系 +CREATE TABLE IF NOT EXISTS `user_account_memberships` ( + `id` INT PRIMARY KEY AUTO_INCREMENT, + `user_id` INT NOT NULL, + `account_id` INT NOT NULL, + `role` VARCHAR(20) NOT NULL DEFAULT 'viewer' COMMENT 'owner, viewer', + `created_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + UNIQUE KEY `uk_user_account` (`user_id`, `account_id`), + INDEX `idx_user_id` (`user_id`), + INDEX `idx_account_id` (`account_id`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='用户-交易账号授权'; + +-- 账号表(多账号) +CREATE TABLE IF NOT EXISTS `accounts` ( + `id` INT PRIMARY KEY AUTO_INCREMENT, + `name` VARCHAR(100) NOT NULL, + `status` VARCHAR(20) DEFAULT 'active' COMMENT 'active, disabled', + `api_key_enc` TEXT NULL COMMENT '加密后的 API KEY(enc:v1:...)', + `api_secret_enc` TEXT NULL COMMENT '加密后的 API SECRET(enc:v1:...)', + `use_testnet` BOOLEAN DEFAULT FALSE, + `created_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + `updated_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='账号表(多账号)'; + +-- 默认账号(兼容单账号) +INSERT INTO `accounts` (`id`, `name`, `status`, `use_testnet`) +VALUES (1, 'default', 'active', false) +ON DUPLICATE KEY UPDATE `name`=VALUES(`name`); + -- 配置表 CREATE TABLE IF NOT EXISTS `trading_config` ( `id` INT PRIMARY KEY AUTO_INCREMENT, - `config_key` VARCHAR(100) UNIQUE NOT NULL, + `account_id` INT NOT NULL DEFAULT 1, + `config_key` VARCHAR(100) NOT NULL, `config_value` TEXT NOT NULL, `config_type` VARCHAR(50) NOT NULL COMMENT 'string, number, boolean, json', `category` VARCHAR(50) NOT NULL COMMENT 'position, risk, scan, strategy, api', `description` TEXT, `updated_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, `updated_by` VARCHAR(50), - INDEX `idx_category` (`category`) + INDEX `idx_category` (`category`), + INDEX `idx_account_id` (`account_id`), + UNIQUE KEY `uk_account_config_key` (`account_id`, `config_key`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='交易配置表'; +-- 注意:多账号需要 (account_id, config_key) 唯一。旧库升级请跑迁移脚本(见 add_multi_account.sql)。 + -- 交易记录表 CREATE TABLE IF NOT EXISTS `trades` ( `id` INT PRIMARY KEY AUTO_INCREMENT, + `account_id` INT NOT NULL DEFAULT 1, `symbol` VARCHAR(20) NOT NULL, `side` VARCHAR(10) NOT NULL COMMENT 'BUY, SELL', `quantity` DECIMAL(20, 8) NOT NULL, @@ -45,6 +92,7 @@ CREATE TABLE IF NOT EXISTS `trades` ( `take_profit_2` DECIMAL(20, 8) NULL COMMENT '第二目标止盈价(用于展示与分步止盈)', `status` VARCHAR(20) DEFAULT 'open' COMMENT 'open, closed, cancelled', `created_at` INT UNSIGNED NOT NULL DEFAULT (UNIX_TIMESTAMP()) COMMENT '创建时间(Unix时间戳秒数)', + INDEX `idx_account_id` (`account_id`), INDEX `idx_symbol` (`symbol`), INDEX `idx_entry_time` (`entry_time`), INDEX `idx_status` (`status`), @@ -57,12 +105,14 @@ CREATE TABLE IF NOT EXISTS `trades` ( -- 账户快照表 CREATE TABLE IF NOT EXISTS `account_snapshots` ( `id` INT PRIMARY KEY AUTO_INCREMENT, + `account_id` INT NOT NULL DEFAULT 1, `total_balance` DECIMAL(20, 8) NOT NULL, `available_balance` DECIMAL(20, 8) NOT NULL, `total_position_value` DECIMAL(20, 8) DEFAULT 0, `total_pnl` DECIMAL(20, 8) DEFAULT 0, `open_positions` INT DEFAULT 0, `snapshot_time` TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + INDEX `idx_account_id` (`account_id`), INDEX `idx_snapshot_time` (`snapshot_time`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='账户快照表'; diff --git a/backend/database/models.py b/backend/database/models.py index 8150c4c..e9a9692 100644 --- a/backend/database/models.py +++ b/backend/database/models.py @@ -5,6 +5,7 @@ from database.connection import db from datetime import datetime, timezone, timedelta import json import logging +import os # 北京时间时区(UTC+8) BEIJING_TZ = timezone(timedelta(hours=8)) @@ -15,46 +16,241 @@ def get_beijing_time(): logger = logging.getLogger(__name__) +def _resolve_default_account_id() -> int: + """ + 默认账号ID: + - trading_system 多进程:每个进程可通过 ATS_ACCOUNT_ID 指定自己的 account_id + - backend:未传 account_id 时默认走 1(兼容单账号) + """ + for k in ("ATS_ACCOUNT_ID", "ACCOUNT_ID", "ATS_DEFAULT_ACCOUNT_ID"): + v = (os.getenv(k, "") or "").strip() + if v: + try: + return int(v) + except Exception: + continue + return 1 + + +DEFAULT_ACCOUNT_ID = _resolve_default_account_id() + + +def _table_has_column(table: str, col: str) -> bool: + try: + db.execute_one(f"SELECT {col} FROM {table} LIMIT 1") + return True + except Exception: + return False + + +class Account: + """ + 账号模型(多账号) + - API Key/Secret 建议加密存储在 accounts 表中,而不是 trading_config + """ + + @staticmethod + def get(account_id: int): + return db.execute_one("SELECT * FROM accounts WHERE id = %s", (int(account_id),)) + + @staticmethod + def list_all(): + return db.execute_query("SELECT id, name, status, created_at, updated_at FROM accounts ORDER BY id ASC") + + @staticmethod + def create(name: str, api_key: str = "", api_secret: str = "", use_testnet: bool = False, status: str = "active"): + from security.crypto import encrypt_str # 延迟导入,避免无依赖时直接崩 + + api_key_enc = encrypt_str(api_key or "") + api_secret_enc = encrypt_str(api_secret or "") + db.execute_update( + """INSERT INTO accounts (name, status, api_key_enc, api_secret_enc, use_testnet) + VALUES (%s, %s, %s, %s, %s)""", + (name, status, api_key_enc, api_secret_enc, bool(use_testnet)), + ) + return db.execute_one("SELECT LAST_INSERT_ID() as id")["id"] + + @staticmethod + def update_credentials(account_id: int, api_key: str = None, api_secret: str = None, use_testnet: bool = None): + from security.crypto import encrypt_str # 延迟导入 + + fields = [] + params = [] + if api_key is not None: + fields.append("api_key_enc = %s") + params.append(encrypt_str(api_key)) + if api_secret is not None: + fields.append("api_secret_enc = %s") + params.append(encrypt_str(api_secret)) + if use_testnet is not None: + fields.append("use_testnet = %s") + params.append(bool(use_testnet)) + if not fields: + return + params.append(int(account_id)) + db.execute_update(f"UPDATE accounts SET {', '.join(fields)} WHERE id = %s", tuple(params)) + + @staticmethod + def get_credentials(account_id: int): + """ + 返回 (api_key, api_secret, use_testnet);密文字段会自动解密。 + 若未配置 master key 且库里是明文,仍可工作(但不安全)。 + """ + row = Account.get(account_id) + if not row: + return "", "", False + try: + from security.crypto import decrypt_str + + api_key = decrypt_str(row.get("api_key_enc") or "") + api_secret = decrypt_str(row.get("api_secret_enc") or "") + except Exception: + # 兼容:无 cryptography 或未配 master key 时,先按明文兜底 + api_key = row.get("api_key_enc") or "" + api_secret = row.get("api_secret_enc") or "" + use_testnet = bool(row.get("use_testnet") or False) + return api_key, api_secret, use_testnet + + +class User: + """登录用户(管理员/普通用户)""" + + @staticmethod + def get_by_username(username: str): + return db.execute_one("SELECT * FROM users WHERE username = %s", (str(username),)) + + @staticmethod + def get_by_id(user_id: int): + return db.execute_one("SELECT * FROM users WHERE id = %s", (int(user_id),)) + + @staticmethod + def list_all(): + return db.execute_query("SELECT id, username, role, status, created_at, updated_at FROM users ORDER BY id ASC") + + @staticmethod + def create(username: str, password_hash: str, role: str = "user", status: str = "active"): + db.execute_update( + "INSERT INTO users (username, password_hash, role, status) VALUES (%s, %s, %s, %s)", + (username, password_hash, role, status), + ) + return db.execute_one("SELECT LAST_INSERT_ID() as id")["id"] + + @staticmethod + def set_password(user_id: int, password_hash: str): + db.execute_update("UPDATE users SET password_hash = %s WHERE id = %s", (password_hash, int(user_id))) + + @staticmethod + def set_status(user_id: int, status: str): + db.execute_update("UPDATE users SET status = %s WHERE id = %s", (status, int(user_id))) + + @staticmethod + def set_role(user_id: int, role: str): + db.execute_update("UPDATE users SET role = %s WHERE id = %s", (role, int(user_id))) + + +class UserAccountMembership: + """用户-交易账号授权关系""" + + @staticmethod + def add(user_id: int, account_id: int, role: str = "viewer"): + db.execute_update( + """INSERT INTO user_account_memberships (user_id, account_id, role) + VALUES (%s, %s, %s) + ON DUPLICATE KEY UPDATE role = VALUES(role)""", + (int(user_id), int(account_id), role), + ) + + @staticmethod + def remove(user_id: int, account_id: int): + db.execute_update( + "DELETE FROM user_account_memberships WHERE user_id = %s AND account_id = %s", + (int(user_id), int(account_id)), + ) + + @staticmethod + def list_for_user(user_id: int): + return db.execute_query( + "SELECT * FROM user_account_memberships WHERE user_id = %s ORDER BY account_id ASC", + (int(user_id),), + ) + + @staticmethod + def list_for_account(account_id: int): + return db.execute_query( + "SELECT * FROM user_account_memberships WHERE account_id = %s ORDER BY user_id ASC", + (int(account_id),), + ) + + @staticmethod + def has_access(user_id: int, account_id: int) -> bool: + row = db.execute_one( + "SELECT 1 as ok FROM user_account_memberships WHERE user_id = %s AND account_id = %s", + (int(user_id), int(account_id)), + ) + return bool(row) + class TradingConfig: """交易配置模型""" @staticmethod - def get_all(): + def get_all(account_id: int = None): """获取所有配置""" - return db.execute_query( - "SELECT * FROM trading_config ORDER BY category, config_key" - ) + aid = int(account_id or DEFAULT_ACCOUNT_ID) + if _table_has_column("trading_config", "account_id"): + return db.execute_query( + "SELECT * FROM trading_config WHERE account_id = %s ORDER BY category, config_key", + (aid,), + ) + return db.execute_query("SELECT * FROM trading_config ORDER BY category, config_key") @staticmethod - def get(key): + def get(key, account_id: int = None): """获取单个配置""" - return db.execute_one( - "SELECT * FROM trading_config WHERE config_key = %s", - (key,) - ) + aid = int(account_id or DEFAULT_ACCOUNT_ID) + if _table_has_column("trading_config", "account_id"): + return db.execute_one( + "SELECT * FROM trading_config WHERE account_id = %s AND config_key = %s", + (aid, key), + ) + return db.execute_one("SELECT * FROM trading_config WHERE config_key = %s", (key,)) @staticmethod - def set(key, value, config_type, category, description=None): + def set(key, value, config_type, category, description=None, account_id: int = None): """设置配置""" value_str = TradingConfig._convert_to_string(value, config_type) - db.execute_update( - """INSERT INTO trading_config - (config_key, config_value, config_type, category, description) - VALUES (%s, %s, %s, %s, %s) - ON DUPLICATE KEY UPDATE - config_value = VALUES(config_value), - config_type = VALUES(config_type), - category = VALUES(category), - description = VALUES(description), - updated_at = CURRENT_TIMESTAMP""", - (key, value_str, config_type, category, description) - ) + aid = int(account_id or DEFAULT_ACCOUNT_ID) + if _table_has_column("trading_config", "account_id"): + db.execute_update( + """INSERT INTO trading_config + (account_id, config_key, config_value, config_type, category, description) + VALUES (%s, %s, %s, %s, %s, %s) + ON DUPLICATE KEY UPDATE + config_value = VALUES(config_value), + config_type = VALUES(config_type), + category = VALUES(category), + description = VALUES(description), + updated_at = CURRENT_TIMESTAMP""", + (aid, key, value_str, config_type, category, description), + ) + else: + db.execute_update( + """INSERT INTO trading_config + (config_key, config_value, config_type, category, description) + VALUES (%s, %s, %s, %s, %s) + ON DUPLICATE KEY UPDATE + config_value = VALUES(config_value), + config_type = VALUES(config_type), + category = VALUES(category), + description = VALUES(description), + updated_at = CURRENT_TIMESTAMP""", + (key, value_str, config_type, category, description), + ) @staticmethod - def get_value(key, default=None): + def get_value(key, default=None, account_id: int = None): """获取配置值(自动转换类型)""" - result = TradingConfig.get(key) + result = TradingConfig.get(key, account_id=account_id) if result: return TradingConfig._convert_value(result['config_value'], result['config_type']) return default @@ -103,6 +299,7 @@ class Trade: atr=None, notional_usdt=None, margin_usdt=None, + account_id: int = None, ): """创建交易记录(使用北京时间) @@ -148,6 +345,10 @@ class Trade: columns = ["symbol", "side", "quantity", "entry_price", "leverage", "entry_reason", "status", "entry_time"] values = [symbol, side, quantity, entry_price, leverage, entry_reason, "open", entry_time] + if _has_column("account_id"): + columns.insert(0, "account_id") + values.insert(0, int(account_id or DEFAULT_ACCOUNT_ID)) + if _has_column("entry_order_id"): columns.append("entry_order_id") values.append(entry_order_id) @@ -325,7 +526,7 @@ class Trade: ) @staticmethod - def get_all(start_timestamp=None, end_timestamp=None, symbol=None, status=None, trade_type=None, exit_reason=None): + def get_all(start_timestamp=None, end_timestamp=None, symbol=None, status=None, trade_type=None, exit_reason=None, account_id: int = None): """获取交易记录 Args: @@ -338,6 +539,14 @@ class Trade: """ query = "SELECT * FROM trades WHERE 1=1" params = [] + + # 多账号隔离(兼容旧schema) + try: + if _table_has_column("trades", "account_id"): + query += " AND account_id = %s" + params.append(int(account_id or DEFAULT_ACCOUNT_ID)) + except Exception: + pass if start_timestamp is not None: query += " AND created_at >= %s" @@ -366,11 +575,17 @@ class Trade: return result @staticmethod - def get_by_symbol(symbol, status='open'): + def get_by_symbol(symbol, status='open', account_id: int = None): """根据交易对获取持仓""" + aid = int(account_id or DEFAULT_ACCOUNT_ID) + if _table_has_column("trades", "account_id"): + return db.execute_query( + "SELECT * FROM trades WHERE account_id = %s AND symbol = %s AND status = %s", + (aid, symbol, status), + ) return db.execute_query( "SELECT * FROM trades WHERE symbol = %s AND status = %s", - (symbol, status) + (symbol, status), ) @@ -378,24 +593,40 @@ class AccountSnapshot: """账户快照模型""" @staticmethod - def create(total_balance, available_balance, total_position_value, total_pnl, open_positions): + def create(total_balance, available_balance, total_position_value, total_pnl, open_positions, account_id: int = None): """创建账户快照(使用北京时间)""" snapshot_time = get_beijing_time() - db.execute_update( - """INSERT INTO account_snapshots - (total_balance, available_balance, total_position_value, total_pnl, open_positions, snapshot_time) - VALUES (%s, %s, %s, %s, %s, %s)""", - (total_balance, available_balance, total_position_value, total_pnl, open_positions, snapshot_time) - ) + if _table_has_column("account_snapshots", "account_id"): + db.execute_update( + """INSERT INTO account_snapshots + (account_id, total_balance, available_balance, total_position_value, total_pnl, open_positions, snapshot_time) + VALUES (%s, %s, %s, %s, %s, %s, %s)""", + (int(account_id or DEFAULT_ACCOUNT_ID), total_balance, available_balance, total_position_value, total_pnl, open_positions, snapshot_time), + ) + else: + db.execute_update( + """INSERT INTO account_snapshots + (total_balance, available_balance, total_position_value, total_pnl, open_positions, snapshot_time) + VALUES (%s, %s, %s, %s, %s, %s)""", + (total_balance, available_balance, total_position_value, total_pnl, open_positions, snapshot_time), + ) @staticmethod - def get_recent(days=7): + def get_recent(days=7, account_id: int = None): """获取最近的快照""" + aid = int(account_id or DEFAULT_ACCOUNT_ID) + if _table_has_column("account_snapshots", "account_id"): + return db.execute_query( + """SELECT * FROM account_snapshots + WHERE account_id = %s AND snapshot_time >= DATE_SUB(NOW(), INTERVAL %s DAY) + ORDER BY snapshot_time DESC""", + (aid, days), + ) return db.execute_query( """SELECT * FROM account_snapshots WHERE snapshot_time >= DATE_SUB(NOW(), INTERVAL %s DAY) ORDER BY snapshot_time DESC""", - (days,) + (days,), ) diff --git a/backend/requirements.txt b/backend/requirements.txt index 2cbf821..9fdb9fd 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -24,3 +24,9 @@ aiohttp==3.9.1 redis>=4.2.0 # 保留aioredis作为备选(如果某些代码仍在使用aioredis接口) aioredis==2.0.1 + +# 安全:加密存储敏感字段(API KEY/SECRET) +cryptography>=42.0.0 + +# 登录鉴权:JWT +python-jose[cryptography]>=3.3.0 diff --git a/backend/security/__init__.py b/backend/security/__init__.py new file mode 100644 index 0000000..7d1d34b --- /dev/null +++ b/backend/security/__init__.py @@ -0,0 +1,4 @@ +""" +安全相关工具(加密/解密等) +""" + diff --git a/backend/security/crypto.py b/backend/security/crypto.py new file mode 100644 index 0000000..8b70acd --- /dev/null +++ b/backend/security/crypto.py @@ -0,0 +1,119 @@ +""" +对称加密工具(用于存储 API Key/Secret 等敏感字段) + +说明: +- 使用 AES-GCM(需要 cryptography 依赖) +- master key 来自环境变量: + - ATS_MASTER_KEY(推荐):32字节 key 的 base64(urlsafe) 或 hex + - AUTO_TRADE_SYS_MASTER_KEY(兼容) +""" + +from __future__ import annotations + +import base64 +import binascii +import os +from typing import Optional + + +def _load_master_key_bytes() -> Optional[bytes]: + raw = ( + os.getenv("ATS_MASTER_KEY") + or os.getenv("AUTO_TRADE_SYS_MASTER_KEY") + or os.getenv("MASTER_KEY") + or "" + ).strip() + if not raw: + return None + + # 1) hex + try: + b = bytes.fromhex(raw) + if len(b) == 32: + return b + except Exception: + pass + + # 2) urlsafe base64 + try: + padded = raw + ("=" * (-len(raw) % 4)) + b = base64.urlsafe_b64decode(padded.encode("utf-8")) + if len(b) == 32: + return b + except binascii.Error: + pass + except Exception: + pass + + return None + + +def _aesgcm(): + try: + from cryptography.hazmat.primitives.ciphers.aead import AESGCM # type: ignore + + return AESGCM + except Exception as e: # pragma: no cover + raise RuntimeError( + "缺少加密依赖 cryptography,无法安全存储敏感字段。请安装 cryptography 并设置 ATS_MASTER_KEY。" + ) from e + + +def encrypt_str(plaintext: str) -> str: + """ + 加密字符串,返回带版本前缀的密文: + enc:v1:: + """ + if plaintext is None: + plaintext = "" + s = str(plaintext) + if s == "": + return "" + + key = _load_master_key_bytes() + if not key: + # 允许降级:不加密直接存(避免线上因缺KEY彻底不可用),但强烈建议尽快配置 master key + return s + + import os as _os + + AESGCM = _aesgcm() + nonce = _os.urandom(12) + aes = AESGCM(key) + ct = aes.encrypt(nonce, s.encode("utf-8"), None) + return "enc:v1:{}:{}".format( + base64.urlsafe_b64encode(nonce).decode("utf-8").rstrip("="), + base64.urlsafe_b64encode(ct).decode("utf-8").rstrip("="), + ) + + +def decrypt_str(ciphertext: str) -> str: + """ + 解密 encrypt_str 的输出;若不是 enc:v1 前缀,则视为明文原样返回(兼容旧数据)。 + """ + if ciphertext is None: + return "" + s = str(ciphertext) + if s == "": + return "" + if not s.startswith("enc:v1:"): + return s + + key = _load_master_key_bytes() + if not key: + raise RuntimeError("密文存在但未配置 ATS_MASTER_KEY,无法解密敏感字段。") + + parts = s.split(":") + if len(parts) != 4: + raise ValueError("密文格式不正确") + + b64_nonce = parts[2] + ("=" * (-len(parts[2]) % 4)) + b64_ct = parts[3] + ("=" * (-len(parts[3]) % 4)) + nonce = base64.urlsafe_b64decode(b64_nonce.encode("utf-8")) + ct = base64.urlsafe_b64decode(b64_ct.encode("utf-8")) + + AESGCM = _aesgcm() + aes = AESGCM(key) + pt = aes.decrypt(nonce, ct, None) + return pt.decode("utf-8") + diff --git a/frontend/.gitignore b/frontend/.gitignore index a547bf3..c398e44 100644 --- a/frontend/.gitignore +++ b/frontend/.gitignore @@ -22,3 +22,4 @@ dist-ssr *.njsproj *.sln *.sw? +.npm-cache/ diff --git a/frontend/src/App.css b/frontend/src/App.css index 11c9fa8..f45d24f 100644 --- a/frontend/src/App.css +++ b/frontend/src/App.css @@ -34,10 +34,48 @@ font-weight: bold; } +.nav-left { + display: flex; + flex-direction: column; + gap: 0.5rem; +} + +.nav-account { + display: flex; + align-items: center; + gap: 0.5rem; +} + +.nav-account-label { + font-size: 0.9rem; + opacity: 0.9; + font-weight: 700; +} + +.nav-account-select { + width: auto; + min-width: 180px; + padding: 0.45rem 0.7rem; + border-radius: 8px; + border: 1px solid rgba(255,255,255,0.25); + background: rgba(255,255,255,0.08); + color: #fff; + outline: none; +} + +.nav-account-select option { + color: #111; +} + @media (min-width: 768px) { .nav-title { font-size: 1.5rem; } + .nav-left { + flex-direction: row; + align-items: center; + gap: 1rem; + } } .nav-links { @@ -76,6 +114,30 @@ background-color: #34495e; } +.nav-user { + display: flex; + align-items: center; + gap: 0.75rem; +} + +.nav-user-name { + font-size: 0.9rem; + opacity: 0.9; +} + +.nav-logout { + padding: 0.45rem 0.75rem; + border-radius: 8px; + border: 1px solid rgba(255, 255, 255, 0.25); + background: rgba(255, 255, 255, 0.08); + color: white; + cursor: pointer; +} + +.nav-logout:hover { + background: rgba(255, 255, 255, 0.14); +} + .main-content { max-width: 1200px; margin: 1rem auto; diff --git a/frontend/src/App.jsx b/frontend/src/App.jsx index cf7249c..e9faecd 100644 --- a/frontend/src/App.jsx +++ b/frontend/src/App.jsx @@ -1,4 +1,4 @@ -import React, { useState } from 'react' +import React, { useEffect, useState } from 'react' import { BrowserRouter as Router, Routes, Route, Link } from 'react-router-dom' import ConfigPanel from './components/ConfigPanel' import ConfigGuide from './components/ConfigGuide' @@ -6,21 +6,78 @@ import TradeList from './components/TradeList' import StatsDashboard from './components/StatsDashboard' import Recommendations from './components/Recommendations' import LogMonitor from './components/LogMonitor' +import AccountSelector from './components/AccountSelector' +import Login from './components/Login' +import { api, clearAuthToken } from './services/api' import './App.css' function App() { + const [me, setMe] = useState(null) + const [checking, setChecking] = useState(true) + + const refreshMe = async () => { + try { + const u = await api.me() + setMe(u) + } catch (e) { + setMe(null) + } finally { + setChecking(false) + } + } + + useEffect(() => { + refreshMe() + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []) + + const isAdmin = (me?.role || '') === 'admin' + + if (checking) { + return ( +
+
+ 正在初始化... +
+
+ ) + } + + if (!me) { + return + } + return (
@@ -29,10 +86,10 @@ function App() { } /> } /> - } /> + } /> } /> } /> - } /> + :
无权限
} />
diff --git a/frontend/src/components/AccountSelector.jsx b/frontend/src/components/AccountSelector.jsx new file mode 100644 index 0000000..4b464ad --- /dev/null +++ b/frontend/src/components/AccountSelector.jsx @@ -0,0 +1,66 @@ +import React, { useEffect, useState } from 'react' +import { api, getCurrentAccountId, setCurrentAccountId } from '../services/api' + +const AccountSelector = ({ onChanged }) => { + const [accounts, setAccounts] = useState([]) + const [accountId, setAccountId] = useState(getCurrentAccountId()) + + useEffect(() => { + const load = () => { + api.getAccounts() + .then((list) => setAccounts(Array.isArray(list) ? list : [])) + .catch(() => setAccounts([])) + } + load() + + // 配置页创建/更新账号后会触发该事件,用于即时刷新下拉列表 + const onUpdated = () => load() + window.addEventListener('ats:accounts:updated', onUpdated) + return () => window.removeEventListener('ats:accounts:updated', onUpdated) + }, []) + + useEffect(() => { + setCurrentAccountId(accountId) + if (typeof onChanged === 'function') onChanged(accountId) + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [accountId]) + + const list = Array.isArray(accounts) ? accounts : [] + const options = (list.length ? list : [{ id: 1, name: 'default' }]).reduce((acc, cur) => { + if (!cur || !cur.id) return acc + if (acc.some((x) => x.id === cur.id)) return acc + acc.push(cur) + return acc + }, []) + + useEffect(() => { + if (!options.length) return + if (options.some((a) => a.id === accountId)) return + setAccountId(options[0].id) + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [options.length]) + + return ( +
+ 账号 + +
+ ) +} + +export default AccountSelector + diff --git a/frontend/src/components/ConfigPanel.css b/frontend/src/components/ConfigPanel.css index 8eaf12a..e241781 100644 --- a/frontend/src/components/ConfigPanel.css +++ b/frontend/src/components/ConfigPanel.css @@ -22,6 +22,39 @@ margin-bottom: 1rem; } +.header-left { + display: flex; + flex-direction: column; + gap: 0.5rem; +} + +.account-switch { + display: flex; + flex-wrap: wrap; + align-items: center; + gap: 0.5rem; +} + +.account-label { + font-size: 0.9rem; + color: #666; + font-weight: 700; +} + +.account-switch select { + width: auto; + min-width: 180px; + padding: 0.5rem 0.75rem; + border: 1px solid #ddd; + border-radius: 8px; + background: #fff; +} + +.account-hint { + font-size: 0.85rem; + color: #666; +} + @media (min-width: 768px) { .header-top { flex-direction: row; @@ -29,6 +62,11 @@ align-items: center; gap: 0; } + .header-left { + flex-direction: row; + align-items: center; + gap: 1rem; + } } .guide-link { @@ -505,6 +543,165 @@ line-height: 1.4; } +/* 账号管理(超管) */ +.accounts-admin-section { + margin-top: 1rem; + padding: 1rem; + border-radius: 8px; + border: 1px solid #e9ecef; + background: #fff; +} + +.accounts-admin-header { + display: flex; + flex-direction: column; + gap: 0.5rem; + justify-content: space-between; + align-items: flex-start; + margin-bottom: 0.75rem; +} + +.accounts-admin-header h3 { + margin: 0; + color: #34495e; + font-size: 1rem; +} + +.accounts-admin-actions { + display: flex; + gap: 0.5rem; + flex-wrap: wrap; +} + +.accounts-admin-body { + display: flex; + flex-direction: column; + gap: 0.75rem; +} + +.accounts-admin-card { + background: #f8f9fa; + border: 1px solid #e9ecef; + border-radius: 10px; + padding: 0.9rem; +} + +.accounts-admin-card-title { + font-weight: 800; + color: #2c3e50; + margin-bottom: 0.6rem; +} + +.accounts-form { + display: grid; + grid-template-columns: 1fr; + gap: 0.75rem; +} + +.accounts-form label { + display: flex; + flex-direction: column; + gap: 0.35rem; + font-size: 0.9rem; + color: #555; + font-weight: 600; +} + +.accounts-form input, +.accounts-form select { + padding: 0.6rem 0.75rem; + border: 1px solid #ddd; + border-radius: 8px; + font-size: 0.95rem; + background: #fff; +} + +.accounts-inline { + flex-direction: row !important; + align-items: center; + justify-content: space-between; + padding: 0.2rem 0; +} + +.accounts-form-actions { + display: flex; + gap: 0.5rem; + flex-wrap: wrap; +} + +.accounts-table { + overflow: auto; +} + +.accounts-table table { + width: 100%; + border-collapse: collapse; + background: #fff; + border-radius: 8px; + overflow: hidden; +} + +.accounts-table th, +.accounts-table td { + padding: 0.6rem 0.65rem; + border-bottom: 1px solid #eee; + font-size: 0.9rem; + text-align: left; + white-space: nowrap; +} + +.accounts-table th { + background: #f5f5f5; + font-weight: 800; + color: #555; +} + +.accounts-actions-cell { + display: flex; + gap: 0.5rem; + flex-wrap: wrap; +} + +.acct-badge { + display: inline-flex; + padding: 0.15rem 0.55rem; + border-radius: 999px; + font-weight: 800; + font-size: 0.8rem; + border: 1px solid transparent; +} + +.acct-badge.ok { + background: #e8f5e9; + border-color: #c8e6c9; + color: #2e7d32; +} + +.acct-badge.off { + background: #fff3e0; + border-color: #ffe0b2; + color: #e65100; +} + +.accounts-empty { + padding: 0.75rem; + color: #666; + font-size: 0.9rem; +} + +@media (min-width: 768px) { + .accounts-admin-header { + flex-direction: row; + align-items: center; + } + .accounts-form { + grid-template-columns: repeat(2, minmax(0, 1fr)); + } + .accounts-form-actions { + grid-column: 1 / -1; + } +} + .message { padding: 1rem; margin-bottom: 1rem; diff --git a/frontend/src/components/ConfigPanel.jsx b/frontend/src/components/ConfigPanel.jsx index 1f058ec..8fd1822 100644 --- a/frontend/src/components/ConfigPanel.jsx +++ b/frontend/src/components/ConfigPanel.jsx @@ -1,9 +1,9 @@ import React, { useState, useEffect } from 'react' import { Link } from 'react-router-dom' -import { api } from '../services/api' +import { api, getCurrentAccountId, setCurrentAccountId } from '../services/api' import './ConfigPanel.css' -const ConfigPanel = () => { +const ConfigPanel = ({ currentUser }) => { const [configs, setConfigs] = useState({}) const [loading, setLoading] = useState(true) const [saving, setSaving] = useState(false) @@ -14,6 +14,25 @@ const ConfigPanel = () => { const [backendStatus, setBackendStatus] = useState(null) const [systemBusy, setSystemBusy] = useState(false) + // 多账号:当前账号(仅用于配置页提示;全局切换器在顶部导航) + const [accountId, setAccountId] = useState(getCurrentAccountId()) + + const isAdmin = (currentUser?.role || '') === 'admin' + + // 账号管理(超管) + const [accountsAdmin, setAccountsAdmin] = useState([]) + const [accountsBusy, setAccountsBusy] = useState(false) + const [showAccountsAdmin, setShowAccountsAdmin] = useState(false) + const [newAccount, setNewAccount] = useState({ + name: '', + api_key: '', + api_secret: '', + use_testnet: false, + status: 'active', + }) + const [credEditId, setCredEditId] = useState(null) + const [credForm, setCredForm] = useState({ api_key: '', api_secret: '', use_testnet: false }) + // “PCT”类配置里有少数是“百分比数值(<=1表示<=1%)”,而不是“0~1比例” // 例如 LIMIT_ORDER_OFFSET_PCT=0.5 表示 0.5%(而不是 50%) const PCT_LIKE_KEYS = new Set([ @@ -276,6 +295,43 @@ const ConfigPanel = () => { return () => clearInterval(timer) }, []) + + const loadAccountsAdmin = async () => { + try { + const list = await api.getAccounts() + setAccountsAdmin(Array.isArray(list) ? list : []) + } catch (e) { + setAccountsAdmin([]) + } + } + + const notifyAccountsUpdated = () => { + try { + window.dispatchEvent(new Event('ats:accounts:updated')) + } catch (e) { + // ignore + } + } + + // 切换账号时,刷新页面数据 + useEffect(() => { + setCurrentAccountId(accountId) + setMessage('') + setLoading(true) + loadConfigs() + checkFeasibility() + loadSystemStatus() + loadBackendStatus() + }, [accountId]) + + // 顶部导航切换账号时(localStorage更新),这里做一个轻量同步 + useEffect(() => { + const timer = setInterval(() => { + const cur = getCurrentAccountId() + if (cur !== accountId) setAccountId(cur) + }, 1000) + return () => clearInterval(timer) + }, [accountId]) const checkFeasibility = async () => { setCheckingFeasibility(true) @@ -613,7 +669,12 @@ const ConfigPanel = () => {
-

交易配置

+
+

交易配置

+
+ 当前账号:#{accountId}(在顶部导航切换) +
+
{/* 系统控制:清缓存 / 启停 / 重启(supervisor) */} + {isAdmin ? (

系统控制

@@ -702,6 +764,274 @@ const ConfigPanel = () => { 建议流程:先更新配置里的 Key → 点击“清除缓存” → 点击“重启交易系统”,确保不再使用旧账号下单。
+ ) : null} + + {/* 账号管理(超管) */} + {isAdmin ? ( +
+
+

账号管理(多账号)

+
+ + +
+
+ + {showAccountsAdmin ? ( +
+
+
新增账号
+
+ + + + + +
+ +
+
+
+ +
+
账号列表
+
+ {(accountsAdmin || []).length ? ( + + + + + + + + + + + + + + {accountsAdmin.map((a) => ( + + + + + + + + + + ))} + +
ID名称状态测试网API KEYSECRET操作
#{a.id}{a.name || '-'} + + {a.status === 'active' ? '启用' : '禁用'} + + {a.use_testnet ? '是' : '否'}{a.api_key_masked || (a.has_api_key ? '已配置' : '未配置')}{a.has_api_secret ? '已配置' : '未配置'} + + +
+ ) : ( +
暂无账号(默认账号 #1 会自动存在)
+ )} +
+
+ + {credEditId ? ( +
+
更新账号 #{credEditId} 的密钥
+
+ + + +
+ + +
+
+
+ ) : null} +
+ ) : null} +
+ ) : null} {/* 预设方案快速切换 */}
diff --git a/frontend/src/components/Login.css b/frontend/src/components/Login.css new file mode 100644 index 0000000..710a2f6 --- /dev/null +++ b/frontend/src/components/Login.css @@ -0,0 +1,86 @@ +.login-page { + min-height: 100vh; + display: flex; + align-items: center; + justify-content: center; + background: linear-gradient(135deg, #0b1220, #101a33); + padding: 24px; +} + +.login-card { + width: 420px; + max-width: 100%; + background: rgba(255, 255, 255, 0.08); + border: 1px solid rgba(255, 255, 255, 0.12); + border-radius: 14px; + padding: 22px; + color: #e8eefc; + backdrop-filter: blur(10px); +} + +.login-title { + font-size: 20px; + font-weight: 700; + letter-spacing: 0.4px; +} + +.login-subtitle { + margin-top: 6px; + color: rgba(232, 238, 252, 0.78); + font-size: 13px; +} + +.login-field { + display: block; + margin-top: 14px; +} + +.login-label { + font-size: 12px; + color: rgba(232, 238, 252, 0.8); + margin-bottom: 6px; +} + +.login-input { + width: 100%; + box-sizing: border-box; + border-radius: 10px; + padding: 10px 12px; + border: 1px solid rgba(255, 255, 255, 0.14); + background: rgba(0, 0, 0, 0.24); + color: #e8eefc; + outline: none; +} + +.login-input:focus { + border-color: rgba(99, 179, 237, 0.7); + box-shadow: 0 0 0 3px rgba(99, 179, 237, 0.18); +} + +.login-error { + margin-top: 12px; + color: #ffd1d1; + background: rgba(255, 0, 0, 0.12); + border: 1px solid rgba(255, 0, 0, 0.22); + border-radius: 10px; + padding: 10px 12px; + font-size: 13px; +} + +.login-btn { + width: 100%; + margin-top: 14px; + border-radius: 10px; + padding: 10px 12px; + border: 1px solid rgba(255, 255, 255, 0.14); + background: rgba(59, 130, 246, 0.9); + color: white; + font-weight: 600; + cursor: pointer; +} + +.login-btn:disabled { + opacity: 0.6; + cursor: not-allowed; +} + diff --git a/frontend/src/components/Login.jsx b/frontend/src/components/Login.jsx new file mode 100644 index 0000000..8f1ba03 --- /dev/null +++ b/frontend/src/components/Login.jsx @@ -0,0 +1,72 @@ +import React, { useState } from 'react' +import { api } from '../services/api' +import './Login.css' + +const Login = ({ onLoggedIn }) => { + const [username, setUsername] = useState('') + const [password, setPassword] = useState('') + const [busy, setBusy] = useState(false) + const [error, setError] = useState('') + + const doLogin = async () => { + setBusy(true) + setError('') + try { + await api.login(username, password) + if (typeof onLoggedIn === 'function') await onLoggedIn() + } catch (e) { + setError(e?.message || '登录失败') + } finally { + setBusy(false) + } + } + + return ( +
+
+
自动交易系统
+
请先登录
+ + + + + + {error ?
{error}
: null} + + +
+
+ ) +} + +export default Login + diff --git a/frontend/src/services/api.js b/frontend/src/services/api.js index 049cbe1..1a0fddc 100644 --- a/frontend/src/services/api.js +++ b/frontend/src/services/api.js @@ -1,6 +1,63 @@ // 如果设置了VITE_API_URL环境变量,使用它;否则在开发环境使用相对路径(通过vite代理),生产环境使用默认值 const API_BASE_URL = import.meta.env.VITE_API_URL || (import.meta.env.DEV ? '' : 'http://localhost:8000'); +// 登录鉴权:JWT token(Authorization: Bearer xxx) +const AUTH_TOKEN_STORAGE_KEY = 'ats_auth_token'; +export const getAuthToken = () => { + try { + return localStorage.getItem(AUTH_TOKEN_STORAGE_KEY) || ''; + } catch (e) { + return ''; + } +}; + +export const setAuthToken = (token) => { + try { + const t = String(token || '').trim(); + if (!t) { + localStorage.removeItem(AUTH_TOKEN_STORAGE_KEY); + return; + } + localStorage.setItem(AUTH_TOKEN_STORAGE_KEY, t); + } catch (e) { + // ignore + } +}; + +export const clearAuthToken = () => setAuthToken(''); + +// 多账号:前端通过 Header 选择账号(默认 1) +const ACCOUNT_ID_STORAGE_KEY = 'ats_account_id'; +export const getCurrentAccountId = () => { + try { + const v = localStorage.getItem(ACCOUNT_ID_STORAGE_KEY); + const n = parseInt(v || '1', 10); + return Number.isFinite(n) && n > 0 ? n : 1; + } catch (e) { + return 1; + } +}; + +export const setCurrentAccountId = (accountId) => { + try { + const n = parseInt(String(accountId || '1'), 10); + localStorage.setItem(ACCOUNT_ID_STORAGE_KEY, String(Number.isFinite(n) && n > 0 ? n : 1)); + } catch (e) { + // ignore + } +}; + +const withAuthHeaders = (headers = {}) => { + const token = getAuthToken(); + if (!token) return { ...headers }; + return { ...headers, Authorization: `Bearer ${token}` }; +}; + +const withAccountHeaders = (headers = {}) => { + const aid = getCurrentAccountId(); + return withAuthHeaders({ ...headers, 'X-Account-Id': String(aid) }); +}; + // 构建API URL的辅助函数,避免双斜杠和格式问题 const buildUrl = (path) => { const baseUrl = API_BASE_URL.endsWith('/') ? API_BASE_URL.slice(0, -1) : API_BASE_URL; @@ -9,9 +66,79 @@ const buildUrl = (path) => { }; export const api = { + // 登录 + login: async (username, password) => { + const response = await fetch(buildUrl('/api/auth/login'), { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ username, password }), + }); + if (!response.ok) { + const error = await response.json().catch(() => ({ detail: '登录失败' })); + throw new Error(error.detail || '登录失败'); + } + const data = await response.json(); + if (data?.access_token) setAuthToken(data.access_token); + return data; + }, + me: async () => { + const response = await fetch(buildUrl('/api/auth/me'), { headers: withAuthHeaders() }); + if (!response.ok) { + const error = await response.json().catch(() => ({ detail: '获取登录信息失败' })); + throw new Error(error.detail || '获取登录信息失败'); + } + return response.json(); + }, + + // 账号管理 + getAccounts: async () => { + const response = await fetch(buildUrl('/api/accounts'), { headers: withAccountHeaders() }); + if (!response.ok) { + const error = await response.json().catch(() => ({ detail: '获取账号列表失败' })); + throw new Error(error.detail || '获取账号列表失败'); + } + return response.json(); + }, + createAccount: async (data) => { + const response = await fetch(buildUrl('/api/accounts'), { + method: 'POST', + headers: withAccountHeaders({ 'Content-Type': 'application/json' }), + body: JSON.stringify(data || {}), + }); + if (!response.ok) { + const error = await response.json().catch(() => ({ detail: '创建账号失败' })); + throw new Error(error.detail || '创建账号失败'); + } + return response.json(); + }, + updateAccount: async (accountId, data) => { + const response = await fetch(buildUrl(`/api/accounts/${accountId}`), { + method: 'PUT', + headers: withAccountHeaders({ 'Content-Type': 'application/json' }), + body: JSON.stringify(data || {}), + }); + if (!response.ok) { + const error = await response.json().catch(() => ({ detail: '更新账号失败' })); + throw new Error(error.detail || '更新账号失败'); + } + return response.json(); + }, + updateAccountCredentials: async (accountId, data) => { + const response = await fetch(buildUrl(`/api/accounts/${accountId}/credentials`), { + method: 'PUT', + headers: withAccountHeaders({ 'Content-Type': 'application/json' }), + body: JSON.stringify(data || {}), + }); + if (!response.ok) { + const error = await response.json().catch(() => ({ detail: '更新账号密钥失败' })); + throw new Error(error.detail || '更新账号密钥失败'); + } + return response.json(); + }, + // 配置管理 getConfigs: async () => { - const response = await fetch(buildUrl('/api/config')); + const response = await fetch(buildUrl('/api/config'), { headers: withAccountHeaders() }); if (!response.ok) { const error = await response.json().catch(() => ({ detail: '获取配置失败' })); throw new Error(error.detail || '获取配置失败'); @@ -20,7 +147,7 @@ export const api = { }, getConfig: async (key) => { - const response = await fetch(buildUrl(`/api/config/${key}`)); + const response = await fetch(buildUrl(`/api/config/${key}`), { headers: withAccountHeaders() }); if (!response.ok) { const error = await response.json().catch(() => ({ detail: '获取配置失败' })); throw new Error(error.detail || '获取配置失败'); @@ -31,7 +158,7 @@ export const api = { updateConfig: async (key, data) => { const response = await fetch(buildUrl(`/api/config/${key}`), { method: 'PUT', - headers: {'Content-Type': 'application/json'}, + headers: withAccountHeaders({'Content-Type': 'application/json'}), body: JSON.stringify(data) }); if (!response.ok) { @@ -44,7 +171,7 @@ export const api = { updateConfigsBatch: async (configs) => { const response = await fetch(buildUrl('/api/config/batch'), { method: 'POST', - headers: {'Content-Type': 'application/json'}, + headers: withAccountHeaders({'Content-Type': 'application/json'}), body: JSON.stringify(configs) }); if (!response.ok) { @@ -56,7 +183,7 @@ export const api = { // 检查配置可行性 checkConfigFeasibility: async () => { - const response = await fetch(buildUrl('/api/config/feasibility-check')); + const response = await fetch(buildUrl('/api/config/feasibility-check'), { headers: withAccountHeaders() }); if (!response.ok) { const error = await response.json().catch(() => ({ detail: '检查配置可行性失败' })); throw new Error(error.detail || '检查配置可行性失败'); @@ -71,7 +198,7 @@ export const api = { const response = await fetch(url, { method: 'GET', headers: { - 'Content-Type': 'application/json', + ...withAccountHeaders({ 'Content-Type': 'application/json' }), }, }); if (!response.ok) { @@ -87,7 +214,7 @@ export const api = { const response = await fetch(url, { method: 'GET', headers: { - 'Content-Type': 'application/json', + ...withAccountHeaders({ 'Content-Type': 'application/json' }), }, }); if (!response.ok) { @@ -99,7 +226,7 @@ export const api = { // 统计 getPerformance: async (days = 7) => { - const response = await fetch(buildUrl(`/api/stats/performance?days=${days}`)); + const response = await fetch(buildUrl(`/api/stats/performance?days=${days}`), { headers: withAccountHeaders() }); if (!response.ok) { const error = await response.json().catch(() => ({ detail: '获取性能统计失败' })); throw new Error(error.detail || '获取性能统计失败'); @@ -108,7 +235,7 @@ export const api = { }, getDashboard: async () => { - const response = await fetch(buildUrl('/api/dashboard')); + const response = await fetch(buildUrl('/api/dashboard'), { headers: withAccountHeaders() }); if (!response.ok) { const error = await response.json().catch(() => ({ detail: '获取仪表板数据失败' })); throw new Error(error.detail || '获取仪表板数据失败'); @@ -121,7 +248,7 @@ export const api = { const response = await fetch(buildUrl(`/api/account/positions/${symbol}/close`), { method: 'POST', headers: { - 'Content-Type': 'application/json', + ...withAccountHeaders({ 'Content-Type': 'application/json' }), }, }); if (!response.ok) { @@ -135,7 +262,7 @@ export const api = { ensurePositionSLTP: async (symbol) => { const response = await fetch(buildUrl(`/api/account/positions/${symbol}/sltp/ensure`), { method: 'POST', - headers: { 'Content-Type': 'application/json' }, + headers: withAccountHeaders({ 'Content-Type': 'application/json' }), }); if (!response.ok) { const error = await response.json().catch(() => ({ detail: '补挂止盈止损失败' })); @@ -147,7 +274,7 @@ export const api = { ensureAllPositionsSLTP: async (limit = 50) => { const response = await fetch(buildUrl(`/api/account/positions/sltp/ensure-all?limit=${encodeURIComponent(limit)}`), { method: 'POST', - headers: { 'Content-Type': 'application/json' }, + headers: withAccountHeaders({ 'Content-Type': 'application/json' }), }); if (!response.ok) { const error = await response.json().catch(() => ({ detail: '批量补挂止盈止损失败' })); @@ -161,7 +288,7 @@ export const api = { const response = await fetch(buildUrl('/api/account/positions/sync'), { method: 'POST', headers: { - 'Content-Type': 'application/json', + ...withAccountHeaders({ 'Content-Type': 'application/json' }), }, }); if (!response.ok) { @@ -179,7 +306,7 @@ export const api = { } const query = new URLSearchParams(params).toString(); const url = query ? `${buildUrl('/api/recommendations')}?${query}` : buildUrl('/api/recommendations'); - const response = await fetch(url); + const response = await fetch(url, { headers: withAccountHeaders() }); if (!response.ok) { const error = await response.json().catch(() => ({ detail: '获取推荐失败' })); throw new Error(error.detail || '获取推荐失败'); @@ -190,7 +317,7 @@ export const api = { bookmarkRecommendation: async (recommendationData) => { const response = await fetch(buildUrl('/api/recommendations/bookmark'), { method: 'POST', - headers: {'Content-Type': 'application/json'}, + headers: withAccountHeaders({ 'Content-Type': 'application/json' }), body: JSON.stringify(recommendationData) }); if (!response.ok) { @@ -201,7 +328,7 @@ export const api = { }, getActiveRecommendations: async () => { - const response = await fetch(buildUrl('/api/recommendations/active')); + const response = await fetch(buildUrl('/api/recommendations/active'), { headers: withAccountHeaders() }); if (!response.ok) { const error = await response.json().catch(() => ({ detail: '获取有效推荐失败' })); throw new Error(error.detail || '获取有效推荐失败'); @@ -210,7 +337,7 @@ export const api = { }, getRecommendation: async (id) => { - const response = await fetch(buildUrl(`/api/recommendations/${id}`)); + const response = await fetch(buildUrl(`/api/recommendations/${id}`), { headers: withAccountHeaders() }); if (!response.ok) { const error = await response.json().catch(() => ({ detail: '获取推荐详情失败' })); throw new Error(error.detail || '获取推荐详情失败'); @@ -223,9 +350,7 @@ export const api = { buildUrl(`/api/recommendations/generate?min_signal_strength=${minSignalStrength}&max_recommendations=${maxRecommendations}`), { method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, + headers: withAccountHeaders({ 'Content-Type': 'application/json' }), } ); if (!response.ok) { @@ -238,9 +363,7 @@ export const api = { markRecommendationExecuted: async (id, tradeId = null) => { const response = await fetch(buildUrl(`/api/recommendations/${id}/execute`), { method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, + headers: withAccountHeaders({ 'Content-Type': 'application/json' }), body: JSON.stringify({ trade_id: tradeId }), }); if (!response.ok) { @@ -253,9 +376,7 @@ export const api = { cancelRecommendation: async (id, notes = null) => { const response = await fetch(buildUrl(`/api/recommendations/${id}/cancel`), { method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, + headers: withAccountHeaders({ 'Content-Type': 'application/json' }), body: JSON.stringify({ notes }), }); if (!response.ok) { @@ -269,7 +390,7 @@ export const api = { clearSystemCache: async () => { const response = await fetch(buildUrl('/api/system/clear-cache'), { method: 'POST', - headers: { 'Content-Type': 'application/json' }, + headers: withAccountHeaders({ 'Content-Type': 'application/json' }), }); if (!response.ok) { const error = await response.json().catch(() => ({ detail: '清理缓存失败' })); @@ -279,7 +400,7 @@ export const api = { }, getTradingSystemStatus: async () => { - const response = await fetch(buildUrl('/api/system/trading/status')); + const response = await fetch(buildUrl('/api/system/trading/status'), { headers: withAccountHeaders() }); if (!response.ok) { const error = await response.json().catch(() => ({ detail: '获取交易系统状态失败' })); throw new Error(error.detail || '获取交易系统状态失败'); @@ -290,7 +411,7 @@ export const api = { startTradingSystem: async () => { const response = await fetch(buildUrl('/api/system/trading/start'), { method: 'POST', - headers: { 'Content-Type': 'application/json' }, + headers: withAccountHeaders({ 'Content-Type': 'application/json' }), }); if (!response.ok) { const error = await response.json().catch(() => ({ detail: '启动交易系统失败' })); @@ -302,7 +423,7 @@ export const api = { stopTradingSystem: async () => { const response = await fetch(buildUrl('/api/system/trading/stop'), { method: 'POST', - headers: { 'Content-Type': 'application/json' }, + headers: withAccountHeaders({ 'Content-Type': 'application/json' }), }); if (!response.ok) { const error = await response.json().catch(() => ({ detail: '停止交易系统失败' })); @@ -314,7 +435,7 @@ export const api = { restartTradingSystem: async () => { const response = await fetch(buildUrl('/api/system/trading/restart'), { method: 'POST', - headers: { 'Content-Type': 'application/json' }, + headers: withAccountHeaders({ 'Content-Type': 'application/json' }), }); if (!response.ok) { const error = await response.json().catch(() => ({ detail: '重启交易系统失败' })); @@ -325,7 +446,7 @@ export const api = { // 后端控制(uvicorn) getBackendStatus: async () => { - const response = await fetch(buildUrl('/api/system/backend/status')); + const response = await fetch(buildUrl('/api/system/backend/status'), { headers: withAccountHeaders() }); if (!response.ok) { const error = await response.json().catch(() => ({ detail: '获取后端状态失败' })); throw new Error(error.detail || '获取后端状态失败'); @@ -336,7 +457,7 @@ export const api = { restartBackend: async () => { const response = await fetch(buildUrl('/api/system/backend/restart'), { method: 'POST', - headers: { 'Content-Type': 'application/json' }, + headers: withAccountHeaders({ 'Content-Type': 'application/json' }), }); if (!response.ok) { const error = await response.json().catch(() => ({ detail: '重启后端失败' })); @@ -349,7 +470,7 @@ export const api = { getSystemLogs: async (params = {}) => { const query = new URLSearchParams(params).toString(); const url = query ? `${buildUrl('/api/system/logs')}?${query}` : buildUrl('/api/system/logs'); - const response = await fetch(url); + const response = await fetch(url, { headers: withAccountHeaders() }); if (!response.ok) { const error = await response.json().catch(() => ({ detail: '获取日志失败' })); throw new Error(error.detail || '获取日志失败'); @@ -358,7 +479,7 @@ export const api = { }, getLogsOverview: async () => { - const response = await fetch(buildUrl('/api/system/logs/overview')); + const response = await fetch(buildUrl('/api/system/logs/overview'), { headers: withAccountHeaders() }); if (!response.ok) { const error = await response.json().catch(() => ({ detail: '获取日志概览失败' })); throw new Error(error.detail || '获取日志概览失败'); @@ -369,7 +490,7 @@ export const api = { updateLogsConfig: async (data) => { const response = await fetch(buildUrl('/api/system/logs/config'), { method: 'PUT', - headers: { 'Content-Type': 'application/json' }, + headers: withAccountHeaders({ 'Content-Type': 'application/json' }), body: JSON.stringify(data || {}), }); if (!response.ok) { @@ -382,7 +503,7 @@ export const api = { writeLogsTest: async () => { const response = await fetch(buildUrl('/api/system/logs/test-write'), { method: 'POST', - headers: { 'Content-Type': 'application/json' }, + headers: withAccountHeaders({ 'Content-Type': 'application/json' }), }); if (!response.ok) { const error = await response.json().catch(() => ({ detail: '写入测试日志失败' })); diff --git a/trading_system/requirements.txt b/trading_system/requirements.txt index b666dc3..d3b6d47 100644 --- a/trading_system/requirements.txt +++ b/trading_system/requirements.txt @@ -14,3 +14,6 @@ python-dotenv==1.0.0 redis>=4.2.0 # 保留aioredis作为备选(向后兼容,如果某些代码仍在使用) aioredis==2.0.1 + +# 安全:加密存储敏感字段(API KEY/SECRET,后端 models 复用) +cryptography>=42.0.0