diff --git a/stockstats.py b/stockstats.py index 8d9c474..a4101b0 100644 --- a/stockstats.py +++ b/stockstats.py @@ -85,8 +85,13 @@ class StockDataFrame(pd.DataFrame): CCI = 14 + RSI = 14 + VR = 26 + WAVE_TREND_1 = 10 + WAVE_TREND_2 = 21 + KAMA_SLOW = 34 KAMA_FAST = 5 @@ -337,7 +342,7 @@ def _get_rsv(cls, df, window): # noinspection PyUnresolvedReferences @classmethod - def _get_rsi(cls, df, window): + def _get_rsi(cls, df, window=None): """ Calculate the RSI (Relative Strength Index) within N periods calculated based on the formula at: @@ -347,6 +352,11 @@ def _get_rsi(cls, df, window): :param window: number of periods :return: None """ + if window is None: + window = cls.RSI + column_name = 'rsi' + else: + column_name = 'rsi_{}'.format(window) window = cls.get_int_positive(window) change = cls._delta(df['close'], -1) @@ -356,9 +366,56 @@ def _get_rsi(cls, df, window): n_ema = cls._smma(close_nm, window) rs_column_name = 'rs_{}'.format(window) - rsi_column_name = 'rsi_{}'.format(window) df[rs_column_name] = rs = p_ema / n_ema - df[rsi_column_name] = 100 - 100 / (1.0 + rs) + df[column_name] = 100 - 100 / (1.0 + rs) + + @classmethod + def _get_stochrsi(cls, df, window=None): + """ Calculate the Stochastic RSI + + calculated based on the formula at: + https://www.investopedia.com/terms/s/stochrsi.asp + + :param df: data + :param window: number of periods + :return: None + """ + if window is None: + window = cls.RSI + column_name = 'stochrsi' + else: + column_name = 'stochrsi_{}'.format(window) + window = cls.get_int_positive(window) + + rsi = df['rsi_{}'.format(window)] + rsi_min = cls._mov_min(rsi, window) + rsi_max = cls._mov_max(rsi, window) + + cv = (rsi - rsi_min) / (rsi_max - rsi_min) + df[column_name] = cv * 100 + + @classmethod + def _get_wave_trend(cls, df): + """ Calculate LazyBear's Wavetrend + Check the algorithm described below: + https://medium.com/@samuel.mcculloch/lets-take-a-look-at-wavetrend-with-crosses-lazybear-s-indicator-2ece1737f72f + + n1: period of EMA on typical price + n2: period of EMA + + :param df: data frame + :return: None + """ + n1 = cls.WAVE_TREND_1 + n2 = cls.WAVE_TREND_2 + + tp = cls._middle(df) + esa = cls._ema(tp, n1) + d = cls._ema((tp - esa).abs(), n1) + ci = (tp - esa) / (0.015 * d) + tci = cls._ema(ci, n2) + df["wt1"] = tci + df["wt2"] = cls._sma(tci, 4) @classmethod def _smma(cls, series, window): @@ -1108,6 +1165,8 @@ def _get_cross(df, key): def __init_not_exist_column(cls, df, key): handlers = { ('change',): cls._get_change, + ('rsi',): cls._get_rsi, + ('stochrsi',): cls._get_stochrsi, ('rate',): cls._get_rate, ('middle',): cls._get_middle, ('boll', 'boll_ub', 'boll_lb'): cls._get_boll, @@ -1128,6 +1187,7 @@ def __init_not_exist_column(cls, df, key): ('chop',): cls._get_chop, ('log-ret',): cls._get_log_ret, ('mfi',): cls._get_mfi, + ('wt1', 'wt2'): cls._get_wave_trend, } for names, handler in handlers.items(): if key in names: diff --git a/test.py b/test.py index 704c9ac..153084d 100644 --- a/test.py +++ b/test.py @@ -165,7 +165,6 @@ def test_column_min(self): def test_column_shift_positive(self): stock = self.get_stock_20day() close_s = stock['close_2_s'] - print(close_s) assert_that(close_s.loc[20110118], equal_to(12.48)) assert_that(close_s.loc[20110119], equal_to(12.48)) assert_that(close_s.loc[20110120], equal_to(12.48)) @@ -413,13 +412,26 @@ def test_rsv_nan_value(self): assert_that(df['rsv_9'][0], equal_to(0.0)) def test_get_rsi(self): - self._supor.get('rsi_6') - self._supor.get('rsi_12') - self._supor.get('rsi_24') + rsi = self._supor.get('rsi') + rsi_6 = self._supor.get('rsi_6') + rsi_12 = self._supor.get('rsi_12') + rsi_14 = self._supor.get('rsi_14') + rsi_24 = self._supor.get('rsi_24') idx = 20160817 - assert_that(self._supor.loc[idx, 'rsi_6'], near_to(71.3114)) - assert_that(self._supor.loc[idx, 'rsi_12'], near_to(63.1125)) - assert_that(self._supor.loc[idx, 'rsi_24'], near_to(61.3064)) + assert_that(rsi_6.loc[idx], near_to(71.3114)) + assert_that(rsi_12.loc[idx], near_to(63.1125)) + assert_that(rsi_24.loc[idx], near_to(61.3064)) + assert_that(rsi.loc[idx], near_to(rsi_14.loc[idx])) + + def test_get_stoch_rsi(self): + stock = self.get_stock_90day() + stoch_rsi = stock['stochrsi'] + stoch_rsi_6 = stock['stochrsi_6'] + stoch_rsi_14 = stock['stochrsi_14'] + idx = 20110331 + assert_that(stoch_rsi.loc[idx], near_to(67.0955)) + assert_that(stoch_rsi_6.loc[idx], near_to(27.5693)) + assert_that(stoch_rsi_14.loc[idx], near_to(stoch_rsi.loc[idx])) def test_get_wr(self): self._supor.get('wr_10') @@ -576,3 +588,10 @@ def test_column_conflict(self): idx = 20110331 assert_that(res['close_26_ema'].loc[idx], near_to(13.2488)) assert_that(res['macd'].loc[idx], near_to(0.1482)) + + def test_wave_trend(self): + stock = self.get_stock_90day() + wt1, wt2 = stock['wt1'], stock['wt2'] + idx = 20110331 + assert_that(wt1.loc[idx], near_to(38.9610)) + assert_that(wt2.loc[idx], near_to(31.6997))