Skip to content

Commit

Permalink
[microNPU] Support binary elementwise with non-4D inputs (apache#9521)
Browse files Browse the repository at this point in the history
Reshapes non-4D inputs to become 4D, then reshapes the output back to
the non-4D input shape.
  • Loading branch information
lhutton1 authored and ylc committed Jan 13, 2022
1 parent 02ebf08 commit f070877
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 14 deletions.
69 changes: 64 additions & 5 deletions python/tvm/relay/backend/contrib/ethosu/legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,60 @@ def __init__(
self.params_class = params_class
self.pattern = pattern

@staticmethod
def reshape_input(
inputs: List["TensorParams"],
) -> List[tvm.relay.Expr]:
"""Reshape the inputs so that the following binary elementwise
operator receives 4-dimensional inputs.
Parameters
----------
inputs: List[TensorParams]
The inputs to reshape.
Returns
-------
reshaped_inputs: List[tvm.relay.Expr]
The new reshaped inputs.
"""
reshaped_inputs = []
for i in inputs:
in_shape = i.shape
if len(in_shape) < 4:
pad_size = 4 - len(in_shape)
new_shape = ([1] * pad_size) + in_shape
new_call = relay.reshape(i.tensor, new_shape)
reshaped_inputs.append(new_call)
else:
reshaped_inputs.append(i.tensor)
return reshaped_inputs

@staticmethod
def reshape_output(output: tvm.relay.Expr, ifm_input_shape: List[int]) -> tvm.relay.Expr:
"""Reshape the output back to the original dimensionality.
Since the NPU must have the brodcastable tensor as the
second operand, the original shape of the first ifm must
be the output shape.
Parameters
----------
output: tvm.relay.Expr
The output to reshape.
ifm_input_shape: List[int]
The shape of the non-reshaped ifm tensor.
Returns
-------
reshaped_output: tvm.relay.Expr
The reshaped output expression.
"""
if len(ifm_input_shape) == 4:
return output
reshaped_output = relay.reshape(output, ifm_input_shape)
return reshaped_output

def callback(
self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map
) -> tvm.relay.Expr:
Expand All @@ -451,9 +505,12 @@ def callback(
# We don't yet support activation functions that need to get legalized to LUTs.
lut = relay.const([], dtype="int8")

return ethosu_ops.ethosu_binary_elementwise(
ifm=params.ifm.tensor,
ifm2=params.ifm2.tensor,
inputs = [params.ifm, params.ifm2]
inputs = self.reshape_input(inputs)

ethosu_binary_elementwise = ethosu_ops.ethosu_binary_elementwise(
ifm=inputs[0],
ifm2=inputs[1],
lut=lut,
operator_type=params.operator_type,
ifm_scale=float(params.ifm.q_params.scale_f32),
Expand All @@ -462,8 +519,8 @@ def callback(
ifm2_zero_point=int(params.ifm2.q_params.zero_point),
ofm_scale=float(params.ofm.q_params.scale_f32),
ofm_zero_point=int(params.ofm.q_params.zero_point),
ifm_channels=params.ifm.shape[3],
ifm2_channels=params.ifm2.shape[3],
ifm_channels=params.ifm.shape[-1],
ifm2_channels=params.ifm2.shape[-1],
reversed_operands=params.reversed_operands,
ofm_dtype=params.ofm.dtype,
activation=activation,
Expand All @@ -473,6 +530,8 @@ def callback(
ifm2_layout=str(params.ifm2.layout),
ofm_layout=str(params.ofm.layout),
)
output = self.reshape_output(ethosu_binary_elementwise, params.ifm.shape)
return output


class AddRewriter(BinaryElementwiseRewriter):
Expand Down
18 changes: 12 additions & 6 deletions python/tvm/relay/op/contrib/ethosu.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,11 +516,12 @@ def __init__(self, func_body: Call, operator_type: str, has_quantization_paramet
self.activation = clip
self.operator_type = operator_type

def can_broadcast(x, y):
for i in range(1, 4):
if x.shape[i] == y.shape[i] or y.shape[i] == 1:
continue
def can_broadcast(ifm, ifm2):
if len(ifm.shape) < len(ifm2.shape):
return False
for m, n in zip(ifm.shape[::-1], ifm2.shape[::-1]):
if m != n and m == 1:
return False
return True

if can_broadcast(self.ifm, self.ifm2):
Expand All @@ -539,9 +540,14 @@ def is_valid(self):
"""
if np.dtype(self.ofm) == np.int32 and self.activation is not None:
return False
if len(self.ifm.shape) != 4 or len(self.ifm2.shape) != 4:
# Due to identity operator requiring ofm != int32 for now
if np.dtype(self.ofm) == np.int32 and len(self.ofm.shape) < 4:
return False
if self.ifm.shape[0] != 1 or self.ifm2.shape[0] != 1:
if len(self.ifm.shape) > 4 or len(self.ifm2.shape) > 4:
return False
if len(self.ifm.shape) == 4 and self.ifm.shape[0] != 1:
return False
if len(self.ifm2.shape) == 4 and self.ifm2.shape[0] != 1:
return False
if not self.valid_broadcast:
return False
Expand Down
79 changes: 79 additions & 0 deletions tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ def representative_dataset():
([1, 2, 3, 4], [1, 2, 3, 4]),
([1, 2, 3, 4], [1, 1, 1, 1]),
([1, 1, 1, 1], [1, 2, 3, 4]),
([1, 4, 4], [4, 1]),
],
)
@pytest.mark.parametrize("activation_function", ["NONE", "RELU"])
Expand Down Expand Up @@ -435,6 +436,84 @@ def representative_dataset():
infra.verify_source(compiled_models, accel_type)


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@pytest.mark.parametrize(
"ifm_shape, ifm2_shape",
[
([4], [4]),
([4], [1, 2, 3, 4]),
([1, 4, 4], [4, 1]),
],
)
def test_binary_add_with_non_4d_shapes(
accel_type,
ifm_shape,
ifm2_shape,
):
dtype = "int8"

def create_tflite_graph():
class Model(tf.Module):
@tf.function
def tf_function(self, lhs, rhs):
return tf.math.add(lhs, rhs)

model = Model()
concrete_func = model.tf_function.get_concrete_function(
tf.TensorSpec(ifm_shape, dtype=tf.float32), tf.TensorSpec(ifm2_shape, dtype=tf.float32)
)

# Convert the model
def representative_dataset():
for _ in range(100):
data = np.random.rand(*tuple(ifm_shape))
data2 = np.random.rand(*tuple(ifm2_shape)) * 2
yield [data.astype(np.float32), data2.astype(np.float32)]

converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
tflite_model = converter.convert()
return tflite_model

tflite_graph = create_tflite_graph()
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)

mod, params = relay.frontend.from_tflite(
tflite_model,
shape_dict={"ifm": ifm_shape, "ifm2": ifm2_shape},
dtype_dict={"ifm": dtype, "ifm2": dtype},
)
mod = partition_for_ethosu(mod, params)

# Generate reference data
input_data, output_data = infra.generate_ref_data_tflite(tflite_graph)

compiled_models = infra.build_source(
mod,
input_data,
output_data,
accel_type,
output_tolerance=0,
)

# Assumes only two runtime.Modules are created -- i.e. single offload module
imported_modules = compiled_models[0].executor_factory.lib.imported_modules
assert len(imported_modules) == 2
ethosu_module = imported_modules[0]

# Verify generated C source
get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs")
cmms = get_cs(ethosu_module)
cmms = bytes.fromhex(cmms)

infra.print_payload(cmms)
infra.verify_source(compiled_models, accel_type)


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
def test_binary_add_from_constant_scalar(accel_type):
dtype = "uint8"
Expand Down
21 changes: 18 additions & 3 deletions tests/python/contrib/test_ethosu/test_legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,10 @@ def verify(ext_func):
([1, 2, 3, 4], [1, 2, 3, 4], False),
([1, 2, 3, 4], [1, 1, 3, 1], False),
([1, 1, 3, 1], [1, 2, 3, 4], True),
([1, 4, 4], [4, 1], False),
([4], [4], False),
([4], [1, 2, 3, 4], True),
([1, 4, 4], [4, 1], False),
],
)
@pytest.mark.parametrize("activation_function", ["NONE", "RELU"])
Expand Down Expand Up @@ -621,16 +625,27 @@ def verify(ext_func):
shapes = [ifm_shape, ifm2_shape]
ifm_index, ifm2_index = (1, 0) if reversed_operands else (0, 1)
op = ext_func.body
assert list(op.args[0].checked_type.shape) == shapes[ifm_index]
assert list(op.args[1].checked_type.shape) == shapes[ifm2_index]

has_reshaped_output = False
shapes_padded = [[1] * (4 - len(s)) + s for s in shapes]
out_padded = [1] * (4 - len(out_shape)) + out_shape
if op.op.name != "contrib.ethosu.binary_elementwise":
has_reshaped_output = True
op = op.args[0]

assert list(op.args[0].checked_type.shape) == shapes_padded[ifm_index]
assert list(op.args[1].checked_type.shape) == shapes_padded[ifm2_index]
assert op.args[0].checked_type.dtype == dtype
assert list(op.checked_type.shape) == out_shape
assert list(op.checked_type.shape) == out_padded
assert op.checked_type.dtype == dtype
assert op.attrs.operator_type == operator_type
assert op.attrs.reversed_operands == reversed_operands
if activation_function == "RELU":
assert str(op.attrs.activation) == "CLIP"

if has_reshaped_output:
assert list(ext_func.body.checked_type.shape) == out_shape

if operator_type == "ADD":
rewriter = legalize.AddRewriter()
pattern_table = [
Expand Down

0 comments on commit f070877

Please sign in to comment.