auto_trade_sys/trading_system/indicators.py
薇薇安 11e3532ac3 a
2026-01-17 20:23:49 +08:00

301 lines
8.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
技术指标模块 - 计算各种技术指标用于交易决策
"""
import logging
from typing import List, Dict, Optional, Tuple
import math
logger = logging.getLogger(__name__)
class TechnicalIndicators:
"""技术指标计算类"""
@staticmethod
def calculate_rsi(prices: List[float], period: int = 14) -> Optional[float]:
"""
计算RSI相对强弱指标
Args:
prices: 价格列表(从旧到新)
period: 计算周期默认14
Returns:
RSI值0-100None表示数据不足
"""
if len(prices) < period + 1:
return None
gains = []
losses = []
for i in range(len(prices) - period, len(prices)):
if i == 0:
continue
change = prices[i] - prices[i - 1]
if change > 0:
gains.append(change)
losses.append(0)
else:
gains.append(0)
losses.append(abs(change))
if len(gains) == 0:
return None
avg_gain = sum(gains) / len(gains)
avg_loss = sum(losses) / len(losses) if sum(losses) > 0 else 0.0001
if avg_loss == 0:
return 100
rs = avg_gain / avg_loss
rsi = 100 - (100 / (1 + rs))
return rsi
@staticmethod
def calculate_macd(
prices: List[float],
fast_period: int = 12,
slow_period: int = 26,
signal_period: int = 9
) -> Optional[Dict[str, float]]:
"""
计算MACD指标
Args:
prices: 价格列表(从旧到新)
fast_period: 快线周期默认12
slow_period: 慢线周期默认26
signal_period: 信号线周期默认9
Returns:
{'macd': MACD值, 'signal': 信号线值, 'histogram': 柱状图值}
"""
if len(prices) < slow_period + signal_period:
return None
# 计算EMA
def ema(data, period):
multiplier = 2 / (period + 1)
ema_values = [data[0]]
for price in data[1:]:
ema_values.append((price - ema_values[-1]) * multiplier + ema_values[-1])
return ema_values
fast_ema = ema(prices, fast_period)
slow_ema = ema(prices, slow_period)
# 对齐长度
min_len = min(len(fast_ema), len(slow_ema))
fast_ema = fast_ema[-min_len:]
slow_ema = slow_ema[-min_len:]
# 计算MACD线
macd_line = [fast_ema[i] - slow_ema[i] for i in range(len(fast_ema))]
if len(macd_line) < signal_period:
return None
# 计算信号线
signal_line = ema(macd_line, signal_period)
# 计算柱状图
macd_value = macd_line[-1]
signal_value = signal_line[-1]
histogram = macd_value - signal_value
return {
'macd': macd_value,
'signal': signal_value,
'histogram': histogram
}
@staticmethod
def calculate_bollinger_bands(
prices: List[float],
period: int = 20,
std_dev: float = 2.0
) -> Optional[Dict[str, float]]:
"""
计算布林带
Args:
prices: 价格列表(从旧到新)
period: 计算周期默认20
std_dev: 标准差倍数默认2.0
Returns:
{'upper': 上轨, 'middle': 中轨, 'lower': 下轨}
"""
if len(prices) < period:
return None
recent_prices = prices[-period:]
middle = sum(recent_prices) / len(recent_prices)
# 计算标准差
variance = sum((p - middle) ** 2 for p in recent_prices) / len(recent_prices)
std = math.sqrt(variance)
upper = middle + (std * std_dev)
lower = middle - (std * std_dev)
return {
'upper': upper,
'middle': middle,
'lower': lower
}
@staticmethod
def calculate_atr(
high_prices: List[float],
low_prices: List[float],
close_prices: List[float],
period: int = 14
) -> Optional[float]:
"""
计算ATR平均真实波幅
Args:
high_prices: 最高价列表
low_prices: 最低价列表
close_prices: 收盘价列表
period: 计算周期默认14
Returns:
ATR值绝对价格
"""
if len(high_prices) < period + 1 or len(low_prices) < period + 1 or len(close_prices) < period + 1:
return None
true_ranges = []
for i in range(1, len(high_prices)):
tr1 = high_prices[i] - low_prices[i]
tr2 = abs(high_prices[i] - close_prices[i - 1])
tr3 = abs(low_prices[i] - close_prices[i - 1])
true_range = max(tr1, tr2, tr3)
true_ranges.append(true_range)
if len(true_ranges) < period:
return None
recent_tr = true_ranges[-period:]
atr = sum(recent_tr) / len(recent_tr)
return atr
@staticmethod
def calculate_atr_percent(
high_prices: List[float],
low_prices: List[float],
close_prices: List[float],
current_price: float,
period: int = 14
) -> Optional[float]:
"""
计算ATR百分比ATR相对于当前价格的百分比
Args:
high_prices: 最高价列表
low_prices: 最低价列表
close_prices: 收盘价列表
current_price: 当前价格(用于计算百分比)
period: 计算周期默认14
Returns:
ATR百分比0.03表示3%None表示数据不足
"""
atr = TechnicalIndicators.calculate_atr(high_prices, low_prices, close_prices, period)
if atr is None or current_price <= 0:
return None
atr_percent = atr / current_price
return atr_percent
@staticmethod
def calculate_ema(prices: List[float], period: int = 20) -> Optional[float]:
"""
计算EMA指数移动平均
Args:
prices: 价格列表(从旧到新)
period: 计算周期
Returns:
EMA值
"""
if len(prices) < period:
return None
multiplier = 2 / (period + 1)
ema_value = prices[0]
for price in prices[1:]:
ema_value = (price - ema_value) * multiplier + ema_value
return ema_value
@staticmethod
def calculate_sma(prices: List[float], period: int = 20) -> Optional[float]:
"""
计算SMA简单移动平均
Args:
prices: 价格列表(从旧到新)
period: 计算周期
Returns:
SMA值
"""
if len(prices) < period:
return None
recent_prices = prices[-period:]
return sum(recent_prices) / len(recent_prices)
@staticmethod
def detect_market_regime(
prices: List[float],
short_period: int = 20,
long_period: int = 50
) -> str:
"""
判断市场状态:趋势或震荡
Args:
prices: 价格列表
short_period: 短期均线周期
long_period: 长期均线周期
Returns:
'trending''ranging'
"""
if len(prices) < long_period:
return 'unknown'
short_ma = TechnicalIndicators.calculate_sma(prices, short_period)
long_ma = TechnicalIndicators.calculate_sma(prices, long_period)
if short_ma is None or long_ma is None:
return 'unknown'
# 计算价格波动率
recent_prices = prices[-20:]
if len(recent_prices) < 2:
return 'unknown'
volatility = sum(abs(recent_prices[i] - recent_prices[i-1]) for i in range(1, len(recent_prices))) / len(recent_prices)
avg_price = sum(recent_prices) / len(recent_prices)
volatility_pct = (volatility / avg_price) * 100 if avg_price > 0 else 0
# 如果短期均线明显高于或低于长期均线,且波动率较大,判断为趋势
ma_diff_pct = abs(short_ma - long_ma) / long_ma * 100 if long_ma > 0 else 0
if ma_diff_pct > 2 and volatility_pct > 1:
return 'trending'
else:
return 'ranging'