-
Notifications
You must be signed in to change notification settings - Fork 3
/
recorder.py
155 lines (129 loc) · 4.79 KB
/
recorder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import datetime
import pandas as pd
import db.db as db
import api.oanda_api as oanda_api
import util.price_util as price_util
import db.table_defs as table_defs
class RecorderError(Exception):
pass
conn = db.conn
time_format = db.time_format
def add_trade_record(trade, table_name):
create_trades_table(table_name)
records = conn.execute('select * from ' + table_name + ' '
+ 'where tradeId = ' + trade['tradeId'] + ';').fetchall()
if len(records) == 0:
df = pd.DataFrame(trade, index=[1])
df.to_sql(table_name, conn, if_exists="append", index=False)
def update_trade_data(table_name):
create_trades_table(table_name)
# tradesテーブルからOPENのtrade_idを取得
open_ids = list(pd.read_sql_query(
'select tradeId from ' + table_name + ' '
+ 'where state=\'OPEN\''
,conn
)['tradeId'])
# 1件も無ければreturn
if len(open_ids) < 1:
return
# リストの中身を文字列型に変換(joinするため)
open_ids = list(map(str, open_ids))
# APIからopen_idのtradeを取得し、DataFrameに追加していく
header = table_defs.get_columns('trades')
fetched_trades = pd.DataFrame(columns=header)
for id in open_ids:
try:
trade = oanda_api.get_trade(id)
except Exception as e:
raise RecorderError(e)
s = pd.Series(trade)
fetched_trades = fetched_trades.append(s,ignore_index=True)
# open_idのレコードをtradesテーブルから削除
conn.execute(
'delete from '+ table_name + ' where tradeId in ('
+ ','.join(open_ids) + ');'
)
conn.commit()
# APIから取得したデータをtradesテーブルに追加
fetched_trades.to_sql(table_name, conn, if_exists="append", index=False)
def update_price_data(time_unit='M', time_count=5, count=60):
table_name = 'prices_{0}{1}'.format(time_unit, time_count)
create_prices_table(table_name)
granularity = '{0}{1}'.format(time_unit, time_count)
params = {
'granularity': granularity,
'count': count
}
# APIから取得してDFに入れる
try:
candles = pd.DataFrame(oanda_api.get_candles(params=params))\
.sort_values('datetime')
except Exception as e:
raise RecorderError(e)
# DBから最新のレコードを取得
last_record = pd.read_sql_query(
'select * from ' + table_name + ' '
'order by datetime desc limit 1;'
,conn
)
# DBにレコードがある時
if not (last_record.empty):
# DBの最新レコードより古いcandleは削除
while not (candles.empty):
last_record_datetime = \
datetime.datetime.strptime(last_record.iloc[0]['datetime'], time_format)
candle_datetime = \
datetime.datetime.strptime(candles.iloc[0]['datetime'], time_format)
if candle_datetime <= last_record_datetime:
# 一番最初の行を削除
candles = candles.drop(candles.head(1).index, axis=0)
else:
break
# DBに書き込み
header = table_defs.get_columns('prices')
candles.reindex(columns=header) \
.to_sql(table_name, conn, if_exists="append", index=False)
# macdを計算
update_macd(table_name)
# bollinger bandを計算
update_bollinger(table_name)
def update_macd(table_name):
max_records = 60
df = pd.read_sql_query(
'select * from ' + table_name + ' '
+ 'order by datetime desc '
+ 'limit ' + str(max_records) + ';'
,conn
).sort_values('datetime')
df = price_util.calc_macd(df)
header = table_defs.get_columns('prices')
df.reindex(columns=header) \
.to_sql(table_name, conn, if_exists="replace", index=False)
def update_bollinger(table_name):
max_records = 60
df = pd.read_sql_query(
'select * from ' + table_name + ' '
+ 'order by datetime desc '
+ 'limit ' + str(max_records) + ';'
,conn
).sort_values('datetime')
df = price_util.calc_bollinger(df)
header = table_defs.get_columns('prices')
df.reindex(columns=header) \
.to_sql(table_name, conn, if_exists="replace", index=False)
def delete_old_trade_data():
table_name = 'trades'
keep_span = datetime.timedelta(weeks=1)
keep_from = (datetime.datetime.now(datetime.timezone.utc)
- keep_span).strftime(time_format)
conn.execute(
'delete from ' + table_name + ' '
'where openTime < \'' + keep_from + '\';'
)
conn.commit()
def create_trades_table(table_name):
sql = table_defs.get_create_table_sql('trades', table_name)
conn.execute(sql)
def create_prices_table(table_name):
sql = table_defs.get_create_table_sql('prices', table_name)
conn.execute(sql)