a
This commit is contained in:
parent
32170b3b0a
commit
c535a7b1ae
|
|
@ -9,8 +9,10 @@ from binance.exceptions import BinanceAPIException
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from . import config
|
from . import config
|
||||||
|
from .redis_cache import RedisCache
|
||||||
except ImportError:
|
except ImportError:
|
||||||
import config
|
import config
|
||||||
|
from redis_cache import RedisCache
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -39,6 +41,20 @@ class BinanceClient:
|
||||||
self._price_cache: Dict[str, Dict] = {} # WebSocket价格缓存 {symbol: {price, volume, changePercent, timestamp}}
|
self._price_cache: Dict[str, Dict] = {} # WebSocket价格缓存 {symbol: {price, volume, changePercent, timestamp}}
|
||||||
self._price_cache_ttl = 60 # 价格缓存有效期(秒)
|
self._price_cache_ttl = 60 # 价格缓存有效期(秒)
|
||||||
|
|
||||||
|
# 初始化 Redis 缓存
|
||||||
|
self.redis_cache = RedisCache(
|
||||||
|
redis_url=config.REDIS_URL,
|
||||||
|
use_tls=config.REDIS_USE_TLS,
|
||||||
|
ssl_cert_reqs=config.REDIS_SSL_CERT_REQS,
|
||||||
|
<<<<<<< Current (Your changes)
|
||||||
|
ssl_ca_certs=config.REDIS_SSL_CA_CERTS
|
||||||
|
=======
|
||||||
|
ssl_ca_certs=config.REDIS_SSL_CA_CERTS,
|
||||||
|
username=config.REDIS_USERNAME,
|
||||||
|
password=config.REDIS_PASSWORD
|
||||||
|
>>>>>>> Incoming (Background Agent changes)
|
||||||
|
)
|
||||||
|
|
||||||
async def connect(self, timeout: int = None, retries: int = None):
|
async def connect(self, timeout: int = None, retries: int = None):
|
||||||
"""
|
"""
|
||||||
连接币安API
|
连接币安API
|
||||||
|
|
@ -75,6 +91,9 @@ class BinanceClient:
|
||||||
self.socket_manager = BinanceSocketManager(self.client)
|
self.socket_manager = BinanceSocketManager(self.client)
|
||||||
logger.info(f"✓ 币安客户端连接成功 (测试网: {self.testnet})")
|
logger.info(f"✓ 币安客户端连接成功 (测试网: {self.testnet})")
|
||||||
|
|
||||||
|
# 连接 Redis 缓存
|
||||||
|
await self.redis_cache.connect()
|
||||||
|
|
||||||
# 验证API密钥权限
|
# 验证API密钥权限
|
||||||
await self._verify_api_permissions()
|
await self._verify_api_permissions()
|
||||||
|
|
||||||
|
|
@ -129,6 +148,9 @@ class BinanceClient:
|
||||||
async def disconnect(self):
|
async def disconnect(self):
|
||||||
"""断开连接"""
|
"""断开连接"""
|
||||||
|
|
||||||
|
# 关闭 Redis 连接
|
||||||
|
await self.redis_cache.close()
|
||||||
|
|
||||||
if self.client:
|
if self.client:
|
||||||
await self.client.close_connection()
|
await self.client.close_connection()
|
||||||
logger.info("币安客户端已断开连接")
|
logger.info("币安客户端已断开连接")
|
||||||
|
|
@ -177,6 +199,7 @@ class BinanceClient:
|
||||||
async def get_klines(self, symbol: str, interval: str = '5m', limit: int = 2) -> List[List]:
|
async def get_klines(self, symbol: str, interval: str = '5m', limit: int = 2) -> List[List]:
|
||||||
"""
|
"""
|
||||||
获取K线数据(合约市场)
|
获取K线数据(合约市场)
|
||||||
|
优先从 Redis 缓存读取,如果缓存不可用或过期则使用 REST API
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
symbol: 交易对
|
symbol: 交易对
|
||||||
|
|
@ -186,11 +209,41 @@ class BinanceClient:
|
||||||
Returns:
|
Returns:
|
||||||
K线数据列表
|
K线数据列表
|
||||||
"""
|
"""
|
||||||
|
# 先查 Redis 缓存
|
||||||
|
cache_key = f"klines:{symbol}:{interval}:{limit}"
|
||||||
|
cached = await self.redis_cache.get(cache_key)
|
||||||
|
if cached:
|
||||||
|
logger.debug(f"从缓存获取 {symbol} K线数据: {interval} x{limit}")
|
||||||
|
return cached
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# 缓存未命中,调用 API
|
||||||
klines = await self._rate_limited_request(
|
klines = await self._rate_limited_request(
|
||||||
f'klines_{symbol}_{interval}',
|
f'klines_{symbol}_{interval}',
|
||||||
self.client.futures_klines(symbol=symbol, interval=interval, limit=limit)
|
self.client.futures_klines(symbol=symbol, interval=interval, limit=limit)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 写入 Redis 缓存(根据 interval 动态设置 TTL)
|
||||||
|
if klines:
|
||||||
|
# TTL 设置:1m=10秒, 5m=30秒, 15m=1分钟, 1h=5分钟, 4h=15分钟, 1d=1小时
|
||||||
|
ttl_map = {
|
||||||
|
'1m': 10,
|
||||||
|
'3m': 20,
|
||||||
|
'5m': 30,
|
||||||
|
'15m': 60,
|
||||||
|
'30m': 120,
|
||||||
|
'1h': 300,
|
||||||
|
'2h': 600,
|
||||||
|
'4h': 900,
|
||||||
|
'6h': 1200,
|
||||||
|
'8h': 1800,
|
||||||
|
'12h': 2400,
|
||||||
|
'1d': 3600
|
||||||
|
}
|
||||||
|
ttl = ttl_map.get(interval, 300) # 默认 5 分钟
|
||||||
|
await self.redis_cache.set(cache_key, klines, ttl=ttl)
|
||||||
|
logger.debug(f"已缓存 {symbol} K线数据: {interval} x{limit} (TTL: {ttl}秒)")
|
||||||
|
|
||||||
return klines
|
return klines
|
||||||
except BinanceAPIException as e:
|
except BinanceAPIException as e:
|
||||||
error_code = e.code if hasattr(e, 'code') else None
|
error_code = e.code if hasattr(e, 'code') else None
|
||||||
|
|
@ -203,7 +256,7 @@ class BinanceClient:
|
||||||
async def get_ticker_24h(self, symbol: str) -> Optional[Dict]:
|
async def get_ticker_24h(self, symbol: str) -> Optional[Dict]:
|
||||||
"""
|
"""
|
||||||
获取24小时行情数据(合约市场)
|
获取24小时行情数据(合约市场)
|
||||||
优先从WebSocket缓存读取,如果缓存不可用或过期则使用REST API
|
优先从WebSocket缓存读取,其次从Redis缓存读取,最后使用REST API
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
symbol: 交易对
|
symbol: 交易对
|
||||||
|
|
@ -213,12 +266,12 @@ class BinanceClient:
|
||||||
"""
|
"""
|
||||||
import time
|
import time
|
||||||
|
|
||||||
# 优先从WebSocket缓存读取
|
# 1. 优先从WebSocket缓存读取
|
||||||
if symbol in self._price_cache:
|
if symbol in self._price_cache:
|
||||||
cached = self._price_cache[symbol]
|
cached = self._price_cache[symbol]
|
||||||
cache_age = time.time() - cached.get('timestamp', 0)
|
cache_age = time.time() - cached.get('timestamp', 0)
|
||||||
if cache_age < self._price_cache_ttl:
|
if cache_age < self._price_cache_ttl:
|
||||||
logger.debug(f"从缓存获取 {symbol} 价格: {cached['price']:.8f} (缓存年龄: {cache_age:.1f}秒)")
|
logger.debug(f"从WebSocket缓存获取 {symbol} 价格: {cached['price']:.8f} (缓存年龄: {cache_age:.1f}秒)")
|
||||||
return {
|
return {
|
||||||
'symbol': symbol,
|
'symbol': symbol,
|
||||||
'price': cached['price'],
|
'price': cached['price'],
|
||||||
|
|
@ -226,10 +279,17 @@ class BinanceClient:
|
||||||
'changePercent': cached.get('changePercent', 0)
|
'changePercent': cached.get('changePercent', 0)
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
logger.debug(f"{symbol} 缓存已过期 ({cache_age:.1f}秒 > {self._price_cache_ttl}秒),使用REST API")
|
logger.debug(f"{symbol} WebSocket缓存已过期 ({cache_age:.1f}秒 > {self._price_cache_ttl}秒)")
|
||||||
|
|
||||||
# 如果缓存不可用或过期,使用REST API(fallback)
|
# 2. 从 Redis 缓存读取
|
||||||
logger.debug(f"{symbol} 未在价格缓存中,使用REST API获取")
|
cache_key = f"ticker_24h:{symbol}"
|
||||||
|
cached = await self.redis_cache.get(cache_key)
|
||||||
|
if cached:
|
||||||
|
logger.debug(f"从Redis缓存获取 {symbol} 24小时行情数据")
|
||||||
|
return cached
|
||||||
|
|
||||||
|
# 3. 如果缓存不可用或过期,使用REST API(fallback)
|
||||||
|
logger.debug(f"{symbol} 未在缓存中,使用REST API获取")
|
||||||
try:
|
try:
|
||||||
ticker = await self._rate_limited_request(
|
ticker = await self._rate_limited_request(
|
||||||
f'ticker_{symbol}',
|
f'ticker_{symbol}',
|
||||||
|
|
@ -245,12 +305,16 @@ class BinanceClient:
|
||||||
'volume': float(stats.get('quoteVolume', 0)),
|
'volume': float(stats.get('quoteVolume', 0)),
|
||||||
'changePercent': float(stats.get('priceChangePercent', 0))
|
'changePercent': float(stats.get('priceChangePercent', 0))
|
||||||
}
|
}
|
||||||
# 更新缓存
|
|
||||||
import time
|
# 更新 WebSocket 缓存
|
||||||
self._price_cache[symbol] = {
|
self._price_cache[symbol] = {
|
||||||
**result,
|
**result,
|
||||||
'timestamp': time.time()
|
'timestamp': time.time()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# 写入 Redis 缓存(TTL: 30秒)
|
||||||
|
await self.redis_cache.set(cache_key, result, ttl=30)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
except BinanceAPIException as e:
|
except BinanceAPIException as e:
|
||||||
error_code = e.code if hasattr(e, 'code') else None
|
error_code = e.code if hasattr(e, 'code') else None
|
||||||
|
|
@ -263,10 +327,18 @@ class BinanceClient:
|
||||||
async def get_all_tickers_24h(self) -> Dict[str, Dict]:
|
async def get_all_tickers_24h(self) -> Dict[str, Dict]:
|
||||||
"""
|
"""
|
||||||
批量获取所有交易对的24小时行情数据(更高效)
|
批量获取所有交易对的24小时行情数据(更高效)
|
||||||
|
优先从 Redis 缓存读取,如果缓存不可用或过期则使用 REST API
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
交易对行情数据字典 {symbol: {price, volume, changePercent}}
|
交易对行情数据字典 {symbol: {price, volume, changePercent}}
|
||||||
"""
|
"""
|
||||||
|
# 先查 Redis 缓存
|
||||||
|
cache_key = "ticker_24h:all"
|
||||||
|
cached = await self.redis_cache.get(cache_key)
|
||||||
|
if cached:
|
||||||
|
logger.debug(f"从Redis缓存获取所有交易对的24小时行情数据: {len(cached)} 个交易对")
|
||||||
|
return cached
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 使用批量API,一次获取所有交易对的数据
|
# 使用批量API,一次获取所有交易对的数据
|
||||||
tickers = await self._rate_limited_request(
|
tickers = await self._rate_limited_request(
|
||||||
|
|
@ -285,7 +357,9 @@ class BinanceClient:
|
||||||
'changePercent': float(ticker.get('priceChangePercent', 0))
|
'changePercent': float(ticker.get('priceChangePercent', 0))
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.debug(f"批量获取到 {len(result)} 个交易对的24小时行情数据")
|
# 写入 Redis 缓存(TTL: 30秒)
|
||||||
|
await self.redis_cache.set(cache_key, result, ttl=30)
|
||||||
|
logger.debug(f"批量获取到 {len(result)} 个交易对的24小时行情数据,已缓存")
|
||||||
return result
|
return result
|
||||||
except BinanceAPIException as e:
|
except BinanceAPIException as e:
|
||||||
error_code = e.code if hasattr(e, 'code') else None
|
error_code = e.code if hasattr(e, 'code') else None
|
||||||
|
|
@ -375,6 +449,7 @@ class BinanceClient:
|
||||||
async def get_symbol_info(self, symbol: str) -> Optional[Dict]:
|
async def get_symbol_info(self, symbol: str) -> Optional[Dict]:
|
||||||
"""
|
"""
|
||||||
获取交易对的精度和限制信息
|
获取交易对的精度和限制信息
|
||||||
|
优先从 Redis 缓存读取,如果缓存不可用或过期则使用 REST API
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
symbol: 交易对
|
symbol: 交易对
|
||||||
|
|
@ -382,10 +457,20 @@ class BinanceClient:
|
||||||
Returns:
|
Returns:
|
||||||
交易对信息字典,包含 quantityPrecision, minQty, stepSize 等
|
交易对信息字典,包含 quantityPrecision, minQty, stepSize 等
|
||||||
"""
|
"""
|
||||||
# 先检查缓存
|
# 1. 先检查内存缓存
|
||||||
if symbol in self._symbol_info_cache:
|
if symbol in self._symbol_info_cache:
|
||||||
return self._symbol_info_cache[symbol]
|
return self._symbol_info_cache[symbol]
|
||||||
|
|
||||||
|
# 2. 从 Redis 缓存读取
|
||||||
|
cache_key = f"symbol_info:{symbol}"
|
||||||
|
cached = await self.redis_cache.get(cache_key)
|
||||||
|
if cached:
|
||||||
|
logger.debug(f"从Redis缓存获取 {symbol} 交易对信息")
|
||||||
|
# 同时更新内存缓存
|
||||||
|
self._symbol_info_cache[symbol] = cached
|
||||||
|
return cached
|
||||||
|
|
||||||
|
# 3. 缓存未命中,调用 API
|
||||||
try:
|
try:
|
||||||
exchange_info = await self.client.futures_exchange_info()
|
exchange_info = await self.client.futures_exchange_info()
|
||||||
for s in exchange_info['symbols']:
|
for s in exchange_info['symbols']:
|
||||||
|
|
@ -415,7 +500,10 @@ class BinanceClient:
|
||||||
'minNotional': min_notional
|
'minNotional': min_notional
|
||||||
}
|
}
|
||||||
|
|
||||||
# 缓存信息
|
# 写入 Redis 缓存(TTL: 1小时)
|
||||||
|
await self.redis_cache.set(cache_key, info, ttl=3600)
|
||||||
|
|
||||||
|
# 同时更新内存缓存
|
||||||
self._symbol_info_cache[symbol] = info
|
self._symbol_info_cache[symbol] = info
|
||||||
logger.debug(f"获取 {symbol} 精度信息: {info}")
|
logger.debug(f"获取 {symbol} 精度信息: {info}")
|
||||||
return info
|
return info
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,31 @@
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# 加载 .env 文件(优先从 trading_system/.env,其次从项目根目录/.env)
|
||||||
|
try:
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
trading_system_dir = Path(__file__).parent
|
||||||
|
project_root = trading_system_dir.parent
|
||||||
|
env_files = [
|
||||||
|
trading_system_dir / '.env', # trading_system/.env
|
||||||
|
project_root / '.env', # 项目根目录/.env
|
||||||
|
]
|
||||||
|
for env_file in env_files:
|
||||||
|
if env_file.exists():
|
||||||
|
load_dotenv(env_file, override=True)
|
||||||
|
print(f"[config.py] 已加载 .env 文件: {env_file}")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# 如果都不存在,尝试加载但不报错
|
||||||
|
load_dotenv(project_root / '.env', override=False)
|
||||||
|
except ImportError:
|
||||||
|
# python-dotenv 未安装时忽略
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
# 加载 .env 文件失败时忽略,不影响程序运行
|
||||||
|
print(f"[config.py] 加载 .env 文件时出错(可忽略): {e}")
|
||||||
|
|
||||||
# 尝试从数据库加载配置
|
# 尝试从数据库加载配置
|
||||||
USE_DB_CONFIG = False
|
USE_DB_CONFIG = False
|
||||||
|
|
@ -183,6 +208,7 @@ for key, value in defaults.items():
|
||||||
def reload_config():
|
def reload_config():
|
||||||
"""重新加载配置(供外部调用)"""
|
"""重新加载配置(供外部调用)"""
|
||||||
global TRADING_CONFIG, BINANCE_API_KEY, BINANCE_API_SECRET, USE_TESTNET, _config_manager
|
global TRADING_CONFIG, BINANCE_API_KEY, BINANCE_API_SECRET, USE_TESTNET, _config_manager
|
||||||
|
global REDIS_URL, REDIS_USE_TLS, REDIS_SSL_CERT_REQS, REDIS_SSL_CA_CERTS, REDIS_USERNAME, REDIS_PASSWORD
|
||||||
_init_config_manager() # 重新初始化配置管理器
|
_init_config_manager() # 重新初始化配置管理器
|
||||||
if _config_manager:
|
if _config_manager:
|
||||||
_config_manager.reload()
|
_config_manager.reload()
|
||||||
|
|
@ -190,6 +216,13 @@ def reload_config():
|
||||||
BINANCE_API_SECRET = _get_config_value('BINANCE_API_SECRET', BINANCE_API_SECRET)
|
BINANCE_API_SECRET = _get_config_value('BINANCE_API_SECRET', BINANCE_API_SECRET)
|
||||||
USE_TESTNET = _get_config_value('USE_TESTNET', False) if _get_config_value('USE_TESTNET') is not None else os.getenv('USE_TESTNET', 'False').lower() == 'true'
|
USE_TESTNET = _get_config_value('USE_TESTNET', False) if _get_config_value('USE_TESTNET') is not None else os.getenv('USE_TESTNET', 'False').lower() == 'true'
|
||||||
TRADING_CONFIG = _get_trading_config()
|
TRADING_CONFIG = _get_trading_config()
|
||||||
|
# 重新加载 Redis 配置
|
||||||
|
REDIS_URL = _get_config_value('REDIS_URL', os.getenv('REDIS_URL', REDIS_URL))
|
||||||
|
REDIS_USE_TLS = _get_config_value('REDIS_USE_TLS', False) if _get_config_value('REDIS_USE_TLS') is not None else os.getenv('REDIS_USE_TLS', 'False').lower() == 'true'
|
||||||
|
REDIS_SSL_CERT_REQS = _get_config_value('REDIS_SSL_CERT_REQS', REDIS_SSL_CERT_REQS)
|
||||||
|
REDIS_SSL_CA_CERTS = _get_config_value('REDIS_SSL_CA_CERTS', REDIS_SSL_CA_CERTS)
|
||||||
|
REDIS_USERNAME = _get_config_value('REDIS_USERNAME', os.getenv('REDIS_USERNAME', REDIS_USERNAME))
|
||||||
|
REDIS_PASSWORD = _get_config_value('REDIS_PASSWORD', os.getenv('REDIS_PASSWORD', REDIS_PASSWORD))
|
||||||
# 确保默认值
|
# 确保默认值
|
||||||
for key, value in defaults.items():
|
for key, value in defaults.items():
|
||||||
if key not in TRADING_CONFIG:
|
if key not in TRADING_CONFIG:
|
||||||
|
|
@ -199,6 +232,14 @@ def reload_config():
|
||||||
CONNECTION_TIMEOUT = int(os.getenv('CONNECTION_TIMEOUT', '30')) # 连接超时时间(秒)
|
CONNECTION_TIMEOUT = int(os.getenv('CONNECTION_TIMEOUT', '30')) # 连接超时时间(秒)
|
||||||
CONNECTION_RETRIES = int(os.getenv('CONNECTION_RETRIES', '3')) # 连接重试次数
|
CONNECTION_RETRIES = int(os.getenv('CONNECTION_RETRIES', '3')) # 连接重试次数
|
||||||
|
|
||||||
|
# Redis 缓存配置(优先从数据库,回退到环境变量和默认值)
|
||||||
|
REDIS_URL = _get_config_value('REDIS_URL', os.getenv('REDIS_URL', 'redis://localhost:6379'))
|
||||||
|
REDIS_USE_TLS = _get_config_value('REDIS_USE_TLS', False) if _get_config_value('REDIS_USE_TLS') is not None else os.getenv('REDIS_USE_TLS', 'False').lower() == 'true'
|
||||||
|
REDIS_SSL_CERT_REQS = _get_config_value('REDIS_SSL_CERT_REQS', 'required')
|
||||||
|
REDIS_SSL_CA_CERTS = _get_config_value('REDIS_SSL_CA_CERTS', None)
|
||||||
|
REDIS_USERNAME = _get_config_value('REDIS_USERNAME', os.getenv('REDIS_USERNAME', None))
|
||||||
|
REDIS_PASSWORD = _get_config_value('REDIS_PASSWORD', os.getenv('REDIS_PASSWORD', None))
|
||||||
|
|
||||||
# 日志配置
|
# 日志配置
|
||||||
LOG_LEVEL = os.getenv('LOG_LEVEL', 'INFO')
|
LOG_LEVEL = os.getenv('LOG_LEVEL', 'INFO')
|
||||||
LOG_FILE = 'trading_bot.log'
|
LOG_FILE = 'trading_bot.log'
|
||||||
|
|
|
||||||
|
|
@ -32,12 +32,22 @@ class MarketScanner:
|
||||||
async def scan_market(self) -> List[Dict]:
|
async def scan_market(self) -> List[Dict]:
|
||||||
"""
|
"""
|
||||||
扫描市场,找出涨跌幅最大的前N个货币对
|
扫描市场,找出涨跌幅最大的前N个货币对
|
||||||
|
优先从 Redis 缓存读取扫描结果,如果缓存不可用或过期则重新扫描
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
前N个货币对列表,包含涨跌幅信息
|
前N个货币对列表,包含涨跌幅信息
|
||||||
"""
|
"""
|
||||||
import time
|
import time
|
||||||
self._scan_start_time = time.time()
|
self._scan_start_time = time.time()
|
||||||
|
|
||||||
|
# 先查 Redis 缓存(扫描结果缓存,TTL: 30秒)
|
||||||
|
cache_key = "scan_result:top_symbols"
|
||||||
|
cached = await self.client.redis_cache.get(cache_key)
|
||||||
|
if cached:
|
||||||
|
logger.info(f"从Redis缓存获取扫描结果: {len(cached)} 个交易对")
|
||||||
|
self.top_symbols = cached
|
||||||
|
return cached
|
||||||
|
|
||||||
logger.info("开始扫描市场...")
|
logger.info("开始扫描市场...")
|
||||||
|
|
||||||
# 获取所有USDT交易对
|
# 获取所有USDT交易对
|
||||||
|
|
@ -111,6 +121,10 @@ class MarketScanner:
|
||||||
|
|
||||||
self.top_symbols = top_n
|
self.top_symbols = top_n
|
||||||
|
|
||||||
|
# 写入 Redis 缓存(TTL: 30秒)
|
||||||
|
await self.client.redis_cache.set(cache_key, top_n, ttl=30)
|
||||||
|
logger.debug(f"扫描结果已缓存: {len(top_n)} 个交易对 (TTL: 30秒)")
|
||||||
|
|
||||||
# 记录扫描结果到数据库
|
# 记录扫描结果到数据库
|
||||||
try:
|
try:
|
||||||
import sys
|
import sys
|
||||||
|
|
|
||||||
221
trading_system/redis_cache.py
Normal file
221
trading_system/redis_cache.py
Normal file
|
|
@ -0,0 +1,221 @@
|
||||||
|
"""
|
||||||
|
Redis 缓存管理器 - 支持 TLS 连接
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Optional, Any, Dict, List
|
||||||
|
import ssl
|
||||||
|
|
||||||
|
try:
|
||||||
|
import aioredis
|
||||||
|
from aioredis import Redis
|
||||||
|
AIOREDIS_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
AIOREDIS_AVAILABLE = False
|
||||||
|
Redis = None
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RedisCache:
|
||||||
|
"""Redis 缓存管理器 - 支持 TLS 连接和降级到内存缓存"""
|
||||||
|
|
||||||
|
def __init__(self, redis_url: str = None, use_tls: bool = False,
|
||||||
|
ssl_cert_reqs: str = 'required', ssl_ca_certs: str = None,
|
||||||
|
username: str = None, password: str = None):
|
||||||
|
"""
|
||||||
|
初始化 Redis 缓存管理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
redis_url: Redis 连接 URL(例如: redis://localhost:6379 或 rediss://localhost:6380)
|
||||||
|
如果 URL 中包含用户名和密码,会优先使用 URL 中的认证信息
|
||||||
|
use_tls: 是否使用 TLS(如果 redis_url 以 rediss:// 开头,会自动启用)
|
||||||
|
ssl_cert_reqs: SSL 证书验证要求 ('none', 'optional', 'required')
|
||||||
|
ssl_ca_certs: SSL CA 证书路径
|
||||||
|
username: Redis 用户名(如果 URL 中未包含)
|
||||||
|
password: Redis 密码(如果 URL 中未包含)
|
||||||
|
"""
|
||||||
|
self.redis_url = redis_url or "redis://localhost:6379"
|
||||||
|
self.use_tls = use_tls or self.redis_url.startswith('rediss://')
|
||||||
|
self.ssl_cert_reqs = ssl_cert_reqs
|
||||||
|
self.ssl_ca_certs = ssl_ca_certs
|
||||||
|
self.username = username
|
||||||
|
self.password = password
|
||||||
|
self.redis: Optional[Redis] = None
|
||||||
|
self._memory_cache: Dict[str, Any] = {} # 降级到内存缓存
|
||||||
|
self._connected = False
|
||||||
|
|
||||||
|
async def connect(self):
|
||||||
|
"""连接 Redis"""
|
||||||
|
if not AIOREDIS_AVAILABLE:
|
||||||
|
logger.warning("aioredis 未安装,将使用内存缓存")
|
||||||
|
self.redis = None
|
||||||
|
self._connected = False
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 构建连接参数
|
||||||
|
connection_kwargs = {}
|
||||||
|
|
||||||
|
# 如果使用 TLS
|
||||||
|
if self.use_tls or self.redis_url.startswith('rediss://'):
|
||||||
|
# 配置 SSL 上下文
|
||||||
|
ssl_context = ssl.create_default_context()
|
||||||
|
|
||||||
|
# 设置证书验证要求
|
||||||
|
if self.ssl_cert_reqs == 'none':
|
||||||
|
ssl_context.check_hostname = False
|
||||||
|
ssl_context.verify_mode = ssl.CERT_NONE
|
||||||
|
elif self.ssl_cert_reqs == 'optional':
|
||||||
|
ssl_context.check_hostname = False
|
||||||
|
ssl_context.verify_mode = ssl.CERT_OPTIONAL
|
||||||
|
else: # required
|
||||||
|
ssl_context.check_hostname = True
|
||||||
|
ssl_context.verify_mode = ssl.CERT_REQUIRED
|
||||||
|
|
||||||
|
# 如果提供了 CA 证书路径
|
||||||
|
if self.ssl_ca_certs:
|
||||||
|
ssl_context.load_verify_locations(self.ssl_ca_certs)
|
||||||
|
|
||||||
|
connection_kwargs['ssl'] = ssl_context
|
||||||
|
logger.info(f"使用 TLS 连接 Redis: {self.redis_url}")
|
||||||
|
|
||||||
|
# 如果 URL 中不包含用户名和密码,且提供了独立的用户名和密码参数,则添加到连接参数中
|
||||||
|
# 注意:如果 URL 中已经包含认证信息(如 redis://user:pass@host:port),则优先使用 URL 中的
|
||||||
|
if self.username and self.password:
|
||||||
|
# 检查 URL 中是否已包含认证信息(格式:redis://user:pass@host:port)
|
||||||
|
url_parts = self.redis_url.split('://')
|
||||||
|
if len(url_parts) == 2:
|
||||||
|
url_after_scheme = url_parts[1]
|
||||||
|
# 如果 URL 中不包含 @ 符号,说明没有在 URL 中指定认证信息
|
||||||
|
if '@' not in url_after_scheme:
|
||||||
|
# URL 中不包含认证信息,使用独立的用户名和密码参数
|
||||||
|
connection_kwargs['username'] = self.username
|
||||||
|
connection_kwargs['password'] = self.password
|
||||||
|
logger.info(f"使用独立的用户名和密码进行认证: {self.username}")
|
||||||
|
else:
|
||||||
|
logger.info("URL 中已包含认证信息,优先使用 URL 中的认证信息")
|
||||||
|
else:
|
||||||
|
# URL 格式异常,尝试使用独立的用户名和密码
|
||||||
|
connection_kwargs['username'] = self.username
|
||||||
|
connection_kwargs['password'] = self.password
|
||||||
|
logger.info(f"URL 格式异常,使用独立的用户名和密码进行认证: {self.username}")
|
||||||
|
|
||||||
|
# 创建 Redis 连接
|
||||||
|
self.redis = await aioredis.from_url(
|
||||||
|
self.redis_url,
|
||||||
|
**connection_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# 测试连接
|
||||||
|
await self.redis.ping()
|
||||||
|
self._connected = True
|
||||||
|
logger.info(f"✓ Redis 连接成功: {self.redis_url}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Redis 连接失败: {e},将使用内存缓存")
|
||||||
|
self.redis = None
|
||||||
|
self._connected = False
|
||||||
|
if self.redis:
|
||||||
|
try:
|
||||||
|
await self.redis.close()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
self.redis = None
|
||||||
|
|
||||||
|
async def get(self, key: str) -> Optional[Any]:
|
||||||
|
"""
|
||||||
|
获取缓存
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: 缓存键
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
缓存值,如果不存在则返回 None
|
||||||
|
"""
|
||||||
|
# 先尝试从 Redis 获取
|
||||||
|
if self.redis and self._connected:
|
||||||
|
try:
|
||||||
|
data = await self.redis.get(key)
|
||||||
|
if data:
|
||||||
|
return json.loads(data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Redis 获取失败 {key}: {e}")
|
||||||
|
# Redis 失败时,尝试重新连接
|
||||||
|
if not self._connected:
|
||||||
|
await self.connect()
|
||||||
|
|
||||||
|
# 降级到内存缓存
|
||||||
|
if key in self._memory_cache:
|
||||||
|
return self._memory_cache[key]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def set(self, key: str, value: Any, ttl: int = 3600):
|
||||||
|
"""
|
||||||
|
设置缓存
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: 缓存键
|
||||||
|
value: 缓存值
|
||||||
|
ttl: 过期时间(秒)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否成功
|
||||||
|
"""
|
||||||
|
# 先尝试写入 Redis
|
||||||
|
if self.redis and self._connected:
|
||||||
|
try:
|
||||||
|
await self.redis.setex(key, ttl, json.dumps(value))
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Redis 设置失败 {key}: {e}")
|
||||||
|
# Redis 失败时,尝试重新连接
|
||||||
|
if not self._connected:
|
||||||
|
await self.connect()
|
||||||
|
if self.redis and self._connected:
|
||||||
|
try:
|
||||||
|
await self.redis.setex(key, ttl, json.dumps(value))
|
||||||
|
return True
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 降级到内存缓存(不设置 TTL,因为内存缓存不支持)
|
||||||
|
self._memory_cache[key] = value
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def delete(self, key: str):
|
||||||
|
"""删除缓存"""
|
||||||
|
if self.redis and self._connected:
|
||||||
|
try:
|
||||||
|
await self.redis.delete(key)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Redis 删除失败 {key}: {e}")
|
||||||
|
|
||||||
|
# 同时删除内存缓存
|
||||||
|
if key in self._memory_cache:
|
||||||
|
del self._memory_cache[key]
|
||||||
|
|
||||||
|
async def exists(self, key: str) -> bool:
|
||||||
|
"""检查缓存是否存在"""
|
||||||
|
if self.redis and self._connected:
|
||||||
|
try:
|
||||||
|
return await self.redis.exists(key) > 0
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Redis 检查失败 {key}: {e}")
|
||||||
|
|
||||||
|
return key in self._memory_cache
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""关闭连接"""
|
||||||
|
if self.redis:
|
||||||
|
try:
|
||||||
|
await self.redis.close()
|
||||||
|
self._connected = False
|
||||||
|
logger.info("Redis 连接已关闭")
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"关闭 Redis 连接时出错: {e}")
|
||||||
|
|
||||||
|
def is_connected(self) -> bool:
|
||||||
|
"""检查是否已连接"""
|
||||||
|
return self._connected and self.redis is not None
|
||||||
|
|
@ -6,3 +6,6 @@ aiohttp==3.9.1
|
||||||
# 数据库依赖(用于从数据库读取配置)
|
# 数据库依赖(用于从数据库读取配置)
|
||||||
pymysql==1.1.0
|
pymysql==1.1.0
|
||||||
python-dotenv==1.0.0
|
python-dotenv==1.0.0
|
||||||
|
|
||||||
|
# Redis 缓存依赖
|
||||||
|
aioredis==2.0.1
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user