Skip to content

Commit

Permalink
add dyn tile to dynamic_to_static
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Brookhart committed Jul 1, 2020
1 parent b759177 commit dfbb78e
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
11 changes: 10 additions & 1 deletion src/relay/transforms/dynamic_to_static.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ namespace relay {

class DynamicToStaticMutator : public MixedModeMutator {
public:
DynamicToStaticMutator() : dyn_reshape_op_(Op::Get("dyn.reshape")) {}
DynamicToStaticMutator() : dyn_reshape_op_(Op::Get("dyn.reshape")), dyn_tile_op_(Op::Get("dyn.tile")) {}

private:
Expr Rewrite_(const CallNode* pre, const Expr& post) override {
Expand All @@ -46,6 +46,14 @@ class DynamicToStaticMutator : public MixedModeMutator {
static const Op& reshape = Op::Get("reshape");
return Call(reshape, {call_node->args[0]}, Attrs(attrs), {});
}
} else if (call_node->op == dyn_tile_op_) {
if (const ConstantNode* reps = call_node->args[1].as<ConstantNode>()) {
auto attrs = make_object<TileAttrs>();
CHECK_EQ(reps->data->ndim, 1);
attrs->reps = ToVector(reps->data);
static const Op& op = Op::Get("tile");
return Call(op, {call_node->args[0]}, Attrs(attrs), {});
}
}
return post;
}
Expand All @@ -58,6 +66,7 @@ class DynamicToStaticMutator : public MixedModeMutator {
}

const Op& dyn_reshape_op_;
const Op& dyn_tile_op_;
};

Expr DynamicToStatic(Function f, IRModule m) {
Expand Down
22 changes: 22 additions & 0 deletions tests/python/relay/test_pass_dynamic_to_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,30 @@ def verify_reshape(shape, newshape):
verify_reshape((2, 3, 4), (8, 3))
verify_reshape((4, 7), (2, 7, 2))

def test_dynamic_to_static_tile():
def verify_tile(shape, reps, oshape):
x = relay.var("x", relay.TensorType(shape, "float32"))
y = relay.var("y", relay.TensorType(reps, "float32"))
z = relay.tile(x, relay.shape_of(y))
func = run_infer_type(relay.Function([x, y], z))
func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType())

zz = func2.body
assert isinstance(zz, relay.Call)
assert zz.op == relay.op.get("tile")
assert zz.checked_type == relay.ty.TensorType(oshape, "float32")

x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
y_data = np.random.uniform(low=-1, high=1, size=reps).astype("float32")
ref_res = np.tile(x_data, reps)
verify_func(func2, [x_data, y_data], ref_res)

verify_tile((2, 3, 4), (2, 1, 5), (4, 3, 20))
verify_tile((4, 7), (4, 2), (16, 14))

if __name__=="__main__":
test_dynamic_to_static_reshape()
test_dynamic_to_static_double_reshape()
test_dynamic_to_static_quad_reshape()
test_dynamic_to_static_tile()

0 comments on commit dfbb78e

Please sign in to comment.