Skip to content

Commit

Permalink
Merge pull request apache#31 from heliqi/paddle_frontend
Browse files Browse the repository at this point in the history
paddle frontend
  • Loading branch information
jiangjiajun authored Sep 14, 2021
2 parents 76194aa + 84fcb97 commit 68d9328
Showing 1 changed file with 27 additions and 9 deletions.
36 changes: 27 additions & 9 deletions python/tvm/relay/frontend/paddlepaddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,26 @@ def _infer_value(x, params):
return x


def _convert_dtype_value(val):
"""converts a Paddle type id to a string."""

convert_dtype_map = {
21: "int8",
20: "uint8",
6: "float64",
5: "float32",
4: "float16",
3: "int64",
2: "int32",
1: "int16",
0: "bool",
}
if val not in convert_dtype_map:
msg = "Paddle data type value %d is not handled yet." % (val)
raise NotImplementedError(msg)
return convert_dtype_map[val]


def convert_unary_op(g, op, block):
"""Operator converter for all the activation."""

Expand Down Expand Up @@ -162,14 +182,12 @@ def convert_assign(g, op, block):
def convert_assign_value(g, op, block):
"""Operator converter for assign_value."""

keys = ["fp32_values", "int32_values", "int64_values"]
keys = ["bool_values", "fp32_values", "int32_values", "int64_values"]
for key in keys:
value = np.array(op.attr(key))
if value is not None and value.size >= 1:
break
shape = op.attr("shape")
dtype = block.var(op.output("Out")[0]).dtype
dtype = str(dtype).strip().split(".")[1]
value = value.reshape(shape)
out = _op.const(value)
g.add_node(op.output("Out")[0], out)
Expand Down Expand Up @@ -306,8 +324,8 @@ def convert_interpolate(g, op, block):
def convert_cast(g, op, block):
"""Operator converter for cast."""

dtype = block.var(op.output("Out")[0]).dtype
dtype = str(dtype).strip().split(".")[1]
dtype = op.attr("out_dtype")
dtype = _convert_dtype_value(dtype)
x = g.get_node(op.input("X")[0])
out = _op.cast(x, dtype=dtype)
g.add_node(op.output("Out")[0], out)
Expand Down Expand Up @@ -576,8 +594,8 @@ def convert_fill_constant(g, op, block):

value = op.attr("value")
shape = block.var(op.output("Out")[0]).shape
dtype = block.var(op.output("Out")[0]).dtype
dtype = str(dtype).strip().split(".")[1]
dtype = op.attr("dtype")
dtype = _convert_dtype_value(dtype)
value = _expr.const(value).astype(dtype)
if op.input("ValueTensor"):
shape = g.get_node(op.input("ValueTensor")[0])
Expand All @@ -598,9 +616,9 @@ def convert_fill_constant_batch_size_like(g, op, block):
shape = op.attr("shape")
input_dim_idx = op.attr("input_dim_idx")
output_dim_idx = op.attr("output_dim_idx")
dtype = op.attr("dtype")

dtype = block.var(op.output("Out")[0]).dtype
dtype = str(dtype).strip().split(".")[1]
dtype = _convert_dtype_value(dtype)
input_shape = shape_of(x)
batch = _op.strided_slice(input_shape, begin=[input_dim_idx], end=[input_dim_idx + 1]).astype(
"int32"
Expand Down

0 comments on commit 68d9328

Please sign in to comment.