273 lines
7.8 KiB
Python
273 lines
7.8 KiB
Python
"""
|
||
技术指标模块 - 计算各种技术指标用于交易决策
|
||
"""
|
||
import logging
|
||
from typing import List, Dict, Optional, Tuple
|
||
import math
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class TechnicalIndicators:
|
||
"""技术指标计算类"""
|
||
|
||
@staticmethod
|
||
def calculate_rsi(prices: List[float], period: int = 14) -> Optional[float]:
|
||
"""
|
||
计算RSI(相对强弱指标)
|
||
|
||
Args:
|
||
prices: 价格列表(从旧到新)
|
||
period: 计算周期,默认14
|
||
|
||
Returns:
|
||
RSI值(0-100),None表示数据不足
|
||
"""
|
||
if len(prices) < period + 1:
|
||
return None
|
||
|
||
gains = []
|
||
losses = []
|
||
|
||
for i in range(len(prices) - period, len(prices)):
|
||
if i == 0:
|
||
continue
|
||
change = prices[i] - prices[i - 1]
|
||
if change > 0:
|
||
gains.append(change)
|
||
losses.append(0)
|
||
else:
|
||
gains.append(0)
|
||
losses.append(abs(change))
|
||
|
||
if len(gains) == 0:
|
||
return None
|
||
|
||
avg_gain = sum(gains) / len(gains)
|
||
avg_loss = sum(losses) / len(losses) if sum(losses) > 0 else 0.0001
|
||
|
||
if avg_loss == 0:
|
||
return 100
|
||
|
||
rs = avg_gain / avg_loss
|
||
rsi = 100 - (100 / (1 + rs))
|
||
|
||
return rsi
|
||
|
||
@staticmethod
|
||
def calculate_macd(
|
||
prices: List[float],
|
||
fast_period: int = 12,
|
||
slow_period: int = 26,
|
||
signal_period: int = 9
|
||
) -> Optional[Dict[str, float]]:
|
||
"""
|
||
计算MACD指标
|
||
|
||
Args:
|
||
prices: 价格列表(从旧到新)
|
||
fast_period: 快线周期,默认12
|
||
slow_period: 慢线周期,默认26
|
||
signal_period: 信号线周期,默认9
|
||
|
||
Returns:
|
||
{'macd': MACD值, 'signal': 信号线值, 'histogram': 柱状图值}
|
||
"""
|
||
if len(prices) < slow_period + signal_period:
|
||
return None
|
||
|
||
# 计算EMA
|
||
def ema(data, period):
|
||
multiplier = 2 / (period + 1)
|
||
ema_values = [data[0]]
|
||
for price in data[1:]:
|
||
ema_values.append((price - ema_values[-1]) * multiplier + ema_values[-1])
|
||
return ema_values
|
||
|
||
fast_ema = ema(prices, fast_period)
|
||
slow_ema = ema(prices, slow_period)
|
||
|
||
# 对齐长度
|
||
min_len = min(len(fast_ema), len(slow_ema))
|
||
fast_ema = fast_ema[-min_len:]
|
||
slow_ema = slow_ema[-min_len:]
|
||
|
||
# 计算MACD线
|
||
macd_line = [fast_ema[i] - slow_ema[i] for i in range(len(fast_ema))]
|
||
|
||
if len(macd_line) < signal_period:
|
||
return None
|
||
|
||
# 计算信号线
|
||
signal_line = ema(macd_line, signal_period)
|
||
|
||
# 计算柱状图
|
||
macd_value = macd_line[-1]
|
||
signal_value = signal_line[-1]
|
||
histogram = macd_value - signal_value
|
||
|
||
return {
|
||
'macd': macd_value,
|
||
'signal': signal_value,
|
||
'histogram': histogram
|
||
}
|
||
|
||
@staticmethod
|
||
def calculate_bollinger_bands(
|
||
prices: List[float],
|
||
period: int = 20,
|
||
std_dev: float = 2.0
|
||
) -> Optional[Dict[str, float]]:
|
||
"""
|
||
计算布林带
|
||
|
||
Args:
|
||
prices: 价格列表(从旧到新)
|
||
period: 计算周期,默认20
|
||
std_dev: 标准差倍数,默认2.0
|
||
|
||
Returns:
|
||
{'upper': 上轨, 'middle': 中轨, 'lower': 下轨}
|
||
"""
|
||
if len(prices) < period:
|
||
return None
|
||
|
||
recent_prices = prices[-period:]
|
||
middle = sum(recent_prices) / len(recent_prices)
|
||
|
||
# 计算标准差
|
||
variance = sum((p - middle) ** 2 for p in recent_prices) / len(recent_prices)
|
||
std = math.sqrt(variance)
|
||
|
||
upper = middle + (std * std_dev)
|
||
lower = middle - (std * std_dev)
|
||
|
||
return {
|
||
'upper': upper,
|
||
'middle': middle,
|
||
'lower': lower
|
||
}
|
||
|
||
@staticmethod
|
||
def calculate_atr(
|
||
high_prices: List[float],
|
||
low_prices: List[float],
|
||
close_prices: List[float],
|
||
period: int = 14
|
||
) -> Optional[float]:
|
||
"""
|
||
计算ATR(平均真实波幅)
|
||
|
||
Args:
|
||
high_prices: 最高价列表
|
||
low_prices: 最低价列表
|
||
close_prices: 收盘价列表
|
||
period: 计算周期,默认14
|
||
|
||
Returns:
|
||
ATR值
|
||
"""
|
||
if len(high_prices) < period + 1 or len(low_prices) < period + 1 or len(close_prices) < period + 1:
|
||
return None
|
||
|
||
true_ranges = []
|
||
|
||
for i in range(1, len(high_prices)):
|
||
tr1 = high_prices[i] - low_prices[i]
|
||
tr2 = abs(high_prices[i] - close_prices[i - 1])
|
||
tr3 = abs(low_prices[i] - close_prices[i - 1])
|
||
true_range = max(tr1, tr2, tr3)
|
||
true_ranges.append(true_range)
|
||
|
||
if len(true_ranges) < period:
|
||
return None
|
||
|
||
recent_tr = true_ranges[-period:]
|
||
atr = sum(recent_tr) / len(recent_tr)
|
||
|
||
return atr
|
||
|
||
@staticmethod
|
||
def calculate_ema(prices: List[float], period: int = 20) -> Optional[float]:
|
||
"""
|
||
计算EMA(指数移动平均)
|
||
|
||
Args:
|
||
prices: 价格列表(从旧到新)
|
||
period: 计算周期
|
||
|
||
Returns:
|
||
EMA值
|
||
"""
|
||
if len(prices) < period:
|
||
return None
|
||
|
||
multiplier = 2 / (period + 1)
|
||
ema_value = prices[0]
|
||
|
||
for price in prices[1:]:
|
||
ema_value = (price - ema_value) * multiplier + ema_value
|
||
|
||
return ema_value
|
||
|
||
@staticmethod
|
||
def calculate_sma(prices: List[float], period: int = 20) -> Optional[float]:
|
||
"""
|
||
计算SMA(简单移动平均)
|
||
|
||
Args:
|
||
prices: 价格列表(从旧到新)
|
||
period: 计算周期
|
||
|
||
Returns:
|
||
SMA值
|
||
"""
|
||
if len(prices) < period:
|
||
return None
|
||
|
||
recent_prices = prices[-period:]
|
||
return sum(recent_prices) / len(recent_prices)
|
||
|
||
@staticmethod
|
||
def detect_market_regime(
|
||
prices: List[float],
|
||
short_period: int = 20,
|
||
long_period: int = 50
|
||
) -> str:
|
||
"""
|
||
判断市场状态:趋势或震荡
|
||
|
||
Args:
|
||
prices: 价格列表
|
||
short_period: 短期均线周期
|
||
long_period: 长期均线周期
|
||
|
||
Returns:
|
||
'trending' 或 'ranging'
|
||
"""
|
||
if len(prices) < long_period:
|
||
return 'unknown'
|
||
|
||
short_ma = TechnicalIndicators.calculate_sma(prices, short_period)
|
||
long_ma = TechnicalIndicators.calculate_sma(prices, long_period)
|
||
|
||
if short_ma is None or long_ma is None:
|
||
return 'unknown'
|
||
|
||
# 计算价格波动率
|
||
recent_prices = prices[-20:]
|
||
if len(recent_prices) < 2:
|
||
return 'unknown'
|
||
|
||
volatility = sum(abs(recent_prices[i] - recent_prices[i-1]) for i in range(1, len(recent_prices))) / len(recent_prices)
|
||
avg_price = sum(recent_prices) / len(recent_prices)
|
||
volatility_pct = (volatility / avg_price) * 100 if avg_price > 0 else 0
|
||
|
||
# 如果短期均线明显高于或低于长期均线,且波动率较大,判断为趋势
|
||
ma_diff_pct = abs(short_ma - long_ma) / long_ma * 100 if long_ma > 0 else 0
|
||
|
||
if ma_diff_pct > 2 and volatility_pct > 1:
|
||
return 'trending'
|
||
else:
|
||
return 'ranging'
|