Skip to content

Commit

Permalink
[Relax][ONNX] Add Multiple ONNX Frontend Support for Clip / Equal / S…
Browse files Browse the repository at this point in the history
…hape / Not / Tanh (#3)

* Rebase w/ Equal, Not, Tanh, Sqrt, Relu, Clip, Conv, Pow, Erf.

* Fix cumsum but still needs work.
  • Loading branch information
zxybazh authored Jan 30, 2023
1 parent e9488b6 commit 957f266
Show file tree
Hide file tree
Showing 2 changed files with 382 additions and 18 deletions.
164 changes: 159 additions & 5 deletions python/tvm/relax/frontend/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def get_converter(cls, opset):
return getattr(cls, "_impl_v{}".format(version))
raise NotImplementedError(
"opset version {} of {} not implemented".format(version, cls.__name__)
)
)


class MatMul(OnnxOpConverter):
Expand All @@ -135,41 +135,50 @@ class MatMul(OnnxOpConverter):
def _impl_v13(cls, bb, inputs, attr):
return bb.emit_te(topi.matmul, inputs[0], inputs[1])


class Div(OnnxOpConverter):
"""Converts an onnx Div node into an equivalent Relax expression."""

@classmethod
def _impl_v14(cls, bb, inputs, attr):
return bb.emit_te(topi.divide, inputs[0], inputs[1])


class Sigmoid(OnnxOpConverter):
"""Converts an onnx Sigmoid node into an equivalent Relax expression."""

@classmethod
def _impl_v13(cls, bb, inputs, attr):
return bb.emit_te(topi.sigmoid, inputs[0])


class Softmax(OnnxOpConverter):
"""Converts an onnx Softmax node into an equivalent Relax expression."""

@classmethod
def _impl_v13(cls, bb, inputs, attr):
axis = attr.get("axis", -1)
return bb.emit_te(topi.nn.softmax, inputs[0], axis=axis)


class Transpose(OnnxOpConverter):
"""Converts an onnx Transpose node into an equivalent Relax expression."""

@classmethod
def _impl_v13(cls, bb, inputs, attr):
perm = attr.get("perm", None)
return bb.emit_te(topi.transpose, inputs[0], axes=perm)


class Unsqueeze(OnnxOpConverter):
"""Converts an onnx Unsqueeze node into an equivalent Relax expression."""

@classmethod
def _impl_v13(cls, bb, inputs, attr):
input = inputs[0]
axes = inputs[1]

if (isinstance(axes, relax.Constant)):
if isinstance(axes, relax.Constant):
constant_axes = list(axes.data.numpy())
constant_axes = list(map(int, constant_axes))
constant_axes = sorted(constant_axes)
Expand All @@ -179,6 +188,7 @@ def _impl_v13(cls, bb, inputs, attr):

raise NotImplementedError("Unsqueeze with dynamic axes is not supported.")


class Concat(OnnxOpConverter):
"""Convert an onnx Concat node into an equivalent Relax expression."""

Expand Down Expand Up @@ -207,6 +217,8 @@ def _impl_v13(cls, bb, inputs, attr):
class Cast(OnnxOpConverter):
"""Convert an onnx Cast node into an equivalent Relax expression."""

"""Convert an onnx Cast node into an equivalent Relax expression."""

@classmethod
def _impl_v13(cls, bb, inputs, attr):
to_type = get_type(attr["to"])
Expand All @@ -216,6 +228,8 @@ def _impl_v13(cls, bb, inputs, attr):
class Gather(OnnxOpConverter):
"""Convert an onnx Gather node into an equivalent Relax expression."""

"""Convert an onnx Gather node into an equivalent Relax expression."""

@classmethod
def _impl_v13(cls, bb, inputs, attr):
# TODO This assumes positive only indices.
Expand Down Expand Up @@ -255,16 +269,20 @@ def _impl_v13(cls, bb, inputs, attr):
class Reshape(OnnxOpConverter):
"""Convert an onnx Reshape node into an equivalent Relax expression."""

"""Convert an onnx Reshape node into an equivalent Relax expression."""

@classmethod
def _impl_v13(cls, bb, inputs, attr):
from tvm.script import relax as R

data = inputs[0]
# TODO We assume new_shape is a constant, need to enable tensor input to reshape
# for full support.
new_shape = inputs[1].data.numpy()

# Convert -1 dims in new_shape into positive equivalent.
if -1 in new_shape:
breakpoint()
data_shape = [dim.value for dim in data.shape.values]
total_elements = np.prod(data_shape)
new_product = 1
Expand All @@ -277,14 +295,15 @@ def _impl_v13(cls, bb, inputs, attr):
if dim == -1:
new_shape[i] = int(total_elements / new_product)


return bb.emit_te(topi.reshape, data, new_shape)


class Gelu(OnnxOpConverter):
"""Operator converter for Gelu from Microsoft onnxruntime contrib opset.
gelu(x) = 0.5x(1 + erf(x/sqrt(2)))
"""

@classmethod
def _impl_v1(cls, bb, inputs, attr):
x = inputs[0]
Expand All @@ -297,15 +316,17 @@ def _impl_v1(cls, bb, inputs, attr):

# Compute gelu
term1 = bb.emit_te(topi.multiply, half, x)
erf = bb.emit_te(topi.erf, bb.emit_te(topi.divide, x, sqrt2))
erf = bb.emit_te(topi.erf, bb.emit_te(topi.divide, x, sqrt2))
term2 = bb.emit_te(topi.add, one, erf)
return bb.emit_te(topi.multiply, term1, term2)


class BiasGelu(OnnxOpConverter):
"""Operator converter for BiasGelu from Microsoft onnxruntime contrib opset.
bias_gelu(x, b) = 0.5(x + b)(1 + erf((x + b)/sqrt(2)))
"""

@classmethod
def _impl_v1(cls, bb, inputs, attr):
x = inputs[0]
Expand All @@ -317,12 +338,134 @@ def _impl_v1(cls, bb, inputs, attr):
inp = bb.emit_te(topi.add, x, b)
return Gelu._impl_v1(bb, [inp], attr)


class Where(OnnxOpConverter):
"""Convert an onnx Where node into an equivalent Relax expression."""

@classmethod
def _impl_v16(cls, bb, inputs, attr):
return bb.emit_te(topi.where, *inputs)


class Clip(OnnxOpConverter):
"""Converts an onnx Clip node into an equivalent Relax expression."""

@classmethod
def _impl_v13(cls, bb, inputs, attr):
results = inputs[0]
if len(inputs) >= 2:
results = bb.emit_te(topi.maximum, results, inputs[1])
if len(inputs) >= 3:
results = bb.emit_te(topi.minimum, results, inputs[2])
return results


class Equal(OnnxOpConverter):
"""Converts an onnx Equal node into an equivalent Relax expression."""

@classmethod
def _impl_v13(cls, bb, inputs, attr):
return bb.emit_te(topi.equal, inputs[0], inputs[1])


class Shape(OnnxOpConverter):
"""Converts an onnx Equal node into an equivalent Relax expression."""

@classmethod
def _impl_v13(cls, bb, inputs, attr):
return bb.emit_te(topi.shape, inputs[0], inputs[1])


class Not(OnnxOpConverter):
"""Converts an onnx Not node into an equivalent Relax expression."""

@classmethod
def _impl_v13(cls, bb, inputs, attr):
return bb.emit_te(topi.bitwise_not, inputs[0])


class Tanh(OnnxOpConverter):
"""Converts an onnx Tanh node into an equivalent Relax expression."""

@classmethod
def _impl_v13(cls, bb, inputs, attr):
return bb.emit_te(topi.tanh, inputs[0])


class Sqrt(OnnxOpConverter):
"""Converts an onnx Sqrt node into an equivalent Relax expression."""

@classmethod
def _impl_v13(cls, bb, inputs, attr):
return bb.emit_te(topi.sqrt, inputs[0])


class Relu(OnnxOpConverter):
"""Converts an onnx Relu node into an equivalent Relax expression."""

@classmethod
def _impl_v13(cls, bb, inputs, attr):
return bb.emit_te(topi.nn.relu, inputs[0])


class Pow(OnnxOpConverter):
"""Converts an onnx Pow node into an equivalent Relax expression."""

@classmethod
def _impl_v13(cls, bb, inputs, attr):
return bb.emit_te(topi.power, inputs[0], inputs[1])


class Conv(OnnxOpConverter):
"""Convert an onnx Conv node into an equivalent Relax expression."""

@classmethod
def _impl_v13(cls, bb, inputs, attr):
# not supported yet
assert "auto_pad" not in attr
assert "group" not in attr
# supported conv2d
return bb.emit_te(
topi.add,
bb.emit_te(
topi.nn.conv2d,
inputs[0],
inputs[1],
strides=attr.get("strides", 1),
padding=attr.get("pads", 0),
dilation=attr.get("dilations", 1),
),
bb.emit_te(topi.expand_dims, inputs[2], axis=1, num_newaxis=2),
)


class Erf(OnnxOpConverter):
"""Converts an onnx Erf node into an equivalent Relax expression."""

@classmethod
def _impl_v13(cls, bb, inputs, attr):
return bb.emit_te(topi.erf, inputs[0])


class CumSum(OnnxOpConverter):
"""Converts an onnx CumSum node into an equivalent Relax expression."""

@classmethod
def _impl_v13(cls, bb, inputs, attr):
assert getattr(attr, "reverse", 0) == 0, "reverse is not supported yet"
if len(inputs) > 1:
# axis = int(infer_value(inputs[1], params).numpy())
axis = inputs[1]
else:
axis = None
return bb.emit_te(
topi.cumsum,
data=inputs[0],
axis=axis,
exclusive=attr.get("exclusive", None),
)


def _get_convert_map(opset):
return {
"MatMul": MatMul.get_converter(opset),
Expand All @@ -341,6 +484,17 @@ def _get_convert_map(opset):
"Gelu": Gelu.get_converter(opset),
"BiasGelu": BiasGelu.get_converter(opset),
"Where": Where.get_converter(opset),
"Clip": Clip.get_converter(opset),
"Equal": Equal.get_converter(opset),
"Shape": Shape.get_converter(opset),
"Not": Not.get_converter(opset),
"Tanh": Tanh.get_converter(opset),
"Sqrt": Sqrt.get_converter(opset),
"Relu": Relu.get_converter(opset),
"Conv": Conv.get_converter(opset),
"Pow": Pow.get_converter(opset),
"Erf": Erf.get_converter(opset),
"CumSum": CumSum.get_converter(opset),
}


Expand Down Expand Up @@ -630,4 +784,4 @@ def from_onnx(model, shape=None, dtype="float32", opset=None):
)

# Use the graph proto as a scope so that ops can access other nodes if needed.
return g.from_onnx(graph, opset)
return g.from_onnx(graph, opset)
Loading

0 comments on commit 957f266

Please sign in to comment.