forked from wbbhcb/stock_market
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Account.py
306 lines (268 loc) · 13.8 KB
/
Account.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
import numpy as np
import pandas as pd
class Account:
def __init__(self, money_init, start_date='', end_date='', stop_loss_rate=-0.03, stop_profit_rate=0.05,
max_hold_period=5):
self.cash = money_init # 现金
self.stock_value = 0 # 股票价值
self.market_value = money_init # 总市值
self.stock_name = [] # 记录持仓股票名字
self.stock_id = [] # 记录持仓股票id
self.buy_date = [] # 记录持仓股票买入日期
self.stock_num = [] # 记录持股股票剩余持股数量
self.stock_price = [] # 记录股票的买入价格
self.start_date = start_date
self.end_date = end_date
self.stock_asset = [] # 持仓数量
self.buy_rate = 0.0003 # 买入费率
self.buy_min = 5 # 最小买入费率
self.sell_rate = 0.0003 # 卖出费率
self.sell_min = 5 # 最大买入费率
self.stamp_duty = 0.001 # 印花税
# self.info = [] # 记录所有买入卖出记录
self.max_hold_period = max_hold_period # 最大持股周期
self.hold_day = [] # 股票持股时间
self.cost = [] # 记录真实花费
# self.profit = [] # 记录每次卖出股票收益
self.stop_loss_rate = stop_loss_rate # 止损比例
self.stop_profit_rate = stop_profit_rate # 止盈比例
self.victory = 0 # 记录交易胜利次数
self.defeat = 0 # 记录失败次数
self.cash_all = [money_init] # 记录每天收盘后所持现金
self.stock_value_all = [0.0] # 记录每天收盘后所持股票的市值
self.market_value_all = [money_init] # 记录每天收盘后的总市值
self.max_market_value = money_init # 记录最大的市值情况,用来计算回撤
self.min_after_max_makret_value = money_init # 记录最大市值后的最小市值
self.max_retracement = 0 #记录最大回撤率
self.info = pd.DataFrame(columns=['ts_code', 'name', 'buy_price', 'buy_date', 'buy_num', 'sell_price', 'sell_date',
'profit'])
# 股票买入
def buy_stock(self, buy_date, stock_name, stock_id, stock_price, buy_num):
"""
:param buy_date: 买入日期
:param stock_name: 买入股票的名字
:param stock_id: 买入股票的id
:param stcok_price: 买入股票的价格
:param buy_num: 买入股票的数量
:return:
"""
tmp_len = len(self.info)
if stock_id not in self.stock_id:
self.stock_id.append(stock_id)
self.stock_name.append(stock_name)
self.buy_date.append(buy_date)
self.stock_price.append(stock_price)
self.hold_day.append(1)
self.info.loc[tmp_len, 'ts_code'] = stock_id
self.info.loc[tmp_len, 'name'] = stock_name
self.info.loc[tmp_len, 'buy_price'] = stock_price
self.info.loc[tmp_len, 'buy_date'] = buy_date
# 更新市值、现金及股票价值
tmp_money = stock_price * buy_num
service_change = tmp_money * self.buy_rate
if service_change < self.buy_min:
service_change = self.buy_min
self.cash = self.cash - tmp_money - service_change
if self.cash < 0:
buy_num = buy_num - 100
tmp_money = stock_price * buy_num
service_change = tmp_money * self.buy_rate
if service_change < self.buy_min:
service_change = self.buy_min
self.cash = self.cash - tmp_money - service_change
self.info.loc[tmp_len, 'buy_num'] = buy_num
self.stock_num.append(buy_num)
# self.stock_value = self.stock_value + tmp_money
# self.market_value = self.cash + self.stock_value
self.cost.append(tmp_money + service_change)
info = str(buy_date) + ' 买入 ' + stock_name + ' (' + stock_id + ') ' \
+ str(int(buy_num)) + '股,股价:'+str(stock_price)+',花费:' + str(round(tmp_money, 2)) + ',手续费:' \
+ str(round(service_change, 2)) + ',剩余现金:' + str(round(self.cash, 2))
print(info)
# self.info.append(info)
def sell_stock(self, sell_date, stock_name, stock_id, sell_price, sell_num, flag):
"""
:param sell_date: 卖出日期
:param stock_name: 卖出股票的名字
:param stock_id: 卖出股票的id
:param sell_price: 卖出股票的价格
:param sell_num: 卖出股票的数量
:return:
"""
if stock_id not in self.stock_id:
raise TypeError('该股票未买入')
idx = self.stock_id.index(stock_id)
tmp_money = sell_num * sell_price
service_change = tmp_money * self.sell_rate
if service_change < self.sell_min:
service_change = self.sell_min
stamp_duty = self.stamp_duty * tmp_money
self.cash = self.cash + tmp_money - service_change - stamp_duty
# self.stock_value = self.stock_value - tmp_money
# self.market_value = self.cash + self.stock_value
service_change = stamp_duty + service_change
# self.profit.append(tmp_money-service_change)
profit = tmp_money-service_change - self.cost[idx]
if self.stock_num[idx] == sell_num:
# 全部卖出
del self.stock_num[idx]
del self.stock_id[idx]
del self.stock_name[idx]
del self.buy_date[idx]
del self.stock_price[idx]
del self.hold_day[idx]
del self.cost[idx]
else:
self.stock_num[idx] = self.stock_num[idx] - sell_num
# 还需要补充profit的计算先放着
pass
if flag == 0:
info = str(sell_date) + ' 到期卖出' + stock_name + ' (' + stock_id + ') ' \
+ str(int(sell_num)) + '股,股价:'+str(sell_price) + ',收入:' + str(round(tmp_money,2)) + ',手续费:' \
+ str(round(service_change, 2)) + ',剩余现金:' + str(round(self.cash, 2))
if profit > 0:
info = info + ',最终盈利:' + str(round(profit, 2))
self.victory += 1
else:
info = info + ',最终亏损:' + str(round(profit, 2))
self.defeat += 1
elif flag == 1:
info = str(sell_date) + ' 止盈卖出' + stock_name + ' (' + stock_id + ') ' \
+ str(int(sell_num)) + '股,股价:' + str(sell_price) + ',收入:' + str(round(tmp_money, 2)) + ',手续费:' \
+ str(round(service_change, 2)) + ',剩余现金:' + str(round(self.cash, 2)) \
+ ',最终盈利:' + str(round(profit, 2))
self.victory += 1
elif flag == 2:
info = str(sell_date) + ' 止损卖出' + stock_name + ' (' + stock_id + ') ' \
+ str(int(sell_num)) + '股,股价:' + str(sell_price) + ',收入:' + str(round(tmp_money, 2)) + ',手续费:' \
+ str(round(service_change, 2)) + ',剩余现金:' + str(round(self.cash, 2)) \
+ ',最终亏损:' + str(round(profit, 2))
self.defeat += 1
print(info)
idx = (self.info['ts_code'] == stock_id) & self.info['sell_date'].isna()
self.info.loc[idx, 'sell_date'] = sell_date
self.info.loc[idx, 'sell_price'] = sell_price
self.info.loc[idx, 'profit'] = profit
# 买入触发时间,后期可以补
def buy_trigger(self):
pass
# 判断是否达到卖出条件
def sell_trigger(self, stock_id, day, all_df, index_df):
"""
:param stock_id: 股票id
:param day: 回测时间
:param all_df: 所有数据的DataFrame
:param index_df: 指数的DataFram
:return: 第一个返回是否卖出,第二个返回卖出类型,第三个返回
卖出价格;若不卖出,后面两个值无意义
"""
# print(day, stock_id)
# 可能会有一些停牌企业,后期再改
idx = (all_df['trade_date'] == day) & (all_df['ts_code'] == stock_id)
# print(all_df[idx]['low'])
low = all_df[idx]['low'].values[0]
high = all_df[idx]['high'].values[0]
open = all_df[idx]['open'].values[0]
close = all_df[idx]['close'].values[0]
idx = self.stock_id.index(stock_id)
tmp_rate = (open - self.stock_price[idx]) / self.stock_price[idx]
if tmp_rate <= self.stop_loss_rate: # 止损卖出,开盘价卖出
return True, 2, open
elif tmp_rate >= self.stop_profit_rate: # 止盈卖出,开盘价卖出
return True, 1, open
# 这里有点bug,先判断最低吧,优先出现最差的可能
tmp_rate = (low - self.stock_price[idx]) / self.stock_price[idx]
if tmp_rate <= self.stop_loss_rate: # 止损卖出,止损价卖出
# 假设都止损价不能马上卖出,多损失 0.01%
sell_price = self.stock_price[idx] * (1 + self.stop_loss_rate - 0.01)
return True, 2, sell_price
tmp_rate = (high - self.stock_price[idx]) / self.stock_price[idx]
if tmp_rate >= self.stop_profit_rate: # 止盈卖出,止盈价卖出
sell_price = self.stock_price[idx] * (1 + self.stop_profit_rate)
return True, 1, sell_price
# 判断持股周期是否达到上限
hold_day = self.hold_day[idx]
if hold_day >= self.max_hold_period: # 收盘价卖出
return True, 0, close
return False, 3, 0
# 更新信息
def update(self, day, all_df):
stock_price = []
for j in range(len(self.stock_id)):
self.hold_day[j] = self.hold_day[j] + 1 # 更新持股时间
idx = (all_df['trade_date'] == day) & (all_df['ts_code'] == self.stock_id[j])
close = all_df.loc[idx]['close'].values[0]
stock_price.append(close)
# 更新市值等信息
# print(stock_price)
stock_price = np.array(stock_price)
stock_num = np.array(self.stock_num)
self.stock_value = np.sum(stock_num * stock_price)
self.market_value = self.cash + self.stock_value
self.market_value_all.append(self.market_value)
self.stock_value_all.append(self.stock_value)
self.cash_all.append(self.cash)
if self.max_market_value < self.market_value:
self.max_market_value = self.market_value
self.min_after_max_makret_value = 99999999999
else:
if self.min_after_max_makret_value > self.market_value:
self.min_after_max_makret_value = self.market_value
# 计算回撤率
retracement = np.abs((self.max_market_value - self.min_after_max_makret_value) / self.max_market_value)
if retracement > self.max_retracement:
self.max_retracement = retracement
def BackTest(self, buy_df, all_df, index_df, buy_price='close'):
"""
:param buy_df: 可以买入的股票,输入为DataFrame
:param all_df: 所有股票的DataFrame
:param index_df: 指数对应时间的df
:return:
"""
day_info = np.sort(index_df['trade_date'])
for i in range(len(day_info)):
day = day_info[i]
tmp_idx = buy_df['trade_date'] == day
# tmp_df = buy_df.loc[tmp_idx].reset_index()
tmp_df = buy_df.loc[tmp_idx].sort_values('label_prob', ascending=False).reset_index()
# 先买后卖吧
# ----买股
if len(tmp_df) != 0:
for j in range(len(tmp_df)):
money = self.market_value * 0.2
if money > self.cash:
money = self.cash
if money < 5000: # 假设小于5000RMB,就不买股票
break
# print(1)
# print(tmp_df)
# print(tmp_df['close'])
buy_num = (money / tmp_df[buy_price][j]) // 100
if buy_num == 0:
continue
buy_num = buy_num * 100
self.buy_stock(day, tmp_df['name'][j],
tmp_df['ts_code'][j], tmp_df[buy_price][j], buy_num)
# ----卖股
# import datetime
# start = datetime.datetime.now()
for j in range(len(self.stock_id) - 1, -1, -1):
if self.buy_date[j] == day:
continue
stock_id = self.stock_id[j]
stock_name = self.stock_name[j]
sell_num = self.stock_num[j] # 假设全卖出去
is_sell, sell_kind, sell_price = self.sell_trigger(stock_id, day, all_df, index_df)
if is_sell:
self.sell_stock(day, stock_name, stock_id, sell_price, sell_num, sell_kind)
# 更新持股周期及信息
self.update(day, all_df)
# end = datetime.datetime.now()
# print('running time:%s'%(end-start))
# self.info['buy_date'] = self.info['buy_date'].apply(lambda x: int(x))
# self.info['sell_date'] = self.info['sell_date'].apply(lambda x: int(x))
# self.info['buy_num'] = self.info['buy_num'].apply(lambda x: int(x))
try:
self.info[['buy_date', 'sell_date', 'buy_num']] = self.info[['buy_date', 'sell_date', 'buy_num']].astype(int)
except:
pass