163 lines
5.5 KiB
Python
163 lines
5.5 KiB
Python
"""
|
||
账号管理 API(多账号)
|
||
|
||
说明:
|
||
- 这是“多账号第一步”的管理入口:创建/禁用/更新密钥
|
||
- 交易/配置/统计接口通过 X-Account-Id 头来选择账号(默认 1)
|
||
"""
|
||
|
||
from fastapi import APIRouter, HTTPException, Depends
|
||
from pydantic import BaseModel, Field
|
||
from typing import Optional, List, Dict, Any
|
||
import logging
|
||
|
||
from database.models import Account, UserAccountMembership
|
||
from api.auth_deps import get_current_user, get_admin_user
|
||
|
||
logger = logging.getLogger(__name__)
|
||
router = APIRouter()
|
||
|
||
|
||
class AccountCreate(BaseModel):
|
||
name: str = Field(..., min_length=1, max_length=100)
|
||
api_key: Optional[str] = ""
|
||
api_secret: Optional[str] = ""
|
||
use_testnet: bool = False
|
||
status: str = Field("active", pattern="^(active|disabled)$")
|
||
|
||
|
||
class AccountUpdate(BaseModel):
|
||
name: Optional[str] = Field(None, min_length=1, max_length=100)
|
||
status: Optional[str] = Field(None, pattern="^(active|disabled)$")
|
||
use_testnet: Optional[bool] = None
|
||
|
||
|
||
class AccountCredentialsUpdate(BaseModel):
|
||
api_key: Optional[str] = None
|
||
api_secret: Optional[str] = None
|
||
use_testnet: Optional[bool] = None
|
||
|
||
|
||
def _mask(s: str) -> str:
|
||
s = "" if s is None else str(s)
|
||
if not s:
|
||
return ""
|
||
if len(s) <= 8:
|
||
return "****"
|
||
return f"{s[:4]}...{s[-4:]}"
|
||
|
||
|
||
@router.get("")
|
||
@router.get("/")
|
||
async def list_accounts(user: Dict[str, Any] = Depends(get_current_user)) -> List[Dict[str, Any]]:
|
||
try:
|
||
is_admin = (user.get("role") or "user") == "admin"
|
||
|
||
out: List[Dict[str, Any]] = []
|
||
if is_admin:
|
||
rows = Account.list_all()
|
||
for r in rows or []:
|
||
aid = int(r.get("id"))
|
||
api_key, api_secret, use_testnet = Account.get_credentials(aid)
|
||
out.append(
|
||
{
|
||
"id": aid,
|
||
"name": r.get("name") or "",
|
||
"status": r.get("status") or "active",
|
||
"use_testnet": bool(use_testnet),
|
||
"has_api_key": bool(api_key),
|
||
"has_api_secret": bool(api_secret),
|
||
"api_key_masked": _mask(api_key),
|
||
}
|
||
)
|
||
return out
|
||
|
||
memberships = UserAccountMembership.list_for_user(int(user["id"]))
|
||
account_ids = [int(m.get("account_id")) for m in (memberships or []) if m.get("account_id") is not None]
|
||
for aid in account_ids:
|
||
r = Account.get(int(aid))
|
||
if not r:
|
||
continue
|
||
# 普通用户:不返回密钥相关字段
|
||
_, _, use_testnet = Account.get_credentials(int(aid))
|
||
out.append(
|
||
{
|
||
"id": int(aid),
|
||
"name": r.get("name") or "",
|
||
"status": r.get("status") or "active",
|
||
"use_testnet": bool(use_testnet),
|
||
}
|
||
)
|
||
return out
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"获取账号列表失败: {e}")
|
||
|
||
|
||
@router.post("")
|
||
@router.post("/")
|
||
async def create_account(payload: AccountCreate, _admin: Dict[str, Any] = Depends(get_admin_user)):
|
||
try:
|
||
aid = Account.create(
|
||
name=payload.name,
|
||
api_key=payload.api_key or "",
|
||
api_secret=payload.api_secret or "",
|
||
use_testnet=bool(payload.use_testnet),
|
||
status=payload.status,
|
||
)
|
||
return {"success": True, "id": int(aid), "message": "账号已创建"}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"创建账号失败: {e}")
|
||
|
||
|
||
@router.put("/{account_id}")
|
||
async def update_account(account_id: int, payload: AccountUpdate, _admin: Dict[str, Any] = Depends(get_admin_user)):
|
||
try:
|
||
row = Account.get(int(account_id))
|
||
if not row:
|
||
raise HTTPException(status_code=404, detail="账号不存在")
|
||
|
||
# name/status
|
||
fields = []
|
||
params = []
|
||
if payload.name is not None:
|
||
fields.append("name = %s")
|
||
params.append(payload.name)
|
||
if payload.status is not None:
|
||
fields.append("status = %s")
|
||
params.append(payload.status)
|
||
if payload.use_testnet is not None:
|
||
fields.append("use_testnet = %s")
|
||
params.append(bool(payload.use_testnet))
|
||
if fields:
|
||
params.append(int(account_id))
|
||
from database.connection import db
|
||
|
||
db.execute_update(f"UPDATE accounts SET {', '.join(fields)} WHERE id = %s", tuple(params))
|
||
|
||
return {"success": True, "message": "账号已更新"}
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"更新账号失败: {e}")
|
||
|
||
|
||
@router.put("/{account_id}/credentials")
|
||
async def update_credentials(account_id: int, payload: AccountCredentialsUpdate, _admin: Dict[str, Any] = Depends(get_admin_user)):
|
||
try:
|
||
row = Account.get(int(account_id))
|
||
if not row:
|
||
raise HTTPException(status_code=404, detail="账号不存在")
|
||
|
||
Account.update_credentials(
|
||
int(account_id),
|
||
api_key=payload.api_key,
|
||
api_secret=payload.api_secret,
|
||
use_testnet=payload.use_testnet,
|
||
)
|
||
return {"success": True, "message": "账号密钥已更新(建议重启该账号交易进程)"}
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"更新账号密钥失败: {e}")
|
||
|