diff --git a/python/tvm/_ffi/base.py b/python/tvm/_ffi/base.py index 164ea53ba3a25..1c4d8ad1595eb 100644 --- a/python/tvm/_ffi/base.py +++ b/python/tvm/_ffi/base.py @@ -253,6 +253,7 @@ def c2pyerror(err_msg): trace_mode = False stack_trace = [] message = [] + return err_msg, err_type for line in arr: if trace_mode: if line.startswith(" "): diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index f299041f10dc2..5ef8f21c82715 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2852,6 +2852,10 @@ def einsum(self, inputs, input_types): equation, data = inputs return _op.einsum(data, equation) + def dot(self, inputs, _): + lhs, rhs = inputs + return _op.dot(lhs, rhs) + # Operator mappings def create_convert_map(self): self.convert_map = { @@ -3076,6 +3080,7 @@ def create_convert_map(self): "aten::bucketize": self.bucketize, "aten::roll": self.roll, "aten::einsum": self.einsum, + "aten::dot": self.dot, } def update_convert_map(self, custom_map): diff --git a/python/tvm/relay/op/_math.py b/python/tvm/relay/op/_math.py index ff74fafcef75b..bcf4065776cdc 100644 --- a/python/tvm/relay/op/_math.py +++ b/python/tvm/relay/op/_math.py @@ -20,3 +20,5 @@ # einsum _reg.register_strategy("einsum", strategy.einsum_strategy) + +_reg.register_strategy("dot", strategy.dot_strategy) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index d7f0dda92c6d3..666a2508212d6 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1819,3 +1819,24 @@ def einsum_strategy(attrs, inputs, out_type, target): name="einsum.generic", ) return strategy + + +def wrap_compute_dot(topi_compute): + """Wrap dot topi compute""" + + def _compute_dot(attrs, inputs, _): + return [topi_compute(inputs[0], inputs[1])] + + return _compute_dot + + +@override_native_generic_func("dot_strategy") +def dot_strategy(attrs, inputs, out_type, target): + """dot generic strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_dot(topi.dot), + wrap_topi_schedule(topi.generic.schedule_dot), + name="dot.generic", + ) + return strategy \ No newline at end of file diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index 0c930dd1153c2..3a5a20502c6c4 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -1136,6 +1136,25 @@ def einsum(data, equation): return _make.einsum(Tuple(data), equation) +def dot(lhs, rhs): + """Compute the dot product of two 1D tensors + + Parameters + ---------- + lhs : relay.Expr input tensor + rhs : relay.Expr input tensor + + Returns + ------- + result : relay.Expr + The output tensor from the dot op. + """ + mul_result = _make.multiply(lhs, rhs) + axis = None + keepdims = False + exclude = False + return _make.sum(mul_result, axis, keepdims, exclude) + def stack(data, axis): """Join a sequence of arrays along a new axis. diff --git a/python/tvm/topi/einsum.py b/python/tvm/topi/einsum.py index f1f426ec81738..35fd698d6965f 100644 --- a/python/tvm/topi/einsum.py +++ b/python/tvm/topi/einsum.py @@ -42,3 +42,11 @@ def einsum(subscripts, *operand): """ return cpp.einsum(subscripts, operand) + + +def dot(lhs, rhs): + # TODO: Move this out of einsum.py + print("running topi.dot") + print(lhs) + print(rhs) + exit(0) \ No newline at end of file diff --git a/python/tvm/topi/generic/math.py b/python/tvm/topi/generic/math.py index 3af6cd16a3741..fdd637b9056f6 100644 --- a/python/tvm/topi/generic/math.py +++ b/python/tvm/topi/generic/math.py @@ -32,3 +32,9 @@ def schedule_einsum(outs): The computation schedule for the op. """ return _default_schedule(outs, False) + + +def schedule_dot(outs): + print("SCHEDULING DOT") + print(outs) + exit(0) \ No newline at end of file diff --git a/src/relay/op/tensor/math.cc b/src/relay/op/tensor/math.cc index 246fba62cc66f..82b94a264af0b 100644 --- a/src/relay/op/tensor/math.cc +++ b/src/relay/op/tensor/math.cc @@ -111,5 +111,37 @@ on the operands)doc" TVM_ADD_FILELINE) .set_attr("FTVMCompute", EinsumCompute) .set_attr("TOpPattern", kInjective); +bool DotProductRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(types.size(), 3) << "Expected 2 types for dot product, found " << types.size(); + + const auto lhs = types[0].as(); + // TODO: Is this the correct way to verify? The tutorial still returns 'false' + ICHECK(lhs != nullptr) << "Expected tensor for input 0"; + const auto rhs = types[1].as(); + ICHECK(rhs != nullptr) << "Expected tensor for input 1"; + + // TODO: How to unify tensor dtypes? e.g. [float32] dot [float64] + reporter->Assign(types[2], TensorType(lhs->shape, lhs->dtype)); + + return true; +} + +RELAY_REGISTER_OP("dot") + .describe(R"doc(Compute the dot product of two 1D tensors)doc" TVM_ADD_FILELINE) + .set_num_inputs(2) + .add_argument("lhs", "tensor", "left hand side of dot product") + .add_argument("rhs", "tensor", "right hand side of dot product") + .set_support_level(11) + .add_type_rel("dot", DotProductRel) + .set_attr("TOpPattern", kOpaque); + +Expr MakeDot(Expr lhs, Expr rhs) { + static const auto& op = Op::Get("dot"); + return Call(op, {lhs, rhs}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.dot").set_body_typed(MakeDot); + } // namespace relay } // namespace tvm diff --git a/test.py b/test.py new file mode 100644 index 0000000000000..158ad63db491a --- /dev/null +++ b/test.py @@ -0,0 +1,38 @@ +import torch +import tvm +from tvm import relay +from tvm.contrib import graph_executor + + +def fn(x): + return x.dot(x) + + +def dump_mod(mod): + s = str(mod).split("\n") + print("\n".join(s[113:])) + + +print("Tracing fn") +model = torch.jit.trace(fn, (torch.ones(2),)) +info = [(i.debugName(), tuple(i.type().sizes())) for i in model.graph.inputs()] + +# Generate module +print("Importing to relay") +mod, params = relay.frontend.from_pytorch(model, input_infos=info) +dump_mod(mod) + +# Execute +print("Lowering to llvm") +target = tvm.target.Target("llvm", host="llvm") +dev = tvm.cpu(0) +with tvm.transform.PassContext(opt_level=0): + lib = relay.build(mod, target=target, params=params) +m = graph_executor.GraphModule(lib["default"](dev)) +i = torch.ones(2) +m.set_input("x", tvm.nd.array(i)) +m.run() + +# Get outputs +tvm_output = m.get_output(0) +print(tvm_output) \ No newline at end of file diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index a64fa0bb91aa0..ea70640170690 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -4086,6 +4086,14 @@ def test_fn(equation): verify_model(test_fn("ij,jk"), [x, y]) verify_model(test_fn("ij,jk,km->im"), [x, y, z]) +@tvm.testing.uses_gpu +def test_dot(): + def test_fn(x): + return x.dot(x) + + x = torch.ones([2]) + verify_model(test_fn, [x]) + if __name__ == "__main__": pytest.main([__file__])