Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CMSIS-NN] Aligned scale computation with TFLM to fix numerical mismatch #10817

Merged
merged 8 commits into from
Apr 6, 2022
10 changes: 5 additions & 5 deletions src/relay/backend/contrib/cmsisnn/generate_constants.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ class GenerateConstantsMutator : public MixedModeMutator {
// Obtain input and output scales from Relay's Requantization
int64_t out_channels = conv2d_attrs->channels.as<IntImmNode>()->value;
float output_scale = GetScalarFromConstant<float>(requantize_call->args[3]);
auto input_scales = tvm::relay::qnn::GetFloatVectorFromConstant(requantize_call->args[1]);
ICHECK(input_scales.size() == static_cast<size_t>(out_channels));
auto input_scale = GetScalarFromConstant<float>(conv2d_call->args[4]);
auto filter_scales = tvm::relay::qnn::GetFloatVectorFromConstant(conv2d_call->args[5]);

// Calculate requantization multiplier and shift
Device dev{DLDeviceType::kDLCPU, 0};
Expand All @@ -134,10 +134,10 @@ class GenerateConstantsMutator : public MixedModeMutator {
int32_t* multiplier = static_cast<int32_t*>(multiplier_nda->data);
int32_t* shift = static_cast<int32_t*>(shift_nda->data);
for (int i = 0; i < out_channels; ++i) {
double quantized_multiplier =
static_cast<double>(input_scales[i]) / static_cast<double>(output_scale);
double effective_output_scale =
static_cast<double>(input_scale) * filter_scales[i] / static_cast<double>(output_scale);
std::tie(*(multiplier + i), *(shift + i)) =
tvm::relay::qnn::GetFixedPointMultiplierShift(quantized_multiplier);
tvm::relay::qnn::GetFixedPointMultiplierShift(effective_output_scale);
}

// Create constants from requantization multiplier and shift
Expand Down
40 changes: 39 additions & 1 deletion tests/python/contrib/test_cmsisnn/test_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@
from utils import (
skip_if_no_reference_system,
make_module,
create_conv2d_tflite_relay_models,
get_range_for_dtype_str,
get_same_padding,
get_conv2d_qnn_params,
make_qnn_relu,
assert_partitioned_function,
assert_no_external_function,
generate_ref_data_tflite,
)


Expand Down Expand Up @@ -282,7 +284,6 @@ def test_conv2d_asymmetric_padding_int8(
)
orig_mod = make_module(model)
cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod, params)

# validate pattern matching
assert_partitioned_function(orig_mod, cmsisnn_mod)

Expand All @@ -304,6 +305,43 @@ def test_conv2d_asymmetric_padding_int8(
)


@tvm.testing.requires_cmsisnn
@pytest.mark.parametrize("ifm_shape", [(1, 55, 55, 3)])
@pytest.mark.parametrize("kernel_shape", [(3, 2), (1, 3)])
@pytest.mark.parametrize("strides, dilation", [((3, 2), (1, 1))])
@pytest.mark.parametrize("padding", ["SAME", "VALID"])
@pytest.mark.parametrize("activation", ["NONE", "RELU"])
def test_conv2d_int8_tflite(ifm_shape, kernel_shape, strides, dilation, padding, activation):
interface_api = "c"
use_unpacked_api = True
test_runner = AOT_USMP_CORSTONE300_RUNNER

dtype = "int8"
tflite_model, relay_mod, params = create_conv2d_tflite_relay_models(
ifm_shape, kernel_shape, strides, dilation, padding, activation, dtype
)

cmsisnn_mod = cmsisnn.partition_for_cmsisnn(relay_mod, params)

# validate pattern matching
assert_partitioned_function(relay_mod, cmsisnn_mod)

# validate CMSIS-NN output against TFLite output
input_map, output_map, output_tolerance = generate_ref_data_tflite(tflite_model)
compile_and_run(
AOTTestModel(
module=cmsisnn_mod,
inputs=input_map,
outputs=output_map,
params=params,
output_tolerance=output_tolerance,
),
test_runner,
interface_api,
use_unpacked_api,
)


@pytest.mark.skip(reason="See https://github.com/apache/tvm/issues/10314")
@tvm.testing.requires_cmsisnn
@pytest.mark.parametrize("ifm_shape", [(1, 28, 28, 12), (1, 64, 100, 4)])
Expand Down
132 changes: 131 additions & 1 deletion tests/python/contrib/test_cmsisnn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
"""CMSIS-NN functions for testing networks"""

import platform

import math
import numpy as np
import pytest
Expand Down Expand Up @@ -226,3 +225,134 @@ def make_qnn_relu(expr, fused_activation_fn, scale, zero_point, dtype):
)
if fused_activation_fn == "RELU":
return tvm.relay.op.clip(expr, a_min=max(qmin, quantize(0.0)), a_max=qmax)


def generate_random_input_data(seed, shape, dtype):
"""
Generates randomized input numpy arrays based on shape and dtype
"""
random_state = np.random.RandomState(seed)
if dtype == np.float32:
return random_state.uniform(-1, 1, size).astype(dtype)
else:
low = np.iinfo(dtype).min
high = np.iinfo(dtype).max + 1
return random_state.randint(low, high, shape, dtype)


def generate_ref_data_tflite(model):
"""
This method uses TFLite reference kernels to generate reference output.
Random input generator is used to get the input data.
It returns randomized inputs and reference outputs.
"""
import tensorflow as tf
from distutils.version import LooseVersion

output_tolerance = None
if tf.__version__ < LooseVersion("2.5.0"):
output_tolerance = 1
interpreter = tf.lite.Interpreter(model_content=model)
else:
from tensorflow.lite.python.interpreter import OpResolverType

output_tolerance = 0
interpreter = tf.lite.Interpreter(
model_content=model,
experimental_op_resolver_type=OpResolverType.BUILTIN_REF,
experimental_preserve_all_tensors=False,
)

interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Generate predictable randomized input
seed = 0
input_data = {}
for input_detail in input_details:
input_values = generate_random_input_data(
seed, input_detail["shape"], input_detail["dtype"]
)
interpreter.set_tensor(input_detail["index"], input_values)
input_data.update({input_detail["name"]: input_values})

interpreter.invoke()

# Obtain the expected output from interpreter
expected_output_data = {}
for output_detail in output_details:
expected_output_data.update(
{output_detail["name"]: interpreter.get_tensor(output_detail["index"])}
)

return input_data, expected_output_data, output_tolerance


def create_conv2d_tflite_model(ifm_shape, kernel_shape, strides, dilation, padding, activation):
""" This method prepares TFlite graph with a single Conv2d layer """
import tensorflow as tf

class Model(tf.Module):
@tf.function
def tf_function(self, x):
# Use tf.nn API to create the model
tf_strides = [1, strides[0], strides[1], 1]
op = tf.nn.conv2d(
x,
filters=tf.constant(
np.random.uniform(size=[kernel_shape[0], kernel_shape[1], 3, 3]),
dtype=tf.float32,
),
strides=tf_strides,
padding=padding,
dilations=dilation,
)
if activation:
op = tf.nn.relu(op)
return op

model = Model()
concrete_func = model.tf_function.get_concrete_function(
tf.TensorSpec(ifm_shape, dtype=tf.float32)
)

def representative_dataset():
for _ in range(100):
data = np.random.rand(*tuple(ifm_shape))
yield [data.astype(np.float32)]

converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
tflite_model = converter.convert()
return tflite_model


def create_conv2d_tflite_relay_models(
ifm_shape, kernel_shape, strides, dilation, padding, activation, dtype
):
"""
This method creates a conv2d TFLite layer and prepared TFLite model from it.
Converts that into the Relay module and params.
Returns TFLite model, Relay module and params.
"""
pytest.importorskip("tflite")
import tflite.Model

serialized_tflite_model = create_conv2d_tflite_model(
ifm_shape, kernel_shape, strides, dilation, padding, activation
)

tflite_model = tflite.Model.Model.GetRootAsModel(serialized_tflite_model, 0)

relay_module, params = relay.frontend.from_tflite(
tflite_model,
shape_dict={"input": ifm_shape},
dtype_dict={"input": dtype},
)

return serialized_tflite_model, relay_module, params