This commit is contained in:
薇薇安 2026-01-13 15:23:36 +08:00
parent 4266d52bc8
commit 3fe84ea2f1
2 changed files with 579 additions and 0 deletions

272
indicators.py Normal file
View File

@ -0,0 +1,272 @@
"""
技术指标模块 - 计算各种技术指标用于交易决策
"""
import logging
from typing import List, Dict, Optional, Tuple
import math
logger = logging.getLogger(__name__)
class TechnicalIndicators:
"""技术指标计算类"""
@staticmethod
def calculate_rsi(prices: List[float], period: int = 14) -> Optional[float]:
"""
计算RSI相对强弱指标
Args:
prices: 价格列表从旧到新
period: 计算周期默认14
Returns:
RSI值0-100None表示数据不足
"""
if len(prices) < period + 1:
return None
gains = []
losses = []
for i in range(len(prices) - period, len(prices)):
if i == 0:
continue
change = prices[i] - prices[i - 1]
if change > 0:
gains.append(change)
losses.append(0)
else:
gains.append(0)
losses.append(abs(change))
if len(gains) == 0:
return None
avg_gain = sum(gains) / len(gains)
avg_loss = sum(losses) / len(losses) if sum(losses) > 0 else 0.0001
if avg_loss == 0:
return 100
rs = avg_gain / avg_loss
rsi = 100 - (100 / (1 + rs))
return rsi
@staticmethod
def calculate_macd(
prices: List[float],
fast_period: int = 12,
slow_period: int = 26,
signal_period: int = 9
) -> Optional[Dict[str, float]]:
"""
计算MACD指标
Args:
prices: 价格列表从旧到新
fast_period: 快线周期默认12
slow_period: 慢线周期默认26
signal_period: 信号线周期默认9
Returns:
{'macd': MACD值, 'signal': 信号线值, 'histogram': 柱状图值}
"""
if len(prices) < slow_period + signal_period:
return None
# 计算EMA
def ema(data, period):
multiplier = 2 / (period + 1)
ema_values = [data[0]]
for price in data[1:]:
ema_values.append((price - ema_values[-1]) * multiplier + ema_values[-1])
return ema_values
fast_ema = ema(prices, fast_period)
slow_ema = ema(prices, slow_period)
# 对齐长度
min_len = min(len(fast_ema), len(slow_ema))
fast_ema = fast_ema[-min_len:]
slow_ema = slow_ema[-min_len:]
# 计算MACD线
macd_line = [fast_ema[i] - slow_ema[i] for i in range(len(fast_ema))]
if len(macd_line) < signal_period:
return None
# 计算信号线
signal_line = ema(macd_line, signal_period)
# 计算柱状图
macd_value = macd_line[-1]
signal_value = signal_line[-1]
histogram = macd_value - signal_value
return {
'macd': macd_value,
'signal': signal_value,
'histogram': histogram
}
@staticmethod
def calculate_bollinger_bands(
prices: List[float],
period: int = 20,
std_dev: float = 2.0
) -> Optional[Dict[str, float]]:
"""
计算布林带
Args:
prices: 价格列表从旧到新
period: 计算周期默认20
std_dev: 标准差倍数默认2.0
Returns:
{'upper': 上轨, 'middle': 中轨, 'lower': 下轨}
"""
if len(prices) < period:
return None
recent_prices = prices[-period:]
middle = sum(recent_prices) / len(recent_prices)
# 计算标准差
variance = sum((p - middle) ** 2 for p in recent_prices) / len(recent_prices)
std = math.sqrt(variance)
upper = middle + (std * std_dev)
lower = middle - (std * std_dev)
return {
'upper': upper,
'middle': middle,
'lower': lower
}
@staticmethod
def calculate_atr(
high_prices: List[float],
low_prices: List[float],
close_prices: List[float],
period: int = 14
) -> Optional[float]:
"""
计算ATR平均真实波幅
Args:
high_prices: 最高价列表
low_prices: 最低价列表
close_prices: 收盘价列表
period: 计算周期默认14
Returns:
ATR值
"""
if len(high_prices) < period + 1 or len(low_prices) < period + 1 or len(close_prices) < period + 1:
return None
true_ranges = []
for i in range(1, len(high_prices)):
tr1 = high_prices[i] - low_prices[i]
tr2 = abs(high_prices[i] - close_prices[i - 1])
tr3 = abs(low_prices[i] - close_prices[i - 1])
true_range = max(tr1, tr2, tr3)
true_ranges.append(true_range)
if len(true_ranges) < period:
return None
recent_tr = true_ranges[-period:]
atr = sum(recent_tr) / len(recent_tr)
return atr
@staticmethod
def calculate_ema(prices: List[float], period: int = 20) -> Optional[float]:
"""
计算EMA指数移动平均
Args:
prices: 价格列表从旧到新
period: 计算周期
Returns:
EMA值
"""
if len(prices) < period:
return None
multiplier = 2 / (period + 1)
ema_value = prices[0]
for price in prices[1:]:
ema_value = (price - ema_value) * multiplier + ema_value
return ema_value
@staticmethod
def calculate_sma(prices: List[float], period: int = 20) -> Optional[float]:
"""
计算SMA简单移动平均
Args:
prices: 价格列表从旧到新
period: 计算周期
Returns:
SMA值
"""
if len(prices) < period:
return None
recent_prices = prices[-period:]
return sum(recent_prices) / len(recent_prices)
@staticmethod
def detect_market_regime(
prices: List[float],
short_period: int = 20,
long_period: int = 50
) -> str:
"""
判断市场状态趋势或震荡
Args:
prices: 价格列表
short_period: 短期均线周期
long_period: 长期均线周期
Returns:
'trending' 'ranging'
"""
if len(prices) < long_period:
return 'unknown'
short_ma = TechnicalIndicators.calculate_sma(prices, short_period)
long_ma = TechnicalIndicators.calculate_sma(prices, long_period)
if short_ma is None or long_ma is None:
return 'unknown'
# 计算价格波动率
recent_prices = prices[-20:]
if len(recent_prices) < 2:
return 'unknown'
volatility = sum(abs(recent_prices[i] - recent_prices[i-1]) for i in range(1, len(recent_prices))) / len(recent_prices)
avg_price = sum(recent_prices) / len(recent_prices)
volatility_pct = (volatility / avg_price) * 100 if avg_price > 0 else 0
# 如果短期均线明显高于或低于长期均线,且波动率较大,判断为趋势
ma_diff_pct = abs(short_ma - long_ma) / long_ma * 100 if long_ma > 0 else 0
if ma_diff_pct > 2 and volatility_pct > 1:
return 'trending'
else:
return 'ranging'

307
unicorn_websocket.py Normal file
View File

@ -0,0 +1,307 @@
"""
Unicorn WebSocket模块 - 提供高性能实时数据流
"""
import asyncio
import logging
from typing import Dict, List, Optional, Callable
from unicorn_binance_websocket_api.unicorn_binance_websocket_api_manager import BinanceWebSocketApiManager
import config
logger = logging.getLogger(__name__)
class UnicornWebSocketManager:
"""Unicorn WebSocket管理器"""
def __init__(self, testnet: bool = False):
"""
初始化Unicorn WebSocket管理器
Args:
testnet: 是否使用测试网
"""
self.testnet = testnet or config.USE_TESTNET
self.manager: Optional[BinanceWebSocketApiManager] = None
self.stream_ids: Dict[str, str] = {} # symbol -> stream_id
self.price_callbacks: Dict[str, List[Callable]] = {} # symbol -> callbacks
self.running = False
def start(self):
"""启动WebSocket管理器"""
try:
# 创建管理器
self.manager = BinanceWebSocketApiManager(
exchange="binance.com-futures" if not self.testnet else "binance.com-futures-testnet",
throw_exception_if_unrepairable=True,
high_performance=True
)
self.running = True
logger.info(f"Unicorn WebSocket管理器启动成功 (测试网: {self.testnet})")
return True
except Exception as e:
logger.error(f"启动Unicorn WebSocket管理器失败: {e}")
return False
def stop(self):
"""停止WebSocket管理器"""
self.running = False
if self.manager:
# 停止所有流
for stream_id in self.stream_ids.values():
try:
self.manager.stop_stream(stream_id)
except:
pass
# 停止管理器
try:
self.manager.stop_manager_with_all_streams()
except:
pass
logger.info("Unicorn WebSocket管理器已停止")
def subscribe_ticker(self, symbols: List[str], callback: Callable) -> Dict[str, str]:
"""
订阅交易对的价格流
Args:
symbols: 交易对列表
callback: 价格更新回调函数 callback(symbol, price_data)
Returns:
交易对到stream_id的映射
"""
if not self.manager:
logger.error("WebSocket管理器未启动")
return {}
stream_ids = {}
try:
# 构建流名称列表
streams = []
for symbol in symbols:
# 转换为小写(币安要求)
symbol_lower = symbol.lower()
# 订阅ticker流
stream_name = f"{symbol_lower}@ticker"
streams.append(stream_name)
# 注册回调
if symbol not in self.price_callbacks:
self.price_callbacks[symbol] = []
self.price_callbacks[symbol].append(callback)
# 创建多路复用流
if streams:
stream_id = self.manager.create_stream(
["arr"],
streams,
output="UnicornFy"
)
# 记录stream_id
for symbol in symbols:
self.stream_ids[symbol] = stream_id
stream_ids[symbol] = stream_id
logger.info(f"订阅 {len(symbols)} 个交易对的价格流")
return stream_ids
except Exception as e:
logger.error(f"订阅价格流失败: {e}")
return {}
def subscribe_kline(
self,
symbols: List[str],
interval: str = "5m",
callback: Optional[Callable] = None
) -> Dict[str, str]:
"""
订阅K线数据流
Args:
symbols: 交易对列表
interval: K线周期1m, 5m, 15m等
callback: K线更新回调函数 callback(symbol, kline_data)
Returns:
交易对到stream_id的映射
"""
if not self.manager:
logger.error("WebSocket管理器未启动")
return {}
stream_ids = {}
try:
# 构建流名称列表
streams = []
for symbol in symbols:
symbol_lower = symbol.lower()
stream_name = f"{symbol_lower}@kline_{interval}"
streams.append(stream_name)
# 创建多路复用流
if streams:
stream_id = self.manager.create_stream(
["arr"],
streams,
output="UnicornFy"
)
for symbol in symbols:
stream_ids[symbol] = stream_id
logger.info(f"订阅 {len(symbols)} 个交易对的K线流 ({interval})")
return stream_ids
except Exception as e:
logger.error(f"订阅K线流失败: {e}")
return {}
def get_realtime_price(self, symbol: str) -> Optional[float]:
"""
获取实时价格从WebSocket流缓冲区中
Args:
symbol: 交易对
Returns:
实时价格如果未订阅则返回None
"""
if not self.manager or symbol not in self.stream_ids:
return None
try:
stream_id = self.stream_ids[symbol]
# 从流缓冲区获取最新数据
data = self.manager.pop_stream_data_from_stream_buffer(stream_id)
if data and isinstance(data, dict):
# 解析ticker数据
if 'event_type' in data and data['event_type'] == '24hrTicker':
price_data = data.get('data', {})
if 'c' in price_data: # 最新价格
return float(price_data['c'])
elif 'data' in data and isinstance(data['data'], dict):
if 'c' in data['data']: # 最新价格
return float(data['data']['c'])
elif 'close' in data['data']: # K线收盘价
return float(data['data']['close'])
return None
except Exception as e:
logger.debug(f"获取 {symbol} 实时价格失败: {e}")
return None
async def process_stream_data(self):
"""
处理WebSocket流数据异步
"""
if not self.manager:
return
while self.running:
try:
# 处理所有流的数据
for symbol, stream_id in list(self.stream_ids.items()):
try:
# 获取流数据(非阻塞)
stream_data = self.manager.pop_stream_data_from_stream_buffer(stream_id)
if stream_data:
# 处理数据
await self._handle_stream_data(symbol, stream_data)
except Exception as e:
logger.debug(f"处理 {symbol} 流数据失败: {e}")
# 短暂休眠避免CPU占用过高
await asyncio.sleep(0.1)
except Exception as e:
logger.error(f"处理WebSocket流数据失败: {e}")
await asyncio.sleep(1)
async def _handle_stream_data(self, symbol: str, data: Dict):
"""
处理单个流的数据
Args:
symbol: 交易对
data: 流数据
"""
try:
if not data or not isinstance(data, dict):
return
# 处理ticker数据
if 'event_type' in data and data['event_type'] == '24hrTicker':
price_data = data.get('data', {})
if 'c' in price_data: # 最新价格
price = float(price_data['c'])
# 调用所有注册的回调
if symbol in self.price_callbacks:
for callback in self.price_callbacks[symbol]:
try:
if asyncio.iscoroutinefunction(callback):
await callback(symbol, price, price_data)
else:
callback(symbol, price, price_data)
except Exception as e:
logger.debug(f"回调函数执行失败: {e}")
# 处理K线数据
elif 'event_type' in data and data['event_type'] == 'kline':
kline_data = data.get('data', {})
if 'k' in kline_data:
kline = kline_data['k']
# 可以在这里处理K线更新
logger.debug(f"{symbol} K线更新: {kline.get('c', 'N/A')}")
except Exception as e:
logger.debug(f"处理 {symbol} 流数据失败: {e}")
def unsubscribe(self, symbol: str):
"""
取消订阅交易对
Args:
symbol: 交易对
"""
if symbol in self.stream_ids:
stream_id = self.stream_ids[symbol]
try:
self.manager.stop_stream(stream_id)
del self.stream_ids[symbol]
if symbol in self.price_callbacks:
del self.price_callbacks[symbol]
logger.info(f"取消订阅 {symbol}")
except Exception as e:
logger.error(f"取消订阅 {symbol} 失败: {e}")
def get_stream_statistics(self) -> Dict:
"""
获取流统计信息
Returns:
统计信息字典
"""
if not self.manager:
return {}
try:
stats = {
'total_streams': len(self.stream_ids),
'active_streams': len([s for s in self.stream_ids.values() if s]),
'subscribed_symbols': list(self.stream_ids.keys())
}
return stats
except Exception as e:
logger.error(f"获取流统计信息失败: {e}")
return {}