diff --git a/indicators.py b/indicators.py new file mode 100644 index 0000000..0ad66e2 --- /dev/null +++ b/indicators.py @@ -0,0 +1,272 @@ +""" +技术指标模块 - 计算各种技术指标用于交易决策 +""" +import logging +from typing import List, Dict, Optional, Tuple +import math + +logger = logging.getLogger(__name__) + + +class TechnicalIndicators: + """技术指标计算类""" + + @staticmethod + def calculate_rsi(prices: List[float], period: int = 14) -> Optional[float]: + """ + 计算RSI(相对强弱指标) + + Args: + prices: 价格列表(从旧到新) + period: 计算周期,默认14 + + Returns: + RSI值(0-100),None表示数据不足 + """ + if len(prices) < period + 1: + return None + + gains = [] + losses = [] + + for i in range(len(prices) - period, len(prices)): + if i == 0: + continue + change = prices[i] - prices[i - 1] + if change > 0: + gains.append(change) + losses.append(0) + else: + gains.append(0) + losses.append(abs(change)) + + if len(gains) == 0: + return None + + avg_gain = sum(gains) / len(gains) + avg_loss = sum(losses) / len(losses) if sum(losses) > 0 else 0.0001 + + if avg_loss == 0: + return 100 + + rs = avg_gain / avg_loss + rsi = 100 - (100 / (1 + rs)) + + return rsi + + @staticmethod + def calculate_macd( + prices: List[float], + fast_period: int = 12, + slow_period: int = 26, + signal_period: int = 9 + ) -> Optional[Dict[str, float]]: + """ + 计算MACD指标 + + Args: + prices: 价格列表(从旧到新) + fast_period: 快线周期,默认12 + slow_period: 慢线周期,默认26 + signal_period: 信号线周期,默认9 + + Returns: + {'macd': MACD值, 'signal': 信号线值, 'histogram': 柱状图值} + """ + if len(prices) < slow_period + signal_period: + return None + + # 计算EMA + def ema(data, period): + multiplier = 2 / (period + 1) + ema_values = [data[0]] + for price in data[1:]: + ema_values.append((price - ema_values[-1]) * multiplier + ema_values[-1]) + return ema_values + + fast_ema = ema(prices, fast_period) + slow_ema = ema(prices, slow_period) + + # 对齐长度 + min_len = min(len(fast_ema), len(slow_ema)) + fast_ema = fast_ema[-min_len:] + slow_ema = slow_ema[-min_len:] + + # 计算MACD线 + macd_line = [fast_ema[i] - slow_ema[i] for i in range(len(fast_ema))] + + if len(macd_line) < signal_period: + return None + + # 计算信号线 + signal_line = ema(macd_line, signal_period) + + # 计算柱状图 + macd_value = macd_line[-1] + signal_value = signal_line[-1] + histogram = macd_value - signal_value + + return { + 'macd': macd_value, + 'signal': signal_value, + 'histogram': histogram + } + + @staticmethod + def calculate_bollinger_bands( + prices: List[float], + period: int = 20, + std_dev: float = 2.0 + ) -> Optional[Dict[str, float]]: + """ + 计算布林带 + + Args: + prices: 价格列表(从旧到新) + period: 计算周期,默认20 + std_dev: 标准差倍数,默认2.0 + + Returns: + {'upper': 上轨, 'middle': 中轨, 'lower': 下轨} + """ + if len(prices) < period: + return None + + recent_prices = prices[-period:] + middle = sum(recent_prices) / len(recent_prices) + + # 计算标准差 + variance = sum((p - middle) ** 2 for p in recent_prices) / len(recent_prices) + std = math.sqrt(variance) + + upper = middle + (std * std_dev) + lower = middle - (std * std_dev) + + return { + 'upper': upper, + 'middle': middle, + 'lower': lower + } + + @staticmethod + def calculate_atr( + high_prices: List[float], + low_prices: List[float], + close_prices: List[float], + period: int = 14 + ) -> Optional[float]: + """ + 计算ATR(平均真实波幅) + + Args: + high_prices: 最高价列表 + low_prices: 最低价列表 + close_prices: 收盘价列表 + period: 计算周期,默认14 + + Returns: + ATR值 + """ + if len(high_prices) < period + 1 or len(low_prices) < period + 1 or len(close_prices) < period + 1: + return None + + true_ranges = [] + + for i in range(1, len(high_prices)): + tr1 = high_prices[i] - low_prices[i] + tr2 = abs(high_prices[i] - close_prices[i - 1]) + tr3 = abs(low_prices[i] - close_prices[i - 1]) + true_range = max(tr1, tr2, tr3) + true_ranges.append(true_range) + + if len(true_ranges) < period: + return None + + recent_tr = true_ranges[-period:] + atr = sum(recent_tr) / len(recent_tr) + + return atr + + @staticmethod + def calculate_ema(prices: List[float], period: int = 20) -> Optional[float]: + """ + 计算EMA(指数移动平均) + + Args: + prices: 价格列表(从旧到新) + period: 计算周期 + + Returns: + EMA值 + """ + if len(prices) < period: + return None + + multiplier = 2 / (period + 1) + ema_value = prices[0] + + for price in prices[1:]: + ema_value = (price - ema_value) * multiplier + ema_value + + return ema_value + + @staticmethod + def calculate_sma(prices: List[float], period: int = 20) -> Optional[float]: + """ + 计算SMA(简单移动平均) + + Args: + prices: 价格列表(从旧到新) + period: 计算周期 + + Returns: + SMA值 + """ + if len(prices) < period: + return None + + recent_prices = prices[-period:] + return sum(recent_prices) / len(recent_prices) + + @staticmethod + def detect_market_regime( + prices: List[float], + short_period: int = 20, + long_period: int = 50 + ) -> str: + """ + 判断市场状态:趋势或震荡 + + Args: + prices: 价格列表 + short_period: 短期均线周期 + long_period: 长期均线周期 + + Returns: + 'trending' 或 'ranging' + """ + if len(prices) < long_period: + return 'unknown' + + short_ma = TechnicalIndicators.calculate_sma(prices, short_period) + long_ma = TechnicalIndicators.calculate_sma(prices, long_period) + + if short_ma is None or long_ma is None: + return 'unknown' + + # 计算价格波动率 + recent_prices = prices[-20:] + if len(recent_prices) < 2: + return 'unknown' + + volatility = sum(abs(recent_prices[i] - recent_prices[i-1]) for i in range(1, len(recent_prices))) / len(recent_prices) + avg_price = sum(recent_prices) / len(recent_prices) + volatility_pct = (volatility / avg_price) * 100 if avg_price > 0 else 0 + + # 如果短期均线明显高于或低于长期均线,且波动率较大,判断为趋势 + ma_diff_pct = abs(short_ma - long_ma) / long_ma * 100 if long_ma > 0 else 0 + + if ma_diff_pct > 2 and volatility_pct > 1: + return 'trending' + else: + return 'ranging' diff --git a/unicorn_websocket.py b/unicorn_websocket.py new file mode 100644 index 0000000..e5c9188 --- /dev/null +++ b/unicorn_websocket.py @@ -0,0 +1,307 @@ +""" +Unicorn WebSocket模块 - 提供高性能实时数据流 +""" +import asyncio +import logging +from typing import Dict, List, Optional, Callable +from unicorn_binance_websocket_api.unicorn_binance_websocket_api_manager import BinanceWebSocketApiManager +import config + +logger = logging.getLogger(__name__) + + +class UnicornWebSocketManager: + """Unicorn WebSocket管理器""" + + def __init__(self, testnet: bool = False): + """ + 初始化Unicorn WebSocket管理器 + + Args: + testnet: 是否使用测试网 + """ + self.testnet = testnet or config.USE_TESTNET + self.manager: Optional[BinanceWebSocketApiManager] = None + self.stream_ids: Dict[str, str] = {} # symbol -> stream_id + self.price_callbacks: Dict[str, List[Callable]] = {} # symbol -> callbacks + self.running = False + + def start(self): + """启动WebSocket管理器""" + try: + # 创建管理器 + self.manager = BinanceWebSocketApiManager( + exchange="binance.com-futures" if not self.testnet else "binance.com-futures-testnet", + throw_exception_if_unrepairable=True, + high_performance=True + ) + + self.running = True + logger.info(f"Unicorn WebSocket管理器启动成功 (测试网: {self.testnet})") + return True + except Exception as e: + logger.error(f"启动Unicorn WebSocket管理器失败: {e}") + return False + + def stop(self): + """停止WebSocket管理器""" + self.running = False + if self.manager: + # 停止所有流 + for stream_id in self.stream_ids.values(): + try: + self.manager.stop_stream(stream_id) + except: + pass + + # 停止管理器 + try: + self.manager.stop_manager_with_all_streams() + except: + pass + + logger.info("Unicorn WebSocket管理器已停止") + + def subscribe_ticker(self, symbols: List[str], callback: Callable) -> Dict[str, str]: + """ + 订阅交易对的价格流 + + Args: + symbols: 交易对列表 + callback: 价格更新回调函数 callback(symbol, price_data) + + Returns: + 交易对到stream_id的映射 + """ + if not self.manager: + logger.error("WebSocket管理器未启动") + return {} + + stream_ids = {} + + try: + # 构建流名称列表 + streams = [] + for symbol in symbols: + # 转换为小写(币安要求) + symbol_lower = symbol.lower() + # 订阅ticker流 + stream_name = f"{symbol_lower}@ticker" + streams.append(stream_name) + + # 注册回调 + if symbol not in self.price_callbacks: + self.price_callbacks[symbol] = [] + self.price_callbacks[symbol].append(callback) + + # 创建多路复用流 + if streams: + stream_id = self.manager.create_stream( + ["arr"], + streams, + output="UnicornFy" + ) + + # 记录stream_id + for symbol in symbols: + self.stream_ids[symbol] = stream_id + stream_ids[symbol] = stream_id + + logger.info(f"订阅 {len(symbols)} 个交易对的价格流") + + return stream_ids + + except Exception as e: + logger.error(f"订阅价格流失败: {e}") + return {} + + def subscribe_kline( + self, + symbols: List[str], + interval: str = "5m", + callback: Optional[Callable] = None + ) -> Dict[str, str]: + """ + 订阅K线数据流 + + Args: + symbols: 交易对列表 + interval: K线周期(1m, 5m, 15m等) + callback: K线更新回调函数 callback(symbol, kline_data) + + Returns: + 交易对到stream_id的映射 + """ + if not self.manager: + logger.error("WebSocket管理器未启动") + return {} + + stream_ids = {} + + try: + # 构建流名称列表 + streams = [] + for symbol in symbols: + symbol_lower = symbol.lower() + stream_name = f"{symbol_lower}@kline_{interval}" + streams.append(stream_name) + + # 创建多路复用流 + if streams: + stream_id = self.manager.create_stream( + ["arr"], + streams, + output="UnicornFy" + ) + + for symbol in symbols: + stream_ids[symbol] = stream_id + + logger.info(f"订阅 {len(symbols)} 个交易对的K线流 ({interval})") + + return stream_ids + + except Exception as e: + logger.error(f"订阅K线流失败: {e}") + return {} + + def get_realtime_price(self, symbol: str) -> Optional[float]: + """ + 获取实时价格(从WebSocket流缓冲区中) + + Args: + symbol: 交易对 + + Returns: + 实时价格,如果未订阅则返回None + """ + if not self.manager or symbol not in self.stream_ids: + return None + + try: + stream_id = self.stream_ids[symbol] + # 从流缓冲区获取最新数据 + data = self.manager.pop_stream_data_from_stream_buffer(stream_id) + + if data and isinstance(data, dict): + # 解析ticker数据 + if 'event_type' in data and data['event_type'] == '24hrTicker': + price_data = data.get('data', {}) + if 'c' in price_data: # 最新价格 + return float(price_data['c']) + elif 'data' in data and isinstance(data['data'], dict): + if 'c' in data['data']: # 最新价格 + return float(data['data']['c']) + elif 'close' in data['data']: # K线收盘价 + return float(data['data']['close']) + + return None + + except Exception as e: + logger.debug(f"获取 {symbol} 实时价格失败: {e}") + return None + + async def process_stream_data(self): + """ + 处理WebSocket流数据(异步) + """ + if not self.manager: + return + + while self.running: + try: + # 处理所有流的数据 + for symbol, stream_id in list(self.stream_ids.items()): + try: + # 获取流数据(非阻塞) + stream_data = self.manager.pop_stream_data_from_stream_buffer(stream_id) + if stream_data: + # 处理数据 + await self._handle_stream_data(symbol, stream_data) + except Exception as e: + logger.debug(f"处理 {symbol} 流数据失败: {e}") + + # 短暂休眠,避免CPU占用过高 + await asyncio.sleep(0.1) + + except Exception as e: + logger.error(f"处理WebSocket流数据失败: {e}") + await asyncio.sleep(1) + + async def _handle_stream_data(self, symbol: str, data: Dict): + """ + 处理单个流的数据 + + Args: + symbol: 交易对 + data: 流数据 + """ + try: + if not data or not isinstance(data, dict): + return + + # 处理ticker数据 + if 'event_type' in data and data['event_type'] == '24hrTicker': + price_data = data.get('data', {}) + if 'c' in price_data: # 最新价格 + price = float(price_data['c']) + # 调用所有注册的回调 + if symbol in self.price_callbacks: + for callback in self.price_callbacks[symbol]: + try: + if asyncio.iscoroutinefunction(callback): + await callback(symbol, price, price_data) + else: + callback(symbol, price, price_data) + except Exception as e: + logger.debug(f"回调函数执行失败: {e}") + + # 处理K线数据 + elif 'event_type' in data and data['event_type'] == 'kline': + kline_data = data.get('data', {}) + if 'k' in kline_data: + kline = kline_data['k'] + # 可以在这里处理K线更新 + logger.debug(f"{symbol} K线更新: {kline.get('c', 'N/A')}") + + except Exception as e: + logger.debug(f"处理 {symbol} 流数据失败: {e}") + + def unsubscribe(self, symbol: str): + """ + 取消订阅交易对 + + Args: + symbol: 交易对 + """ + if symbol in self.stream_ids: + stream_id = self.stream_ids[symbol] + try: + self.manager.stop_stream(stream_id) + del self.stream_ids[symbol] + if symbol in self.price_callbacks: + del self.price_callbacks[symbol] + logger.info(f"取消订阅 {symbol}") + except Exception as e: + logger.error(f"取消订阅 {symbol} 失败: {e}") + + def get_stream_statistics(self) -> Dict: + """ + 获取流统计信息 + + Returns: + 统计信息字典 + """ + if not self.manager: + return {} + + try: + stats = { + 'total_streams': len(self.stream_ids), + 'active_streams': len([s for s in self.stream_ids.values() if s]), + 'subscribed_symbols': list(self.stream_ids.keys()) + } + return stats + except Exception as e: + logger.error(f"获取流统计信息失败: {e}") + return {}