diff --git a/example/warpctc/toy_ctc.py b/example/warpctc/toy_ctc.py index 08a48d2b7f4c..46bab5776018 100644 --- a/example/warpctc/toy_ctc.py +++ b/example/warpctc/toy_ctc.py @@ -36,10 +36,10 @@ def gen_rand(): buf = str(num) while len(buf) < 4: buf = "0" + buf - ret = np.array([]) + ret = [] for i in range(80): c = int(buf[i // 20]) - ret = np.concatenate([ret, gen_feature(c)]) + ret.append(gen_feature(c)) return buf, ret def get_label(buf): @@ -56,7 +56,7 @@ def __init__(self, count, batch_size, num_label, init_states): self.num_label = num_label self.init_states = init_states self.init_state_arrays = [mx.nd.zeros(x[1]) for x in init_states] - self.provide_data = [('data', (batch_size, 10 * 80))] + init_states + self.provide_data = [('data', (batch_size, 80, 10))] + init_states self.provide_label = [('label', (self.batch_size, 4))] def __iter__(self):