Skip to content

Commit

Permalink
remove dynamic behavior from standard reshape
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Brookhart committed Jun 26, 2020
1 parent a9f25aa commit 643b115
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 82 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def convert(self, v):
def __call__(self, args, attrs, type_args):
if attrs is None:
attrs = {}
if self.operator in (op.strided_slice):
if self.operator in (op.strided_slice,):
x = self.operator(*args)
elif self.operator in (op.zeros, op.ones, op.full, op.broadcast_to):
x = self.operator(*args, dtype=attrs["dtype"])
Expand Down
77 changes: 3 additions & 74 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,82 +272,11 @@ def _reshape_shape_func_input_shape(data_shape, newshape, ndim):
out[infer_idx] = old_size // new_size
return out

@script
def _reshape_shape_func_input_data(data, newshape, ndim):
out = output_tensor((ndim,), "int64")
data_shape = allocate((len(data.shape),), "int64")
for x in const_range(len(data.shape)):
data_shape[x] = int64(data.shape[x])
src_idx = 0
dst_idx = 0
infer_idx = -1
copy = False
skip = 0
for i in const_range(len(newshape)):
if skip > 0:
skip -= 1
elif newshape[i] > 0:
out[dst_idx] = int64(newshape[i])
src_idx += 1
dst_idx += 1
elif newshape[i] == 0:
out[dst_idx] = data_shape[src_idx]
src_idx += 1
dst_idx += 1
elif newshape[i] == -1:
assert infer_idx < 0, "One and only one dim can be inferred"
out[dst_idx] = int64(1)
infer_idx = i
dst_idx += 1
elif newshape[i] == -2:
copy = True
elif newshape[i] == -3:
assert data_shape.shape[0] - src_idx > 1, \
"Not enough dims in input shape for -3"
out[dst_idx] = data_shape[src_idx] * data_shape[src_idx+1]
src_idx += 2
dst_idx += 1
elif newshape[i] == -4:
assert len(newshape) - i > 2, "Not enough dims in new shape for -4"
if newshape[i+1] == -1:
assert newshape[i+2] != -1, "Split dims cannot both be -1."
out[dst_idx] = data_shape[src_idx] // int64(newshape[i+2])
out[dst_idx+1] = int64(newshape[i+2])
else:
out[dst_idx] = int64(newshape[i+1])
if newshape[i+2] == -1:
out[dst_idx+1] = data_shape[src_idx] // int64(newshape[i+1])
else:
out[dst_idx+1] = int64(newshape[i+2])
assert data_shape[src_idx] == out[dst_idx] * out[dst_idx+1],\
"Product of split dims doesn't match to input dim"
src_idx += 1
dst_idx += 2
skip = 2
else:
assert False, "Invalid special values in new shape"
if len(data_shape.shape) > 0:
# if data is not constant, we can then handle -1 and -2
if copy:
for i in range(src_idx, data_shape.shape[0]):
out[dst_idx] = data_shape[i]
dst_idx += 1
if infer_idx >= 0:
old_size = int64(1)
for i in const_range(data_shape.shape[0]):
old_size *= data_shape[i]
new_size = int64(1)
for i in const_range(out.shape[0]):
new_size *= out[i]
out[infer_idx] = old_size // new_size
return out

@_reg.register_shape_func("reshape", True)
@_reg.register_shape_func("reshape", False)
def reshape_shape_func(attrs, inputs, out_ndims):
if attrs.newshape is None:
return [_reshape_shape_func_input_data(*inputs, out_ndims[0])]
newshape = get_const_tuple(attrs.newshape)
return [_reshape_shape_func_input_shape(inputs[0],
convert(attrs.newshape),
convert(newshape),
out_ndims[0])]

@script
Expand Down
1 change: 0 additions & 1 deletion python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

from . import _make
from ..expr import TupleWrapper, const
from ...tir import expr as _expr


def cast(data, dtype):
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/dynamic_to_static.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class DynamicToStaticMutator : public MixedModeMutator {
attrs->newshape = ToVector(shape->data);
attrs->reverse = false;
static const Op& reshape = Op::Get("reshape");
return Call(reshape, call_node->args, Attrs(attrs), {});
return Call(reshape, {call_node->args[0]}, Attrs(attrs), {});
}
}
return post;
Expand Down
9 changes: 4 additions & 5 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,11 @@ def verify_any_reshape(x_shape, newshape, x_np_shape, out_shape, variable_newsha
newshape_var = relay.var('newshape', shape=(len(newshape),), dtype='int64')
params.append(newshape_var)
args.append(np.array(newshape, dtype='int64'))
newshape = newshape_var

y = relay.reshape(relu_x, newshape=newshape)
y = relay.dyn.reshape(relu_x, newshape_var)
else:
y = relay.reshape(relu_x, newshape=newshape)
mod = tvm.IRModule()
mod["main"] = relay.Function(params, y)

for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(*args).asnumpy()
Expand All @@ -184,9 +183,9 @@ def test_any_reshape():
# Variable newshape only supports that output rank is the same as newshape
verify_any_reshape(any_dims(3), (1, -1), (2, 3, 4), (1, 24), variable_newshape)
verify_any_reshape(any_dims(3), (0, -1), (2, 3, 4), (2, 12), variable_newshape)
verify_any_reshape(any_dims(3), (-4, 2, -1, -2), (6, 3, 4), (2, 3, 3, 4), variable_newshape)
verify_any_reshape(any_dims(3), (0, -2), (2, 3, 4), (2, 3, 4))
verify_any_reshape(any_dims(3), (-4, -1, 2, -3), (6, 3, 4), (3, 2, 12))
verify_any_reshape(any_dims(3), (-4, 2, -1, -2), (6, 3, 4), (2, 3, 3, 4))

def verify_any_argwhere(x_shape, x_np_shape, dtype="bool"):
x = relay.var('x', shape=x_shape, dtype=dtype)
Expand Down

0 comments on commit 643b115

Please sign in to comment.