-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
27 lines (24 loc) · 1001 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import numpy as np
def sequence_padding(inputs, length=None, value=0, seq_dims=1, mode='post'):
"""Numpy函数,将序列padding到同一长度
"""
if length is None:
length = np.max([np.shape(x)[:seq_dims] for x in inputs], axis=0)
elif not hasattr(length, '__getitem__'):
length = [length]
slices = [np.s_[:length[i]] for i in range(seq_dims)]
slices = tuple(slices) if len(slices) > 1 else slices[0]
pad_width = [(0, 0) for _ in np.shape(inputs[0])]
outputs = []
for x in inputs:
x = x[slices]
for i in range(seq_dims):
if mode == 'post':
pad_width[i] = (0, length[i] - np.shape(x)[i])
elif mode == 'pre':
pad_width[i] = (length[i] - np.shape(x)[i], 0)
else:
raise ValueError('"mode" argument must be "post" or "pre".')
x = np.pad(x, pad_width, 'constant', constant_values=value)
outputs.append(x)
return np.array(outputs)