Skip to content

Commit

Permalink
[microNPU] Add support for TFLite PAD (#13732)
Browse files Browse the repository at this point in the history
A separate nn.pad relay operator is legalized to an Ethos-U depthwise_conv2d operator.
For ethosu_depthwise_conv2d the hardware only supports padding up to 31, 31, 32, 32, 32,
so the pad size for legalization on the NPU is within these limits.
  • Loading branch information
Aleksei-grovety authored Jan 9, 2023
1 parent 5db453e commit 6b65a59
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 0 deletions.
57 changes: 57 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1355,6 +1355,62 @@ def callback(self, pre, post, node_map):
return ethosu_fc


class PadRewriter(DFPatternCallback):
"""Convert ethos-u.pad2d composite function to ethosu_depthwise_conv2d
operator"""

def __init__(self):
super().__init__(require_type=True)
self.pattern = (
wildcard().has_attr({"Composite": ethosu_patterns.PadParams.composite_name})
)(wildcard())

def callback(
self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map
) -> tvm.relay.Expr:
params = ethosu_patterns.PadParams(post.op.body)
params.ifm.tensor = post.args[0]
channels_map = {
"NHWC": 3,
}
w_h, w_w = (1, 1)
# OHWI format for the ethosu_depthwise_conv2d kernel weights
weight_shape = (params.ifm.shape[-1], w_h, w_w, 1)
weights = relay.const(np.full(weight_shape, 1), params.ifm.dtype)
scale_bias = vela_api.pack_biases(
biases=np.zeros(params.ifm.shape[-1]),
ifm_scale=params.ifm.q_params.scale_f32,
ifm_dtype=np.dtype(params.ifm.dtype),
weight_scales=np.array(1.0, dtype=np.float32),
ofm_scale=params.ofm.q_params.scale_f32,
is_activation_tanh_or_sigmoid=False,
)

return ethosu_ops.ethosu_depthwise_conv2d(
ifm=post.args[0],
weight=weights,
scale_bias=relay.const(scale_bias, "uint8"),
lut=relay.const([], "int8"),
ifm_scale=float(params.ifm.q_params.scale_f32),
ifm_zero_point=int(params.ifm.q_params.zero_point.item()),
weight_zero_point=0,
ofm_scale=float(params.ofm.q_params.scale_f32),
ofm_zero_point=int(params.ofm.q_params.zero_point.item()),
kernel_shape=(w_h, w_w),
ofm_channels=params.ofm.shape[channels_map[str(params.ofm.layout)]],
strides=(1, 1),
padding=params.padding,
dilation=(1, 1),
activation="NONE",
clip_min=0,
clip_max=0,
upscale="NONE",
ifm_layout=str(params.ifm.layout),
ofm_layout=str(params.ofm.layout),
ofm_dtype=str(params.ofm.dtype),
)


@util.create_npu_function_pass(opt_level=1)
class LegalizeEthosU:
"""This is the pass to call graph-rewrites to perform graph transformation
Expand All @@ -1375,6 +1431,7 @@ def transform_npu_function(self, _, func: relay.Function) -> relay.Function:
FullyConnectedRewriter(),
MaxPoolingRewriter(),
AvgPoolingRewriter(),
PadRewriter(),
AddRewriter(),
SubRewriter(),
MulRewriter(),
Expand Down
10 changes: 10 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,16 @@ class QDenseArgs(Enum):
WEIGHTS_SCALE = 5


class QPad2DArgs(Enum):
"""
This is a helper enum to obtain the correct index
of nn.pad arguments.
"""

IFM = 0
IFM_ZERO_POINT = 1


def is_npu_func(func: relay.Function) -> bool:
"""Check if the given function is an NPU function."""
return func.attrs and "Compiler" in func.attrs and func.attrs["Compiler"] == "ethos-u"
Expand Down
85 changes: 85 additions & 0 deletions python/tvm/relay/op/contrib/ethosu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1772,6 +1772,86 @@ def hard_swish_pattern():
return quantize


class PadParams:
"""
This class will parse a call to a ethosu.pad2d composite function
and extract the parameter information.
"""

composite_name = "ethos-u.pad2d"
# The ethos-u.pad2d composite function will be transformed to the
# ethosu_depthwise_conv2d operator.
# For the ethosu_depthwise_conv2d the hardware only supports padding
# upto the numbers as follows, so we define such padding limits
padding_bounds = [31, 31, 32, 32]

def __init__(self, func_body: Call):
from tvm.relay.backend.contrib.ethosu.util import QPad2DArgs

# there is no 'layout' attribute in nn.pad
layout = "NHWC"
self.ifm = TensorParams(
tensor=func_body.args[QPad2DArgs.IFM.value],
layout=layout,
scale=tvm.relay.Constant(tvm.nd.array(np.array(1.0, dtype="float32"))),
zero_point=func_body.args[QPad2DArgs.IFM_ZERO_POINT.value],
)

self.padding = self.extract_padding(func_body)
self.ofm = TensorParams(
tensor=func_body,
layout=layout,
scale=tvm.relay.Constant(tvm.nd.array(np.array(1.0, dtype="float32"))),
zero_point=func_body.args[QPad2DArgs.IFM_ZERO_POINT.value],
)

@staticmethod
def extract_padding(
padding: relay.Call,
) -> Optional[Tuple[int, int, int, int]]:
"""
Here we check whether a separate padding operation can be rewritten
as NPU depthwise convolution. If the padding specified by the
separate nn.pad operation is not supported by NPU depthwise convolution,
None will be returned. This will cause the nn.pad not to be offloaded to NPU.
"""
pad_width = padding.attrs["pad_width"]
if len(pad_width) != 4:
return None
if list(pad_width[0]) != [0, 0] or list(pad_width[3]) != [0, 0]:
return None
return [
pad_width[1][0],
pad_width[2][0],
pad_width[1][1],
pad_width[2][1],
]

def is_valid(self):
"""
This function checks whether pad has compatible attributes
with the NPU depthwise convolution
"""
tensor_params = [self.ifm, self.ofm]
if not check_valid_dtypes(tensor_params, supported_dtypes=[np.uint8, np.int8]):
return False
if self.ifm.dtype != self.ofm.dtype:
return False
if not check_batch_size(self.ifm):
return False
if not self.padding or not check_padding(self.padding, self.padding_bounds):
return False
if not check_dimensions(self.ifm) or not check_dimensions(self.ofm):
return False
return True


def pad_pattern():
"""Create pattern for pad"""
pattern = is_op("nn.pad")(wildcard(), is_constant())
return pattern


@register_pattern_table("ethos-u")
def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]]:
return [
Expand Down Expand Up @@ -1805,6 +1885,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal
qnn_avgpool2d_pattern(),
lambda pat: AvgPool2DParams(pat).is_valid(),
),
(
PadParams.composite_name,
pad_pattern(),
lambda pat: PadParams(pat).is_valid(),
),
(
AddParams.composite_name,
qnn_add_pattern(),
Expand Down
23 changes: 23 additions & 0 deletions tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,29 @@ def depthwise_conv2d(x):
infra.compare_tvm_with_tflite(depthwise_conv2d, [ifm_shape], "ethos-u55-256")


@pytest.mark.parametrize("ifm_shape", [(1, 55, 55, 3), (1, 23, 32, 7)])
@pytest.mark.parametrize("padding", [(0, 1, 0, 0), (1, 1, 1, 1), (1, 1, 5, 5)])
@pytest.mark.parametrize("const_value", [0, 5, 125, -5])
def test_tflite_separate_pad(
ifm_shape,
padding,
const_value,
):

np.random.seed(0)

@tf.function
def pad2d(x):
return tf.pad(
x,
[[0, 0], [padding[0], padding[2]], [padding[1], padding[3]], [0, 0]],
"CONSTANT",
const_value,
)

infra.compare_tvm_with_tflite(pad2d, [ifm_shape], "ethos-u55-256")


@pytest.mark.parametrize(
"accel_type",
ACCEL_TYPES,
Expand Down

0 comments on commit 6b65a59

Please sign in to comment.