""" FastAPI backend:将日志写入 Redis List(按 error / warning / info 分组,仅保留最近 N 条)。 实现与 trading_system/redis_log_handler.py 保持一致(避免跨目录导入带来的 PYTHONPATH 问题)。 """ from __future__ import annotations import json import logging import os import socket import time import traceback from dataclasses import dataclass from datetime import datetime, timezone, timedelta from typing import Any, Dict, Optional, Literal def _beijing_time_str(ts: float) -> str: beijing_tz = timezone(timedelta(hours=8)) return datetime.fromtimestamp(ts, tz=beijing_tz).strftime("%Y-%m-%d %H:%M:%S") def _beijing_yyyymmdd(ts: Optional[float] = None) -> str: beijing_tz = timezone(timedelta(hours=8)) dt = datetime.fromtimestamp(ts or time.time(), tz=beijing_tz) return dt.strftime("%Y%m%d") def _safe_json_loads(s: str) -> Optional[Dict[str, Any]]: try: obj = json.loads(s) if isinstance(obj, dict): return obj except Exception: return None return None LogGroup = Literal["error", "warning", "info"] def _parse_bool(v: Any, default: bool) -> bool: if v is None: return default if isinstance(v, bool): return v s = str(v).strip().lower() if s in ("1", "true", "yes", "y", "on"): return True if s in ("0", "false", "no", "n", "off"): return False return default def _parse_int(v: Any, default: int) -> int: try: n = int(str(v).strip()) return n except Exception: return default @dataclass(frozen=True) class RedisLogConfig: redis_url: str list_key_prefix: str = "ats:logs" config_key: str = "ats:logs:config" stats_key_prefix: str = "ats:logs:stats:added" max_len_error: int = 2000 max_len_warning: int = 2000 max_len_info: int = 2000 dedupe_consecutive: bool = True enable_error: bool = True enable_warning: bool = True enable_info: bool = True include_debug_in_info: bool = False config_refresh_sec: float = 5.0 service: str = "backend" hostname: str = socket.gethostname() connect_timeout_sec: float = 1.0 socket_timeout_sec: float = 1.0 username: Optional[str] = None password: Optional[str] = None use_tls: bool = False ssl_cert_reqs: str = "required" ssl_ca_certs: Optional[str] = None class RedisErrorLogHandler(logging.Handler): def __init__(self, cfg: RedisLogConfig): super().__init__() self.cfg = cfg self._redis = None self._redis_ok = False self._last_connect_attempt_ts = 0.0 self._last_cfg_refresh_ts = 0.0 self._remote_cfg: Dict[str, Any] = {} def _connection_kwargs(self) -> Dict[str, Any]: kwargs: Dict[str, Any] = { "decode_responses": True, "socket_connect_timeout": self.cfg.connect_timeout_sec, "socket_timeout": self.cfg.socket_timeout_sec, } if self.cfg.username: kwargs["username"] = self.cfg.username if self.cfg.password: kwargs["password"] = self.cfg.password if self.cfg.redis_url.startswith("rediss://") or self.cfg.use_tls: kwargs["ssl_cert_reqs"] = self.cfg.ssl_cert_reqs if self.cfg.ssl_ca_certs: kwargs["ssl_ca_certs"] = self.cfg.ssl_ca_certs if self.cfg.ssl_cert_reqs == "none": kwargs["ssl_check_hostname"] = False elif self.cfg.ssl_cert_reqs == "required": kwargs["ssl_check_hostname"] = True else: kwargs["ssl_check_hostname"] = False return kwargs def _get_redis(self): now = time.time() if self._redis_ok and self._redis is not None: return self._redis if now - self._last_connect_attempt_ts < 5: return None self._last_connect_attempt_ts = now try: import redis # type: ignore except Exception: self._redis = None self._redis_ok = False return None try: client = redis.from_url(self.cfg.redis_url, **self._connection_kwargs()) client.ping() self._redis = client self._redis_ok = True return self._redis except Exception: self._redis = None self._redis_ok = False return None def _build_entry(self, record: logging.LogRecord) -> Dict[str, Any]: msg = record.getMessage() exc_text = None exc_type = None if record.exc_info: exc_type = getattr(record.exc_info[0], "__name__", None) exc_text = "".join(traceback.format_exception(*record.exc_info)) signature = f"{self.cfg.service}|{record.levelname}|{record.name}|{record.pathname}:{record.lineno}|{msg}|{exc_type or ''}" return { "ts": int(record.created * 1000), "time": _beijing_time_str(record.created), "service": self.cfg.service, "level": record.levelname, "logger": record.name, "message": msg, "pathname": record.pathname, "lineno": record.lineno, "funcName": record.funcName, "process": record.process, "thread": record.thread, "hostname": self.cfg.hostname, "exc_type": exc_type, "exc_text": exc_text, "signature": signature, "count": 1, } def _effective_cfg_bool(self, key: str, default: bool) -> bool: if key in self._remote_cfg: return _parse_bool(self._remote_cfg.get(key), default) return default def _refresh_remote_config_if_needed(self, client) -> None: now = time.time() if now - self._last_cfg_refresh_ts < self.cfg.config_refresh_sec: return self._last_cfg_refresh_ts = now try: cfg_key = os.getenv("REDIS_LOG_CONFIG_KEY", self.cfg.config_key).strip() or self.cfg.config_key data = client.hgetall(cfg_key) or {} normalized: Dict[str, Any] = {} for k, v in data.items(): if not k: continue normalized[str(k).strip()] = v self._remote_cfg = normalized except Exception: return def _group_for_record(self, record: logging.LogRecord) -> Optional[LogGroup]: if record.levelno >= logging.ERROR: return "error" if record.levelno >= logging.WARNING: return "warning" if record.levelno == logging.INFO: return "info" if record.levelno == logging.DEBUG and self._effective_cfg_bool("include_debug_in_info", self.cfg.include_debug_in_info): return "info" return None def _list_key_for_group(self, group: LogGroup) -> str: if group == "error": legacy = os.getenv("REDIS_LOG_LIST_KEY", "").strip() if legacy: return legacy env_key = os.getenv(f"REDIS_LOG_LIST_KEY_{group.upper()}", "").strip() if env_key: return env_key prefix = os.getenv("REDIS_LOG_LIST_PREFIX", self.cfg.list_key_prefix).strip() or self.cfg.list_key_prefix return f"{prefix}:{group}" def _max_len_for_group(self, group: LogGroup) -> int: env_specific = os.getenv(f"REDIS_LOG_LIST_MAX_LEN_{group.upper()}", "").strip() if env_specific: n = _parse_int(env_specific, 0) return n if n > 0 else (self.cfg.max_len_error if group == "error" else self.cfg.max_len_warning if group == "warning" else self.cfg.max_len_info) env_global = os.getenv("REDIS_LOG_LIST_MAX_LEN", "").strip() if env_global: n = _parse_int(env_global, 0) if n > 0: return n field = f"max_len:{group}" if field in self._remote_cfg: n = _parse_int(self._remote_cfg.get(field), 0) if n > 0: return n return self.cfg.max_len_error if group == "error" else self.cfg.max_len_warning if group == "warning" else self.cfg.max_len_info def _enabled_for_group(self, group: LogGroup) -> bool: field = f"enabled:{group}" if field in self._remote_cfg: return _parse_bool(self._remote_cfg.get(field), True) return self.cfg.enable_error if group == "error" else self.cfg.enable_warning if group == "warning" else self.cfg.enable_info def _dedupe_consecutive_enabled(self) -> bool: if "dedupe_consecutive" in self._remote_cfg: return _parse_bool(self._remote_cfg.get("dedupe_consecutive"), self.cfg.dedupe_consecutive) return self.cfg.dedupe_consecutive def _stats_key(self, group: LogGroup) -> str: prefix = os.getenv("REDIS_LOG_STATS_PREFIX", self.cfg.stats_key_prefix).strip() or self.cfg.stats_key_prefix day = _beijing_yyyymmdd() return f"{prefix}:{day}:{group}" def emit(self, record: logging.LogRecord) -> None: try: client = self._get_redis() if client is None: return self._refresh_remote_config_if_needed(client) group = self._group_for_record(record) if group is None: return if not self._enabled_for_group(group): return entry = self._build_entry(record) list_key = self._list_key_for_group(group) max_len = self._max_len_for_group(group) stats_key = self._stats_key(group) if self._dedupe_consecutive_enabled(): try: head_raw = client.lindex(list_key, 0) except Exception: head_raw = None if isinstance(head_raw, str): head = _safe_json_loads(head_raw) else: head = None if head and head.get("signature") == entry["signature"]: head["count"] = int(head.get("count", 1)) + 1 head["ts"] = entry["ts"] head["time"] = entry["time"] if entry.get("exc_text"): head["exc_text"] = entry.get("exc_text") head["exc_type"] = entry.get("exc_type") try: # 集群模式下禁用 transaction,避免 CROSSSLOT(list_key 与 stats_key 不同 slot) pipe = client.pipeline(transaction=False) pipe.lset(list_key, 0, json.dumps(head, ensure_ascii=False)) pipe.ltrim(list_key, 0, max_len - 1) pipe.incr(stats_key, 1) pipe.expire(stats_key, 14 * 24 * 3600) pipe.execute() return except Exception: pass try: # 集群模式下禁用 transaction,避免 CROSSSLOT(list_key 与 stats_key 不同 slot) pipe = client.pipeline(transaction=False) pipe.lpush(list_key, json.dumps(entry, ensure_ascii=False)) pipe.ltrim(list_key, 0, max_len - 1) pipe.incr(stats_key, 1) pipe.expire(stats_key, 14 * 24 * 3600) pipe.execute() except Exception: return except Exception: return