a
This commit is contained in:
parent
4266d52bc8
commit
3fe84ea2f1
272
indicators.py
Normal file
272
indicators.py
Normal file
|
|
@ -0,0 +1,272 @@
|
|||
"""
|
||||
技术指标模块 - 计算各种技术指标用于交易决策
|
||||
"""
|
||||
import logging
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
import math
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TechnicalIndicators:
|
||||
"""技术指标计算类"""
|
||||
|
||||
@staticmethod
|
||||
def calculate_rsi(prices: List[float], period: int = 14) -> Optional[float]:
|
||||
"""
|
||||
计算RSI(相对强弱指标)
|
||||
|
||||
Args:
|
||||
prices: 价格列表(从旧到新)
|
||||
period: 计算周期,默认14
|
||||
|
||||
Returns:
|
||||
RSI值(0-100),None表示数据不足
|
||||
"""
|
||||
if len(prices) < period + 1:
|
||||
return None
|
||||
|
||||
gains = []
|
||||
losses = []
|
||||
|
||||
for i in range(len(prices) - period, len(prices)):
|
||||
if i == 0:
|
||||
continue
|
||||
change = prices[i] - prices[i - 1]
|
||||
if change > 0:
|
||||
gains.append(change)
|
||||
losses.append(0)
|
||||
else:
|
||||
gains.append(0)
|
||||
losses.append(abs(change))
|
||||
|
||||
if len(gains) == 0:
|
||||
return None
|
||||
|
||||
avg_gain = sum(gains) / len(gains)
|
||||
avg_loss = sum(losses) / len(losses) if sum(losses) > 0 else 0.0001
|
||||
|
||||
if avg_loss == 0:
|
||||
return 100
|
||||
|
||||
rs = avg_gain / avg_loss
|
||||
rsi = 100 - (100 / (1 + rs))
|
||||
|
||||
return rsi
|
||||
|
||||
@staticmethod
|
||||
def calculate_macd(
|
||||
prices: List[float],
|
||||
fast_period: int = 12,
|
||||
slow_period: int = 26,
|
||||
signal_period: int = 9
|
||||
) -> Optional[Dict[str, float]]:
|
||||
"""
|
||||
计算MACD指标
|
||||
|
||||
Args:
|
||||
prices: 价格列表(从旧到新)
|
||||
fast_period: 快线周期,默认12
|
||||
slow_period: 慢线周期,默认26
|
||||
signal_period: 信号线周期,默认9
|
||||
|
||||
Returns:
|
||||
{'macd': MACD值, 'signal': 信号线值, 'histogram': 柱状图值}
|
||||
"""
|
||||
if len(prices) < slow_period + signal_period:
|
||||
return None
|
||||
|
||||
# 计算EMA
|
||||
def ema(data, period):
|
||||
multiplier = 2 / (period + 1)
|
||||
ema_values = [data[0]]
|
||||
for price in data[1:]:
|
||||
ema_values.append((price - ema_values[-1]) * multiplier + ema_values[-1])
|
||||
return ema_values
|
||||
|
||||
fast_ema = ema(prices, fast_period)
|
||||
slow_ema = ema(prices, slow_period)
|
||||
|
||||
# 对齐长度
|
||||
min_len = min(len(fast_ema), len(slow_ema))
|
||||
fast_ema = fast_ema[-min_len:]
|
||||
slow_ema = slow_ema[-min_len:]
|
||||
|
||||
# 计算MACD线
|
||||
macd_line = [fast_ema[i] - slow_ema[i] for i in range(len(fast_ema))]
|
||||
|
||||
if len(macd_line) < signal_period:
|
||||
return None
|
||||
|
||||
# 计算信号线
|
||||
signal_line = ema(macd_line, signal_period)
|
||||
|
||||
# 计算柱状图
|
||||
macd_value = macd_line[-1]
|
||||
signal_value = signal_line[-1]
|
||||
histogram = macd_value - signal_value
|
||||
|
||||
return {
|
||||
'macd': macd_value,
|
||||
'signal': signal_value,
|
||||
'histogram': histogram
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def calculate_bollinger_bands(
|
||||
prices: List[float],
|
||||
period: int = 20,
|
||||
std_dev: float = 2.0
|
||||
) -> Optional[Dict[str, float]]:
|
||||
"""
|
||||
计算布林带
|
||||
|
||||
Args:
|
||||
prices: 价格列表(从旧到新)
|
||||
period: 计算周期,默认20
|
||||
std_dev: 标准差倍数,默认2.0
|
||||
|
||||
Returns:
|
||||
{'upper': 上轨, 'middle': 中轨, 'lower': 下轨}
|
||||
"""
|
||||
if len(prices) < period:
|
||||
return None
|
||||
|
||||
recent_prices = prices[-period:]
|
||||
middle = sum(recent_prices) / len(recent_prices)
|
||||
|
||||
# 计算标准差
|
||||
variance = sum((p - middle) ** 2 for p in recent_prices) / len(recent_prices)
|
||||
std = math.sqrt(variance)
|
||||
|
||||
upper = middle + (std * std_dev)
|
||||
lower = middle - (std * std_dev)
|
||||
|
||||
return {
|
||||
'upper': upper,
|
||||
'middle': middle,
|
||||
'lower': lower
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def calculate_atr(
|
||||
high_prices: List[float],
|
||||
low_prices: List[float],
|
||||
close_prices: List[float],
|
||||
period: int = 14
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
计算ATR(平均真实波幅)
|
||||
|
||||
Args:
|
||||
high_prices: 最高价列表
|
||||
low_prices: 最低价列表
|
||||
close_prices: 收盘价列表
|
||||
period: 计算周期,默认14
|
||||
|
||||
Returns:
|
||||
ATR值
|
||||
"""
|
||||
if len(high_prices) < period + 1 or len(low_prices) < period + 1 or len(close_prices) < period + 1:
|
||||
return None
|
||||
|
||||
true_ranges = []
|
||||
|
||||
for i in range(1, len(high_prices)):
|
||||
tr1 = high_prices[i] - low_prices[i]
|
||||
tr2 = abs(high_prices[i] - close_prices[i - 1])
|
||||
tr3 = abs(low_prices[i] - close_prices[i - 1])
|
||||
true_range = max(tr1, tr2, tr3)
|
||||
true_ranges.append(true_range)
|
||||
|
||||
if len(true_ranges) < period:
|
||||
return None
|
||||
|
||||
recent_tr = true_ranges[-period:]
|
||||
atr = sum(recent_tr) / len(recent_tr)
|
||||
|
||||
return atr
|
||||
|
||||
@staticmethod
|
||||
def calculate_ema(prices: List[float], period: int = 20) -> Optional[float]:
|
||||
"""
|
||||
计算EMA(指数移动平均)
|
||||
|
||||
Args:
|
||||
prices: 价格列表(从旧到新)
|
||||
period: 计算周期
|
||||
|
||||
Returns:
|
||||
EMA值
|
||||
"""
|
||||
if len(prices) < period:
|
||||
return None
|
||||
|
||||
multiplier = 2 / (period + 1)
|
||||
ema_value = prices[0]
|
||||
|
||||
for price in prices[1:]:
|
||||
ema_value = (price - ema_value) * multiplier + ema_value
|
||||
|
||||
return ema_value
|
||||
|
||||
@staticmethod
|
||||
def calculate_sma(prices: List[float], period: int = 20) -> Optional[float]:
|
||||
"""
|
||||
计算SMA(简单移动平均)
|
||||
|
||||
Args:
|
||||
prices: 价格列表(从旧到新)
|
||||
period: 计算周期
|
||||
|
||||
Returns:
|
||||
SMA值
|
||||
"""
|
||||
if len(prices) < period:
|
||||
return None
|
||||
|
||||
recent_prices = prices[-period:]
|
||||
return sum(recent_prices) / len(recent_prices)
|
||||
|
||||
@staticmethod
|
||||
def detect_market_regime(
|
||||
prices: List[float],
|
||||
short_period: int = 20,
|
||||
long_period: int = 50
|
||||
) -> str:
|
||||
"""
|
||||
判断市场状态:趋势或震荡
|
||||
|
||||
Args:
|
||||
prices: 价格列表
|
||||
short_period: 短期均线周期
|
||||
long_period: 长期均线周期
|
||||
|
||||
Returns:
|
||||
'trending' 或 'ranging'
|
||||
"""
|
||||
if len(prices) < long_period:
|
||||
return 'unknown'
|
||||
|
||||
short_ma = TechnicalIndicators.calculate_sma(prices, short_period)
|
||||
long_ma = TechnicalIndicators.calculate_sma(prices, long_period)
|
||||
|
||||
if short_ma is None or long_ma is None:
|
||||
return 'unknown'
|
||||
|
||||
# 计算价格波动率
|
||||
recent_prices = prices[-20:]
|
||||
if len(recent_prices) < 2:
|
||||
return 'unknown'
|
||||
|
||||
volatility = sum(abs(recent_prices[i] - recent_prices[i-1]) for i in range(1, len(recent_prices))) / len(recent_prices)
|
||||
avg_price = sum(recent_prices) / len(recent_prices)
|
||||
volatility_pct = (volatility / avg_price) * 100 if avg_price > 0 else 0
|
||||
|
||||
# 如果短期均线明显高于或低于长期均线,且波动率较大,判断为趋势
|
||||
ma_diff_pct = abs(short_ma - long_ma) / long_ma * 100 if long_ma > 0 else 0
|
||||
|
||||
if ma_diff_pct > 2 and volatility_pct > 1:
|
||||
return 'trending'
|
||||
else:
|
||||
return 'ranging'
|
||||
307
unicorn_websocket.py
Normal file
307
unicorn_websocket.py
Normal file
|
|
@ -0,0 +1,307 @@
|
|||
"""
|
||||
Unicorn WebSocket模块 - 提供高性能实时数据流
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Callable
|
||||
from unicorn_binance_websocket_api.unicorn_binance_websocket_api_manager import BinanceWebSocketApiManager
|
||||
import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnicornWebSocketManager:
|
||||
"""Unicorn WebSocket管理器"""
|
||||
|
||||
def __init__(self, testnet: bool = False):
|
||||
"""
|
||||
初始化Unicorn WebSocket管理器
|
||||
|
||||
Args:
|
||||
testnet: 是否使用测试网
|
||||
"""
|
||||
self.testnet = testnet or config.USE_TESTNET
|
||||
self.manager: Optional[BinanceWebSocketApiManager] = None
|
||||
self.stream_ids: Dict[str, str] = {} # symbol -> stream_id
|
||||
self.price_callbacks: Dict[str, List[Callable]] = {} # symbol -> callbacks
|
||||
self.running = False
|
||||
|
||||
def start(self):
|
||||
"""启动WebSocket管理器"""
|
||||
try:
|
||||
# 创建管理器
|
||||
self.manager = BinanceWebSocketApiManager(
|
||||
exchange="binance.com-futures" if not self.testnet else "binance.com-futures-testnet",
|
||||
throw_exception_if_unrepairable=True,
|
||||
high_performance=True
|
||||
)
|
||||
|
||||
self.running = True
|
||||
logger.info(f"Unicorn WebSocket管理器启动成功 (测试网: {self.testnet})")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"启动Unicorn WebSocket管理器失败: {e}")
|
||||
return False
|
||||
|
||||
def stop(self):
|
||||
"""停止WebSocket管理器"""
|
||||
self.running = False
|
||||
if self.manager:
|
||||
# 停止所有流
|
||||
for stream_id in self.stream_ids.values():
|
||||
try:
|
||||
self.manager.stop_stream(stream_id)
|
||||
except:
|
||||
pass
|
||||
|
||||
# 停止管理器
|
||||
try:
|
||||
self.manager.stop_manager_with_all_streams()
|
||||
except:
|
||||
pass
|
||||
|
||||
logger.info("Unicorn WebSocket管理器已停止")
|
||||
|
||||
def subscribe_ticker(self, symbols: List[str], callback: Callable) -> Dict[str, str]:
|
||||
"""
|
||||
订阅交易对的价格流
|
||||
|
||||
Args:
|
||||
symbols: 交易对列表
|
||||
callback: 价格更新回调函数 callback(symbol, price_data)
|
||||
|
||||
Returns:
|
||||
交易对到stream_id的映射
|
||||
"""
|
||||
if not self.manager:
|
||||
logger.error("WebSocket管理器未启动")
|
||||
return {}
|
||||
|
||||
stream_ids = {}
|
||||
|
||||
try:
|
||||
# 构建流名称列表
|
||||
streams = []
|
||||
for symbol in symbols:
|
||||
# 转换为小写(币安要求)
|
||||
symbol_lower = symbol.lower()
|
||||
# 订阅ticker流
|
||||
stream_name = f"{symbol_lower}@ticker"
|
||||
streams.append(stream_name)
|
||||
|
||||
# 注册回调
|
||||
if symbol not in self.price_callbacks:
|
||||
self.price_callbacks[symbol] = []
|
||||
self.price_callbacks[symbol].append(callback)
|
||||
|
||||
# 创建多路复用流
|
||||
if streams:
|
||||
stream_id = self.manager.create_stream(
|
||||
["arr"],
|
||||
streams,
|
||||
output="UnicornFy"
|
||||
)
|
||||
|
||||
# 记录stream_id
|
||||
for symbol in symbols:
|
||||
self.stream_ids[symbol] = stream_id
|
||||
stream_ids[symbol] = stream_id
|
||||
|
||||
logger.info(f"订阅 {len(symbols)} 个交易对的价格流")
|
||||
|
||||
return stream_ids
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"订阅价格流失败: {e}")
|
||||
return {}
|
||||
|
||||
def subscribe_kline(
|
||||
self,
|
||||
symbols: List[str],
|
||||
interval: str = "5m",
|
||||
callback: Optional[Callable] = None
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
订阅K线数据流
|
||||
|
||||
Args:
|
||||
symbols: 交易对列表
|
||||
interval: K线周期(1m, 5m, 15m等)
|
||||
callback: K线更新回调函数 callback(symbol, kline_data)
|
||||
|
||||
Returns:
|
||||
交易对到stream_id的映射
|
||||
"""
|
||||
if not self.manager:
|
||||
logger.error("WebSocket管理器未启动")
|
||||
return {}
|
||||
|
||||
stream_ids = {}
|
||||
|
||||
try:
|
||||
# 构建流名称列表
|
||||
streams = []
|
||||
for symbol in symbols:
|
||||
symbol_lower = symbol.lower()
|
||||
stream_name = f"{symbol_lower}@kline_{interval}"
|
||||
streams.append(stream_name)
|
||||
|
||||
# 创建多路复用流
|
||||
if streams:
|
||||
stream_id = self.manager.create_stream(
|
||||
["arr"],
|
||||
streams,
|
||||
output="UnicornFy"
|
||||
)
|
||||
|
||||
for symbol in symbols:
|
||||
stream_ids[symbol] = stream_id
|
||||
|
||||
logger.info(f"订阅 {len(symbols)} 个交易对的K线流 ({interval})")
|
||||
|
||||
return stream_ids
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"订阅K线流失败: {e}")
|
||||
return {}
|
||||
|
||||
def get_realtime_price(self, symbol: str) -> Optional[float]:
|
||||
"""
|
||||
获取实时价格(从WebSocket流缓冲区中)
|
||||
|
||||
Args:
|
||||
symbol: 交易对
|
||||
|
||||
Returns:
|
||||
实时价格,如果未订阅则返回None
|
||||
"""
|
||||
if not self.manager or symbol not in self.stream_ids:
|
||||
return None
|
||||
|
||||
try:
|
||||
stream_id = self.stream_ids[symbol]
|
||||
# 从流缓冲区获取最新数据
|
||||
data = self.manager.pop_stream_data_from_stream_buffer(stream_id)
|
||||
|
||||
if data and isinstance(data, dict):
|
||||
# 解析ticker数据
|
||||
if 'event_type' in data and data['event_type'] == '24hrTicker':
|
||||
price_data = data.get('data', {})
|
||||
if 'c' in price_data: # 最新价格
|
||||
return float(price_data['c'])
|
||||
elif 'data' in data and isinstance(data['data'], dict):
|
||||
if 'c' in data['data']: # 最新价格
|
||||
return float(data['data']['c'])
|
||||
elif 'close' in data['data']: # K线收盘价
|
||||
return float(data['data']['close'])
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"获取 {symbol} 实时价格失败: {e}")
|
||||
return None
|
||||
|
||||
async def process_stream_data(self):
|
||||
"""
|
||||
处理WebSocket流数据(异步)
|
||||
"""
|
||||
if not self.manager:
|
||||
return
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# 处理所有流的数据
|
||||
for symbol, stream_id in list(self.stream_ids.items()):
|
||||
try:
|
||||
# 获取流数据(非阻塞)
|
||||
stream_data = self.manager.pop_stream_data_from_stream_buffer(stream_id)
|
||||
if stream_data:
|
||||
# 处理数据
|
||||
await self._handle_stream_data(symbol, stream_data)
|
||||
except Exception as e:
|
||||
logger.debug(f"处理 {symbol} 流数据失败: {e}")
|
||||
|
||||
# 短暂休眠,避免CPU占用过高
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理WebSocket流数据失败: {e}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def _handle_stream_data(self, symbol: str, data: Dict):
|
||||
"""
|
||||
处理单个流的数据
|
||||
|
||||
Args:
|
||||
symbol: 交易对
|
||||
data: 流数据
|
||||
"""
|
||||
try:
|
||||
if not data or not isinstance(data, dict):
|
||||
return
|
||||
|
||||
# 处理ticker数据
|
||||
if 'event_type' in data and data['event_type'] == '24hrTicker':
|
||||
price_data = data.get('data', {})
|
||||
if 'c' in price_data: # 最新价格
|
||||
price = float(price_data['c'])
|
||||
# 调用所有注册的回调
|
||||
if symbol in self.price_callbacks:
|
||||
for callback in self.price_callbacks[symbol]:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(symbol, price, price_data)
|
||||
else:
|
||||
callback(symbol, price, price_data)
|
||||
except Exception as e:
|
||||
logger.debug(f"回调函数执行失败: {e}")
|
||||
|
||||
# 处理K线数据
|
||||
elif 'event_type' in data and data['event_type'] == 'kline':
|
||||
kline_data = data.get('data', {})
|
||||
if 'k' in kline_data:
|
||||
kline = kline_data['k']
|
||||
# 可以在这里处理K线更新
|
||||
logger.debug(f"{symbol} K线更新: {kline.get('c', 'N/A')}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"处理 {symbol} 流数据失败: {e}")
|
||||
|
||||
def unsubscribe(self, symbol: str):
|
||||
"""
|
||||
取消订阅交易对
|
||||
|
||||
Args:
|
||||
symbol: 交易对
|
||||
"""
|
||||
if symbol in self.stream_ids:
|
||||
stream_id = self.stream_ids[symbol]
|
||||
try:
|
||||
self.manager.stop_stream(stream_id)
|
||||
del self.stream_ids[symbol]
|
||||
if symbol in self.price_callbacks:
|
||||
del self.price_callbacks[symbol]
|
||||
logger.info(f"取消订阅 {symbol}")
|
||||
except Exception as e:
|
||||
logger.error(f"取消订阅 {symbol} 失败: {e}")
|
||||
|
||||
def get_stream_statistics(self) -> Dict:
|
||||
"""
|
||||
获取流统计信息
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
if not self.manager:
|
||||
return {}
|
||||
|
||||
try:
|
||||
stats = {
|
||||
'total_streams': len(self.stream_ids),
|
||||
'active_streams': len([s for s in self.stream_ids.values() if s]),
|
||||
'subscribed_symbols': list(self.stream_ids.keys())
|
||||
}
|
||||
return stats
|
||||
except Exception as e:
|
||||
logger.error(f"获取流统计信息失败: {e}")
|
||||
return {}
|
||||
Loading…
Reference in New Issue
Block a user