Skip to content

Commit

Permalink
add dot product attempt #2
Browse files Browse the repository at this point in the history
  • Loading branch information
driazati committed Jan 10, 2022
1 parent f6f252f commit 9570ab6
Show file tree
Hide file tree
Showing 9 changed files with 128 additions and 0 deletions.
1 change: 1 addition & 0 deletions python/tvm/_ffi/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def c2pyerror(err_msg):
trace_mode = False
stack_trace = []
message = []
return err_msg, err_type
for line in arr:
if trace_mode:
if line.startswith(" "):
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2852,6 +2852,10 @@ def einsum(self, inputs, input_types):
equation, data = inputs
return _op.einsum(data, equation)

def dot(self, inputs, _):
lhs, rhs = inputs
return _op.dot(lhs, rhs)

# Operator mappings
def create_convert_map(self):
self.convert_map = {
Expand Down Expand Up @@ -3076,6 +3080,7 @@ def create_convert_map(self):
"aten::bucketize": self.bucketize,
"aten::roll": self.roll,
"aten::einsum": self.einsum,
"aten::dot": self.dot,
}

def update_convert_map(self, custom_map):
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/op/_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,5 @@

# einsum
_reg.register_strategy("einsum", strategy.einsum_strategy)

_reg.register_strategy("dot", strategy.dot_strategy)
21 changes: 21 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1819,3 +1819,24 @@ def einsum_strategy(attrs, inputs, out_type, target):
name="einsum.generic",
)
return strategy


def wrap_compute_dot(topi_compute):
"""Wrap dot topi compute"""

def _compute_dot(attrs, inputs, _):
return [topi_compute(inputs[0], inputs[1])]

return _compute_dot


@override_native_generic_func("dot_strategy")
def dot_strategy(attrs, inputs, out_type, target):
"""dot generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_dot(topi.dot),
wrap_topi_schedule(topi.generic.schedule_dot),
name="dot.generic",
)
return strategy
15 changes: 15 additions & 0 deletions python/tvm/relay/op/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1136,6 +1136,21 @@ def einsum(data, equation):
return _make.einsum(Tuple(data), equation)


def dot(lhs, rhs):
"""Compute the dot product of two 1D tensors
Parameters
----------
lhs : relay.Expr input tensor
rhs : relay.Expr input tensor
Returns
-------
result : relay.Expr
The output tensor from the dot op.
"""
return _make.dot(lhs, rhs)

def stack(data, axis):
"""Join a sequence of arrays along a new axis.
Expand Down
8 changes: 8 additions & 0 deletions python/tvm/topi/einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,11 @@ def einsum(subscripts, *operand):
"""

return cpp.einsum(subscripts, operand)


def dot(lhs, rhs):
# TODO: Move this out of einsum.py
print("running topi.dot")
print(lhs)
print(rhs)
exit(0)
6 changes: 6 additions & 0 deletions python/tvm/topi/generic/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,9 @@ def schedule_einsum(outs):
The computation schedule for the op.
"""
return _default_schedule(outs, False)


def schedule_dot(outs):
print("SCHEDULING DOT")
print(outs)
exit(0)
32 changes: 32 additions & 0 deletions src/relay/op/tensor/math.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,5 +111,37 @@ on the operands)doc" TVM_ADD_FILELINE)
.set_attr<FTVMCompute>("FTVMCompute", EinsumCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

bool DotProductRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
ICHECK_EQ(types.size(), 3) << "Expected 2 types for dot product, found " << types.size();

const auto lhs = types[0].as<TensorTypeNode>();
// TODO: Is this the correct way to verify? The tutorial still returns 'false'
ICHECK(lhs != nullptr) << "Expected tensor for input 0";
const auto rhs = types[1].as<TensorTypeNode>();
ICHECK(rhs != nullptr) << "Expected tensor for input 1";

// TODO: How to unify tensor dtypes? e.g. [float32] dot [float64]
reporter->Assign(types[2], TensorType(lhs->shape, lhs->dtype));

return true;
}

RELAY_REGISTER_OP("dot")
.describe(R"doc(Compute the dot product of two 1D tensors)doc" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("lhs", "tensor", "left hand side of dot product")
.add_argument("rhs", "tensor", "right hand side of dot product")
.set_support_level(11)
.add_type_rel("dot", DotProductRel)
.set_attr<TOpPattern>("TOpPattern", kOpaque);

Expr MakeDot(Expr lhs, Expr rhs) {
static const auto& op = Op::Get("dot");
return Call(op, {lhs, rhs}, Attrs(), {});
}

TVM_REGISTER_GLOBAL("relay.op._make.dot").set_body_typed(MakeDot);

} // namespace relay
} // namespace tvm
38 changes: 38 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch
import tvm
from tvm import relay
from tvm.contrib import graph_executor


def fn(x):
return x.dot(x)


def dump_mod(mod):
s = str(mod).split("\n")
print("\n".join(s[113:]))


print("Tracing fn")
model = torch.jit.trace(fn, (torch.ones(2),))
info = [(i.debugName(), tuple(i.type().sizes())) for i in model.graph.inputs()]

# Generate module
print("Importing to relay")
mod, params = relay.frontend.from_pytorch(model, input_infos=info)
dump_mod(mod)

# Execute
print("Lowering to llvm")
target = tvm.target.Target("llvm", host="llvm")
dev = tvm.cpu(0)
with tvm.transform.PassContext(opt_level=0):
lib = relay.build(mod, target=target, params=params)
m = graph_executor.GraphModule(lib["default"](dev))
i = torch.ones(2)
m.set_input("x", tvm.nd.array(i))
m.run()

# Get outputs
tvm_output = m.get_output(0)
print(tvm_output)

0 comments on commit 9570ab6

Please sign in to comment.