308 lines
10 KiB
Python
308 lines
10 KiB
Python
"""
|
||
Unicorn WebSocket模块 - 提供高性能实时数据流
|
||
"""
|
||
import asyncio
|
||
import logging
|
||
from typing import Dict, List, Optional, Callable
|
||
from unicorn_binance_websocket_api.unicorn_binance_websocket_api_manager import BinanceWebSocketApiManager
|
||
import config
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class UnicornWebSocketManager:
|
||
"""Unicorn WebSocket管理器"""
|
||
|
||
def __init__(self, testnet: bool = False):
|
||
"""
|
||
初始化Unicorn WebSocket管理器
|
||
|
||
Args:
|
||
testnet: 是否使用测试网
|
||
"""
|
||
self.testnet = testnet or config.USE_TESTNET
|
||
self.manager: Optional[BinanceWebSocketApiManager] = None
|
||
self.stream_ids: Dict[str, str] = {} # symbol -> stream_id
|
||
self.price_callbacks: Dict[str, List[Callable]] = {} # symbol -> callbacks
|
||
self.running = False
|
||
|
||
def start(self):
|
||
"""启动WebSocket管理器"""
|
||
try:
|
||
# 创建管理器
|
||
self.manager = BinanceWebSocketApiManager(
|
||
exchange="binance.com-futures" if not self.testnet else "binance.com-futures-testnet",
|
||
throw_exception_if_unrepairable=True,
|
||
high_performance=True
|
||
)
|
||
|
||
self.running = True
|
||
logger.info(f"Unicorn WebSocket管理器启动成功 (测试网: {self.testnet})")
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"启动Unicorn WebSocket管理器失败: {e}")
|
||
return False
|
||
|
||
def stop(self):
|
||
"""停止WebSocket管理器"""
|
||
self.running = False
|
||
if self.manager:
|
||
# 停止所有流
|
||
for stream_id in self.stream_ids.values():
|
||
try:
|
||
self.manager.stop_stream(stream_id)
|
||
except:
|
||
pass
|
||
|
||
# 停止管理器
|
||
try:
|
||
self.manager.stop_manager_with_all_streams()
|
||
except:
|
||
pass
|
||
|
||
logger.info("Unicorn WebSocket管理器已停止")
|
||
|
||
def subscribe_ticker(self, symbols: List[str], callback: Callable) -> Dict[str, str]:
|
||
"""
|
||
订阅交易对的价格流
|
||
|
||
Args:
|
||
symbols: 交易对列表
|
||
callback: 价格更新回调函数 callback(symbol, price_data)
|
||
|
||
Returns:
|
||
交易对到stream_id的映射
|
||
"""
|
||
if not self.manager:
|
||
logger.error("WebSocket管理器未启动")
|
||
return {}
|
||
|
||
stream_ids = {}
|
||
|
||
try:
|
||
# 构建流名称列表
|
||
streams = []
|
||
for symbol in symbols:
|
||
# 转换为小写(币安要求)
|
||
symbol_lower = symbol.lower()
|
||
# 订阅ticker流
|
||
stream_name = f"{symbol_lower}@ticker"
|
||
streams.append(stream_name)
|
||
|
||
# 注册回调
|
||
if symbol not in self.price_callbacks:
|
||
self.price_callbacks[symbol] = []
|
||
self.price_callbacks[symbol].append(callback)
|
||
|
||
# 创建多路复用流
|
||
if streams:
|
||
stream_id = self.manager.create_stream(
|
||
["arr"],
|
||
streams,
|
||
output="UnicornFy"
|
||
)
|
||
|
||
# 记录stream_id
|
||
for symbol in symbols:
|
||
self.stream_ids[symbol] = stream_id
|
||
stream_ids[symbol] = stream_id
|
||
|
||
logger.info(f"订阅 {len(symbols)} 个交易对的价格流")
|
||
|
||
return stream_ids
|
||
|
||
except Exception as e:
|
||
logger.error(f"订阅价格流失败: {e}")
|
||
return {}
|
||
|
||
def subscribe_kline(
|
||
self,
|
||
symbols: List[str],
|
||
interval: str = "5m",
|
||
callback: Optional[Callable] = None
|
||
) -> Dict[str, str]:
|
||
"""
|
||
订阅K线数据流
|
||
|
||
Args:
|
||
symbols: 交易对列表
|
||
interval: K线周期(1m, 5m, 15m等)
|
||
callback: K线更新回调函数 callback(symbol, kline_data)
|
||
|
||
Returns:
|
||
交易对到stream_id的映射
|
||
"""
|
||
if not self.manager:
|
||
logger.error("WebSocket管理器未启动")
|
||
return {}
|
||
|
||
stream_ids = {}
|
||
|
||
try:
|
||
# 构建流名称列表
|
||
streams = []
|
||
for symbol in symbols:
|
||
symbol_lower = symbol.lower()
|
||
stream_name = f"{symbol_lower}@kline_{interval}"
|
||
streams.append(stream_name)
|
||
|
||
# 创建多路复用流
|
||
if streams:
|
||
stream_id = self.manager.create_stream(
|
||
["arr"],
|
||
streams,
|
||
output="UnicornFy"
|
||
)
|
||
|
||
for symbol in symbols:
|
||
stream_ids[symbol] = stream_id
|
||
|
||
logger.info(f"订阅 {len(symbols)} 个交易对的K线流 ({interval})")
|
||
|
||
return stream_ids
|
||
|
||
except Exception as e:
|
||
logger.error(f"订阅K线流失败: {e}")
|
||
return {}
|
||
|
||
def get_realtime_price(self, symbol: str) -> Optional[float]:
|
||
"""
|
||
获取实时价格(从WebSocket流缓冲区中)
|
||
|
||
Args:
|
||
symbol: 交易对
|
||
|
||
Returns:
|
||
实时价格,如果未订阅则返回None
|
||
"""
|
||
if not self.manager or symbol not in self.stream_ids:
|
||
return None
|
||
|
||
try:
|
||
stream_id = self.stream_ids[symbol]
|
||
# 从流缓冲区获取最新数据
|
||
data = self.manager.pop_stream_data_from_stream_buffer(stream_id)
|
||
|
||
if data and isinstance(data, dict):
|
||
# 解析ticker数据
|
||
if 'event_type' in data and data['event_type'] == '24hrTicker':
|
||
price_data = data.get('data', {})
|
||
if 'c' in price_data: # 最新价格
|
||
return float(price_data['c'])
|
||
elif 'data' in data and isinstance(data['data'], dict):
|
||
if 'c' in data['data']: # 最新价格
|
||
return float(data['data']['c'])
|
||
elif 'close' in data['data']: # K线收盘价
|
||
return float(data['data']['close'])
|
||
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.debug(f"获取 {symbol} 实时价格失败: {e}")
|
||
return None
|
||
|
||
async def process_stream_data(self):
|
||
"""
|
||
处理WebSocket流数据(异步)
|
||
"""
|
||
if not self.manager:
|
||
return
|
||
|
||
while self.running:
|
||
try:
|
||
# 处理所有流的数据
|
||
for symbol, stream_id in list(self.stream_ids.items()):
|
||
try:
|
||
# 获取流数据(非阻塞)
|
||
stream_data = self.manager.pop_stream_data_from_stream_buffer(stream_id)
|
||
if stream_data:
|
||
# 处理数据
|
||
await self._handle_stream_data(symbol, stream_data)
|
||
except Exception as e:
|
||
logger.debug(f"处理 {symbol} 流数据失败: {e}")
|
||
|
||
# 短暂休眠,避免CPU占用过高
|
||
await asyncio.sleep(0.1)
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理WebSocket流数据失败: {e}")
|
||
await asyncio.sleep(1)
|
||
|
||
async def _handle_stream_data(self, symbol: str, data: Dict):
|
||
"""
|
||
处理单个流的数据
|
||
|
||
Args:
|
||
symbol: 交易对
|
||
data: 流数据
|
||
"""
|
||
try:
|
||
if not data or not isinstance(data, dict):
|
||
return
|
||
|
||
# 处理ticker数据
|
||
if 'event_type' in data and data['event_type'] == '24hrTicker':
|
||
price_data = data.get('data', {})
|
||
if 'c' in price_data: # 最新价格
|
||
price = float(price_data['c'])
|
||
# 调用所有注册的回调
|
||
if symbol in self.price_callbacks:
|
||
for callback in self.price_callbacks[symbol]:
|
||
try:
|
||
if asyncio.iscoroutinefunction(callback):
|
||
await callback(symbol, price, price_data)
|
||
else:
|
||
callback(symbol, price, price_data)
|
||
except Exception as e:
|
||
logger.debug(f"回调函数执行失败: {e}")
|
||
|
||
# 处理K线数据
|
||
elif 'event_type' in data and data['event_type'] == 'kline':
|
||
kline_data = data.get('data', {})
|
||
if 'k' in kline_data:
|
||
kline = kline_data['k']
|
||
# 可以在这里处理K线更新
|
||
logger.debug(f"{symbol} K线更新: {kline.get('c', 'N/A')}")
|
||
|
||
except Exception as e:
|
||
logger.debug(f"处理 {symbol} 流数据失败: {e}")
|
||
|
||
def unsubscribe(self, symbol: str):
|
||
"""
|
||
取消订阅交易对
|
||
|
||
Args:
|
||
symbol: 交易对
|
||
"""
|
||
if symbol in self.stream_ids:
|
||
stream_id = self.stream_ids[symbol]
|
||
try:
|
||
self.manager.stop_stream(stream_id)
|
||
del self.stream_ids[symbol]
|
||
if symbol in self.price_callbacks:
|
||
del self.price_callbacks[symbol]
|
||
logger.info(f"取消订阅 {symbol}")
|
||
except Exception as e:
|
||
logger.error(f"取消订阅 {symbol} 失败: {e}")
|
||
|
||
def get_stream_statistics(self) -> Dict:
|
||
"""
|
||
获取流统计信息
|
||
|
||
Returns:
|
||
统计信息字典
|
||
"""
|
||
if not self.manager:
|
||
return {}
|
||
|
||
try:
|
||
stats = {
|
||
'total_streams': len(self.stream_ids),
|
||
'active_streams': len([s for s in self.stream_ids.values() if s]),
|
||
'subscribed_symbols': list(self.stream_ids.keys())
|
||
}
|
||
return stats
|
||
except Exception as e:
|
||
logger.error(f"获取流统计信息失败: {e}")
|
||
return {}
|