Skip to content

Commit

Permalink
[GH-76] Add StochRSI and wave trend (#106)
Browse files Browse the repository at this point in the history
Add StochRSI.  The default window is set to 14.
Use `StockDataFrame.RSI` to change it.

Add wave trand that includes two lines.  The default window of the ema
is set to `10` and `21`.
Use `StockDataFrame.WAVE_TREND_1` and `StockDataFrame.WAVE_TREND_2` to
change them.
  • Loading branch information
jealous authored Jan 3, 2022
1 parent c8262c7 commit 2b2d9ee
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 10 deletions.
66 changes: 63 additions & 3 deletions stockstats.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,13 @@ class StockDataFrame(pd.DataFrame):

CCI = 14

RSI = 14

VR = 26

WAVE_TREND_1 = 10
WAVE_TREND_2 = 21

KAMA_SLOW = 34
KAMA_FAST = 5

Expand Down Expand Up @@ -337,7 +342,7 @@ def _get_rsv(cls, df, window):

# noinspection PyUnresolvedReferences
@classmethod
def _get_rsi(cls, df, window):
def _get_rsi(cls, df, window=None):
""" Calculate the RSI (Relative Strength Index) within N periods
calculated based on the formula at:
Expand All @@ -347,6 +352,11 @@ def _get_rsi(cls, df, window):
:param window: number of periods
:return: None
"""
if window is None:
window = cls.RSI
column_name = 'rsi'
else:
column_name = 'rsi_{}'.format(window)
window = cls.get_int_positive(window)

change = cls._delta(df['close'], -1)
Expand All @@ -356,9 +366,56 @@ def _get_rsi(cls, df, window):
n_ema = cls._smma(close_nm, window)

rs_column_name = 'rs_{}'.format(window)
rsi_column_name = 'rsi_{}'.format(window)
df[rs_column_name] = rs = p_ema / n_ema
df[rsi_column_name] = 100 - 100 / (1.0 + rs)
df[column_name] = 100 - 100 / (1.0 + rs)

@classmethod
def _get_stochrsi(cls, df, window=None):
""" Calculate the Stochastic RSI
calculated based on the formula at:
https://www.investopedia.com/terms/s/stochrsi.asp
:param df: data
:param window: number of periods
:return: None
"""
if window is None:
window = cls.RSI
column_name = 'stochrsi'
else:
column_name = 'stochrsi_{}'.format(window)
window = cls.get_int_positive(window)

rsi = df['rsi_{}'.format(window)]
rsi_min = cls._mov_min(rsi, window)
rsi_max = cls._mov_max(rsi, window)

cv = (rsi - rsi_min) / (rsi_max - rsi_min)
df[column_name] = cv * 100

@classmethod
def _get_wave_trend(cls, df):
""" Calculate LazyBear's Wavetrend
Check the algorithm described below:
https://medium.com/@samuel.mcculloch/lets-take-a-look-at-wavetrend-with-crosses-lazybear-s-indicator-2ece1737f72f
n1: period of EMA on typical price
n2: period of EMA
:param df: data frame
:return: None
"""
n1 = cls.WAVE_TREND_1
n2 = cls.WAVE_TREND_2

tp = cls._middle(df)
esa = cls._ema(tp, n1)
d = cls._ema((tp - esa).abs(), n1)
ci = (tp - esa) / (0.015 * d)
tci = cls._ema(ci, n2)
df["wt1"] = tci
df["wt2"] = cls._sma(tci, 4)

@classmethod
def _smma(cls, series, window):
Expand Down Expand Up @@ -1108,6 +1165,8 @@ def _get_cross(df, key):
def __init_not_exist_column(cls, df, key):
handlers = {
('change',): cls._get_change,
('rsi',): cls._get_rsi,
('stochrsi',): cls._get_stochrsi,
('rate',): cls._get_rate,
('middle',): cls._get_middle,
('boll', 'boll_ub', 'boll_lb'): cls._get_boll,
Expand All @@ -1128,6 +1187,7 @@ def __init_not_exist_column(cls, df, key):
('chop',): cls._get_chop,
('log-ret',): cls._get_log_ret,
('mfi',): cls._get_mfi,
('wt1', 'wt2'): cls._get_wave_trend,
}
for names, handler in handlers.items():
if key in names:
Expand Down
33 changes: 26 additions & 7 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ def test_column_min(self):
def test_column_shift_positive(self):
stock = self.get_stock_20day()
close_s = stock['close_2_s']
print(close_s)
assert_that(close_s.loc[20110118], equal_to(12.48))
assert_that(close_s.loc[20110119], equal_to(12.48))
assert_that(close_s.loc[20110120], equal_to(12.48))
Expand Down Expand Up @@ -413,13 +412,26 @@ def test_rsv_nan_value(self):
assert_that(df['rsv_9'][0], equal_to(0.0))

def test_get_rsi(self):
self._supor.get('rsi_6')
self._supor.get('rsi_12')
self._supor.get('rsi_24')
rsi = self._supor.get('rsi')
rsi_6 = self._supor.get('rsi_6')
rsi_12 = self._supor.get('rsi_12')
rsi_14 = self._supor.get('rsi_14')
rsi_24 = self._supor.get('rsi_24')
idx = 20160817
assert_that(self._supor.loc[idx, 'rsi_6'], near_to(71.3114))
assert_that(self._supor.loc[idx, 'rsi_12'], near_to(63.1125))
assert_that(self._supor.loc[idx, 'rsi_24'], near_to(61.3064))
assert_that(rsi_6.loc[idx], near_to(71.3114))
assert_that(rsi_12.loc[idx], near_to(63.1125))
assert_that(rsi_24.loc[idx], near_to(61.3064))
assert_that(rsi.loc[idx], near_to(rsi_14.loc[idx]))

def test_get_stoch_rsi(self):
stock = self.get_stock_90day()
stoch_rsi = stock['stochrsi']
stoch_rsi_6 = stock['stochrsi_6']
stoch_rsi_14 = stock['stochrsi_14']
idx = 20110331
assert_that(stoch_rsi.loc[idx], near_to(67.0955))
assert_that(stoch_rsi_6.loc[idx], near_to(27.5693))
assert_that(stoch_rsi_14.loc[idx], near_to(stoch_rsi.loc[idx]))

def test_get_wr(self):
self._supor.get('wr_10')
Expand Down Expand Up @@ -576,3 +588,10 @@ def test_column_conflict(self):
idx = 20110331
assert_that(res['close_26_ema'].loc[idx], near_to(13.2488))
assert_that(res['macd'].loc[idx], near_to(0.1482))

def test_wave_trend(self):
stock = self.get_stock_90day()
wt1, wt2 = stock['wt1'], stock['wt2']
idx = 20110331
assert_that(wt1.loc[idx], near_to(38.9610))
assert_that(wt2.loc[idx], near_to(31.6997))

0 comments on commit 2b2d9ee

Please sign in to comment.