Skip to content

Commit

Permalink
asdf
Browse files Browse the repository at this point in the history
  • Loading branch information
ajtulloch committed Apr 5, 2019
1 parent bc02b8e commit 6537bdb
Showing 1 changed file with 79 additions and 7 deletions.
86 changes: 79 additions & 7 deletions wavernn/pytorch_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import collections
import tvm
from tvm import relay

from tvm import autotvm
import itertools
import scipy.sparse as sp
torch.manual_seed(42)


Expand Down Expand Up @@ -229,6 +231,66 @@ def build_fast_wavernn_module(target="llvm"):

Cell = collections.namedtuple('Cell', ['weight_ih', 'weight_hh', 'bias_ih', 'bias_hh'])

BFLOAT16 = True

BSR = collections.namedtuple(
'BSR',
['data', 'indices', 'indptr', 'N', 'K', 'BS_R', 'BS_C', 'density'])

def random_bsr_matrix(M, N, BS_R, BS_C, density):
Y = np.zeros((M, N), dtype="float32")
assert M % BS_R == 0
assert N % BS_C == 0
nnz = int(density * M * N)
num_blocks = int(nnz / (BS_R * BS_C)) + 1
candidate_blocks = np.asarray(list(itertools.product(range(0, M, BS_R), range(0, N, BS_C))))
assert candidate_blocks.shape[0] == M // BS_R * N // BS_C
chosen_blocks = candidate_blocks[np.random.choice(candidate_blocks.shape[0], size=num_blocks, replace=False)]
for i in range(len(chosen_blocks)):
r, c = chosen_blocks[i]
Y[r:r + BS_R, c:c + BS_C] = np.random.randn(BS_R, BS_C).astype("float32")
s = sp.bsr_matrix(Y, blocksize=(BS_R, BS_C))
assert s.data.shape == (num_blocks, BS_R, BS_C)
assert s.indices.shape == (num_blocks, )
assert s.indptr.shape == (M // BS_R + 1, )
return s

def to_bf16(x):
assert x.dtype == np.float32
return ((x.view('<u4') + 2 ** 15) >> 16).astype("uint16")

def instantiate(param):
if isinstance(param, BSR):
param_np = random_bsr_matrix(M=param.N, N=param.K, BS_R=param.BS_R, BS_C=param.BS_C, density=param.density)
return [
(param.data.name_hint, tvm.ndarray.array(param_np.data.astype("uint16" if BFLOAT16 else "float32"))),
(param.indices.name_hint, tvm.ndarray.array(param_np.indices.astype("int32"))),
(param.indptr.name_hint, tvm.ndarray.array(param_np.indptr.astype("int32"))),
]
else:
return [(
param.name_hint,
tvm.ndarray.array(
np.random.randn(*param.type_annotation.concrete_shape).astype(
param.type_annotation.dtype)
)
)
]

def to_sparse(v, params, density=0.05, BS_R=16, BS_C=1):
name = v.name_hint
(N, K) = v.type_annotation.concrete_shape
nnz = int(density * N * K)
num_blocks = int(nnz / (BS_R * BS_C)) + 1
v_data = relay.var(name + "_data", shape=(num_blocks, BS_R, BS_C), dtype="uint16" if BFLOAT16 else "float32")
v_indices = relay.var(name + "_indices", shape=(num_blocks,), dtype="int32")
v_indptr = relay.var(name + "_indptr", shape=(N // BS_R + 1,), dtype="int32")
param_np = random_bsr_matrix(M=N, N=K, BS_R=BS_R, BS_C=BS_C, density=density)
params[name + "_data"] = to_bf16(param_np.data) if BFLOAT16 else param_np.data
params[name + "_indices"] = param_np.indices.astype("int32")
params[name + "_indptr"] = param_np.indptr.astype("int32")
return BSR(data=v_data, indices=v_indices, indptr=v_indptr, N=N, K=K, BS_R=BS_R, BS_C=BS_C, density=density)

def approx_exp(x):
x = relay.minimum(relay.maximum(x, C(-88.0)), C(88.0))
x = C(127.0) + x * C(1.44268504)
Expand Down Expand Up @@ -256,7 +318,7 @@ def C(x):
return relay.expr.const(x, "float32")

def sparse_dense(X, W, B, **kwargs):
return relay.nn.bias_add(relay.nn.dense(X, W), B)
return relay.nn.bias_add(relay.nn.sparse_dense(X, W), B)

def dense(X, W, B, **kwargs):
return relay.nn.bias_add(relay.nn.dense(X, W), B)
Expand All @@ -274,15 +336,15 @@ def gru_cell(cell, x, h):
xconcat_trns = dense(Rx, RI_W, RI_B) + RI_residual

Rrnn1 = Cell(
weight_ih=relay.var("rnn1_weight_ih", shape=(3 * rnn_dims, rnn_dims), dtype="float32"),
weight_hh=relay.var("rnn1_weight_hh", shape=(3 * rnn_dims, rnn_dims), dtype="float32"),
weight_ih=to_sparse(relay.var("rnn1_weight_ih", shape=(3 * rnn_dims, rnn_dims), dtype="float32"), params),
weight_hh=to_sparse(relay.var("rnn1_weight_hh", shape=(3 * rnn_dims, rnn_dims), dtype="float32"), params),
bias_ih=relay.var("rnn1_bias_ih", shape=(3 * rnn_dims, ), dtype="float32"),
bias_hh=relay.var("rnn1_bias_hh", shape=(3 * rnn_dims, ), dtype="float32"),
)
h1 = gru_cell(Rrnn1, xconcat_trns, Rh1)
xres = xconcat_trns + h1

Rfc1_W = relay.var("fc1_W", shape=(fc_dims, rnn_dims), dtype="float32")
Rfc1_W = to_sparse(relay.var("fc1_W", shape=(fc_dims, rnn_dims), dtype="float32"), params)
Rfc1_B = relay.var("fc1_B", shape=(fc_dims,), dtype="float32")

x_fc = relay.nn.relu(sparse_dense(xres, Rfc1_W, Rfc1_B) + Rfc1_residual)
Expand All @@ -295,7 +357,16 @@ def gru_cell(cell, x, h):
outputs = relay.expr.Tuple([x_prob, h1])
func = relay.Function(relay.ir_pass.free_vars(outputs), outputs)
func = relay.ir_pass.infer_type(func)
graph, lib, params = relay.build_module.build(func, target=target, params=params)
TARGET = tvm.target.create(target)
log_filename = "lpcnet_no_bf16_autotvm_skl.log"

with autotvm.apply_history_best(log_filename):
with relay.build_config(opt_level=3):
func = relay.optimize(func, target=TARGET, params=params)
print(func.astext(show_meta_data=False))
func = relay.ir_pass.infer_type(func)
graph, lib, new_params = relay.build_module.build(
func, target=TARGET, params=params)
return (graph, lib, params)


Expand Down Expand Up @@ -383,7 +454,7 @@ def test_relay_cpp_frame():
np.testing.assert_allclose(h1_ref, h1_new, rtol=1e-4, atol=1e-4)


(graph, lib, params) = build_fast_wavernn_module("llvm -mcpu=core-avx2 -target=x86_64-linux-gnu")
(graph, lib, params) = build_fast_wavernn_module("llvm -mcpu=skylake-avx512 -target=x86_64-linux-gnu")
with open(
"fast_wavernn_rnn_dims_{rnn_dims}_fc_dims_{fc_dims}_feat_dims_{feat_dims}_aux_dims_{aux_dims}_graph.json".format(**globals()),
"w") as f:
Expand All @@ -395,3 +466,4 @@ def test_relay_cpp_frame():
f.write(relay.save_param_dict(params))

lib.save("fast_wavernn_rnn_dims_{rnn_dims}_fc_dims_{fc_dims}_feat_dims_{feat_dims}_aux_dims_{aux_dims}_lib.o".format(**globals()))
lib.export_library("fast_wavernn_rnn_dims_{rnn_dims}_fc_dims_{fc_dims}_feat_dims_{feat_dims}_aux_dims_{aux_dims}_lib.so".format(**globals()))

0 comments on commit 6537bdb

Please sign in to comment.