From cf5f31721ef489bb82bf4a7bb2d493713d48447d Mon Sep 17 00:00:00 2001 From: Aseem Wadhwa Date: Thu, 1 Oct 2020 12:58:16 -0700 Subject: [PATCH] Sync for 4.0b4 release (#950) * sync for 4.0b4 release * fix extra space character in build.sh and add a simple prediction test for smoke testing on older macs --- .gitlab-ci.yml | 77 +- README.md | 2 +- coremlpython/CoreMLPython.h | 10 +- coremlpython/CoreMLPython.mm | 33 +- coremltools/_deps/__init__.py | 4 +- coremltools/converters/_converters_entry.py | 3 +- .../converters/mil/backend/nn/op_mapping.py | 544 ++++++-- .../backend/nn/passes/commingle_loop_vars.py | 23 +- coremltools/converters/mil/converter.py | 4 +- .../frontend/tensorflow/basic_graph_ops.py | 2 + .../mil/frontend/tensorflow/load.py | 1 + .../converters/mil/frontend/tensorflow/ops.py | 128 +- .../mil/frontend/tensorflow/parse.py | 1 + .../backfill_make_list_elem_type.py | 22 +- .../tensorflow/test/test_composite_ops.py | 68 + .../tensorflow/test/test_custom_ops.py | 143 +- .../mil/frontend/tensorflow/test/test_ops.py | 1223 +++++++++-------- .../frontend/tensorflow/test/test_parse.py | 2 +- .../frontend/tensorflow/test/testing_utils.py | 12 +- .../tensorflow/tf_graph_pass/__init__.py | 1 + .../tf_graph_pass/insert_get_tuple.py | 2 +- .../tf_graph_pass/quantization_pass.py | 70 + .../mil/frontend/tensorflow2/load.py | 77 +- .../tensorflow2/test/test_v2_composite_ops.py | 30 + .../frontend/tensorflow2/test/test_v2_ops.py | 42 +- .../tensorflow2/test/test_v2_ops_tf_keras.py | 127 +- .../tensorflow2/test/testing_utils.py | 32 +- .../rewrite_control_flow_functions.py | 31 +- .../mil/frontend/torch/converter.py | 17 +- .../converters/mil/frontend/torch/load.py | 4 +- .../converters/mil/frontend/torch/ops.py | 253 ++-- .../mil/frontend/torch/test/test_api.py | 33 + .../torch/test/test_internal_graph.py | 51 +- .../mil/frontend/torch/test/test_torch_ops.py | 221 ++- .../mil/frontend/torch/test/testing_utils.py | 51 +- .../mil/frontend/torch/torch_op_registry.py | 4 +- coremltools/converters/mil/mil/block.py | 6 + coremltools/converters/mil/mil/builder.py | 5 +- coremltools/converters/mil/mil/operation.py | 17 +- .../converters/mil/mil/ops/defs/_utils.py | 9 +- .../mil/mil/ops/defs/control_flow.py | 101 +- .../converters/mil/mil/ops/defs/conv.py | 43 + .../mil/mil/ops/defs/elementwise_binary.py | 2 +- .../mil/mil/ops/defs/elementwise_unary.py | 49 +- .../mil/mil/ops/defs/image_resizing.py | 17 +- .../mil/mil/ops/defs/normalization.py | 95 +- .../converters/mil/mil/ops/defs/reduction.py | 13 +- .../mil/mil/ops/defs/tensor_operation.py | 27 +- .../mil/mil/ops/tests/test_const.py | 50 + .../mil/mil/ops/tests/test_control_flow.py | 113 +- .../converters/mil/mil/ops/tests/test_conv.py | 1 - .../mil/ops/tests/test_elementwise_unary.py | 101 +- .../mil/mil/ops/tests/test_image_resizing.py | 18 +- .../mil/mil/ops/tests/test_linear.py | 5 +- .../mil/mil/ops/tests/test_normalization.py | 233 +++- .../mil/mil/ops/tests/test_scatter_gather.py | 39 + .../mil/ops/tests/test_tensor_operation.py | 40 +- .../ops/tests/test_tensor_transformation.py | 100 +- .../mil/mil/ops/tests/testing_utils.py | 8 +- .../converters/mil/mil/passes/common_pass.py | 1 + .../mil/passes/loop_invariant_elimination.py | 27 +- .../mil/mil/passes/noop_elimination.py | 131 +- .../mil/mil/passes/pad_conv_connect.py | 131 ++ .../mil/mil/passes/reduce_transposes.py | 6 +- .../mil/mil/passes/test_noop_elimination.py | 225 +++ .../mil/mil/passes/test_pad_conv_pass.py | 126 ++ .../mil/passes/test_reduce_transposes_pass.py | 28 + .../converters/mil/mil/types/__init__.py | 1 + .../converters/mil/mil/types/type_mapping.py | 11 +- coremltools/converters/mil/mil/var.py | 39 +- coremltools/converters/mil/testing_reqs.py | 3 +- coremltools/converters/mil/testing_utils.py | 25 +- coremltools/converters/onnx/_operators.py | 6 +- coremltools/converters/onnx/_operators_nd.py | 8 +- coremltools/models/neural_network/builder.py | 19 +- coremltools/test/api/test_api_examples.py | 8 +- .../test/neural_network/test_nn_builder.py | 21 +- .../test_simple_nn_inference.py | 43 + coremltools/version.py | 2 +- mlmodel/format/Model.proto | 21 + mlmodel/src/NeuralNetworkBuffer.cpp | 212 +-- mlmodel/src/NeuralNetworkBuffer.hpp | 80 +- .../src/Validation/InterfaceValidators.cpp | 4 +- reqs/test_tf2.pip | 4 +- scripts/build.sh | 2 +- scripts/build_docs.sh | 2 +- setup.py | 1 + 87 files changed, 4056 insertions(+), 1571 deletions(-) create mode 100644 coremltools/converters/mil/frontend/tensorflow/test/test_composite_ops.py create mode 100644 coremltools/converters/mil/frontend/tensorflow/tf_graph_pass/quantization_pass.py create mode 100644 coremltools/converters/mil/frontend/tensorflow2/test/test_v2_composite_ops.py create mode 100644 coremltools/converters/mil/mil/ops/tests/test_const.py create mode 100644 coremltools/converters/mil/mil/passes/pad_conv_connect.py create mode 100644 coremltools/converters/mil/mil/passes/test_pad_conv_pass.py create mode 100644 coremltools/test/neural_network/test_simple_nn_inference.py diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 7bee8606c..13cfd9cb1 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -22,7 +22,7 @@ check_python_flake8: ######################################################################## # -# linux - Build & Test +# linux - Build # ######################################################################## @@ -44,9 +44,33 @@ build_wheel_linux_py27: variables: PYTHON: "2.7" +build_wheel_linux_py35: + <<: *build_linux + image: registry.gitlab.com/zach_nation/coremltools/build-image-ubuntu-14.04:1.0.3 + variables: + PYTHON: "3.5" + +build_wheel_linux_py36: + <<: *build_linux + image: registry.gitlab.com/zach_nation/coremltools/build-image-ubuntu-14.04:1.0.3 + variables: + PYTHON: "3.6" + +build_wheel_linux_py37: + <<: *build_linux + image: registry.gitlab.com/zach_nation/coremltools/build-image-ubuntu-14.04:1.0.3 + variables: + PYTHON: "3.7" + +build_wheel_linux_py38: + <<: *build_linux + image: registry.gitlab.com/zach_nation/coremltools/build-image-ubuntu-14.04:1.0.3 + variables: + PYTHON: "3.8" + ######################################################################### ## -## macOS - Build & Test +## macOS - Build ## ######################################################################### @@ -91,6 +115,12 @@ build_wheel_macos_py38: variables: PYTHON: "3.8" +######################################################################### +## +## macOS - Test +## +######################################################################### + .test_macos_pkg: &test_macos_pkg stage: test script: @@ -272,6 +302,45 @@ test_macos11_py37_mil: - coremltools/converters/mil/mil/**/*.{py} - coremltools/converters/mil/backend/**/*.{py} +######################################################################### +## +## macOS - Smoke Test on older versions +## +######################################################################### + +test_macos11_py38_coremltools_smoke_test: + <<: *test_macos_pkg + tags: + - macos11 + dependencies: + - build_wheel_macos_py38 + variables: + WHEEL_PATH: build/dist/*cp38*10_16* + TEST_PACKAGE: coremltools.test.neural_network.test_simple_nn_inference + PYTHON: "3.8" + +test_macos15_py38_coremltools_smoke_test: + <<: *test_macos_pkg + tags: + - macos10.15 + dependencies: + - build_wheel_macos_py38 + variables: + WHEEL_PATH: build/dist/*cp38*10_15* + TEST_PACKAGE: coremltools.test.neural_network.test_simple_nn_inference + PYTHON: "3.8" + +test_macos14_py38_coremltools_smoke_test: + <<: *test_macos_pkg + tags: + - macos10.14 + dependencies: + - build_wheel_macos_py38 + variables: + WHEEL_PATH: build/dist/*cp38*10_14* + TEST_PACKAGE: coremltools.test.neural_network.test_simple_nn_inference + PYTHON: "3.8" + ######################################################################### ## ## Make docs @@ -314,6 +383,10 @@ collect_artifacts: echo "Collect artifacts (wheels and documentation)" dependencies: - build_wheel_linux_py27 + - build_wheel_linux_py35 + - build_wheel_linux_py36 + - build_wheel_linux_py37 + - build_wheel_linux_py38 - build_wheel_macos_py27 - build_wheel_macos_py35 - build_wheel_macos_py36 diff --git a/README.md b/README.md index 42951c5ab..a04a5ea8b 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ With coremltools, you can do the following: To get the latest version of coremltools: ```shell -pip install coremltools==4.0b3 +pip install coremltools==4.0b4 ``` For the latest changes please see the [release notes](https://github.com/apple/coremltools/releases/). diff --git a/coremlpython/CoreMLPython.h b/coremlpython/CoreMLPython.h index 1ca240080..6643a23d3 100644 --- a/coremlpython/CoreMLPython.h +++ b/coremlpython/CoreMLPython.h @@ -51,10 +51,14 @@ namespace CoreML { std::unique_ptr nnBuffer; public: - NeuralNetworkBufferInformation(const std::string& bufferFilePath, NNBuffer::bufferMode mode); + NeuralNetworkBufferInformation(const std::string& bufferFilePath, NNBuffer::BufferMode mode); ~NeuralNetworkBufferInformation(); - std::vector getBuffer(const u_int64_t offset); - u_int64_t addBuffer(const std::vector& buffer); + + template + u_int64_t addBuffer(const std::vector& buffer); + + template + std::vector getBuffer(const u_int64_t offset); }; } } diff --git a/coremlpython/CoreMLPython.mm b/coremlpython/CoreMLPython.mm index b334608d0..ed07b466a 100644 --- a/coremlpython/CoreMLPython.mm +++ b/coremlpython/CoreMLPython.mm @@ -134,7 +134,7 @@ /* * NeuralNetworkBuffer - NeuralNetworkBuffer */ -NeuralNetworkBufferInformation::NeuralNetworkBufferInformation(const std::string &bufferFilePath, NNBuffer::bufferMode mode) +NeuralNetworkBufferInformation::NeuralNetworkBufferInformation(const std::string &bufferFilePath, NNBuffer::BufferMode mode) : nnBuffer(std::make_unique(bufferFilePath, mode)) { } @@ -149,18 +149,20 @@ * Writes given buffer into file * Returns offset from the beginning of buffer */ -inline u_int64_t NeuralNetworkBufferInformation::addBuffer(const std::vector &buffer) { - return nnBuffer->addBuffer(buffer); +template +inline u_int64_t NeuralNetworkBufferInformation::addBuffer(const std::vector& buffer) { + return nnBuffer->AddBuffer(buffer); } /* * NeuralNetworkBufferInformation - getBuffer * Reads buffer from given offset and of given size and writes to data */ -inline std::vector NeuralNetworkBufferInformation::getBuffer(const u_int64_t offset) { +template +inline std::vector NeuralNetworkBufferInformation::getBuffer(const u_int64_t offset) { // TODO: Explore Pybind11 Opaque to pass vector by reference - std::vector buffer; - nnBuffer->getBuffer(offset, buffer); + std::vector buffer; + nnBuffer->GetBuffer(offset, buffer); return buffer; } @@ -180,13 +182,18 @@ .def("print", &NeuralNetworkShapeInformation::print); py::class_ netBuffer(m, "_NeuralNetworkBuffer"); - netBuffer.def(py::init()) - .def("add_buffer", &NeuralNetworkBufferInformation::addBuffer) - .def("get_buffer", &NeuralNetworkBufferInformation::getBuffer); - py::enum_(netBuffer, "mode") - .value("write", NNBuffer::bufferMode::write) - .value("append", NNBuffer::bufferMode::append) - .value("read", NNBuffer::bufferMode::read) + netBuffer.def(py::init()) + .def("add_buffer_float", &NeuralNetworkBufferInformation::addBuffer) + .def("add_buffer_int", &NeuralNetworkBufferInformation::addBuffer) + .def("add_buffer_bool", &NeuralNetworkBufferInformation::addBuffer) + .def("get_buffer_float", &NeuralNetworkBufferInformation::getBuffer) + .def("get_buffer_int", &NeuralNetworkBufferInformation::getBuffer) + .def("get_buffer_bool", &NeuralNetworkBufferInformation::getBuffer); + + py::enum_(netBuffer, "mode") + .value("write", NNBuffer::BufferMode::Write) + .value("append", NNBuffer::BufferMode::Append) + .value("read", NNBuffer::BufferMode::Read) .export_values(); return m.ptr(); diff --git a/coremltools/_deps/__init__.py b/coremltools/_deps/__init__.py index d7de145e9..3a5fe3885 100644 --- a/coremltools/_deps/__init__.py +++ b/coremltools/_deps/__init__.py @@ -86,10 +86,10 @@ def __get_sklearn_version(version): _HAS_TF = True _HAS_TF_1 = False _HAS_TF_2 = False -_TF_1_MIN_VERSION = "1.0.0" +_TF_1_MIN_VERSION = "1.12.0" _TF_1_MAX_VERSION = "1.15.0" _TF_2_MIN_VERSION = "2.1.0" -_TF_2_MAX_VERSION = "2.2.0" +_TF_2_MAX_VERSION = "2.3.0" try: import tensorflow diff --git a/coremltools/converters/_converters_entry.py b/coremltools/converters/_converters_entry.py index 9ffaa1238..4988b6cc4 100644 --- a/coremltools/converters/_converters_entry.py +++ b/coremltools/converters/_converters_entry.py @@ -327,7 +327,8 @@ def _flatten_list(_inputs): if convert_to == 'mil': return proto_spec # Returns the MIL program - model = coremltools.models.MLModel(proto_spec, useCPUOnly=True) + useCPUOnly = kwargs.get("useCPUOnly", True) + model = coremltools.models.MLModel(proto_spec, useCPUOnly=useCPUOnly) if minimum_deployment_target is not None: check_deployment_compatibility( diff --git a/coremltools/converters/mil/backend/nn/op_mapping.py b/coremltools/converters/mil/backend/nn/op_mapping.py index 5cf42bd4a..284e722a2 100644 --- a/coremltools/converters/mil/backend/nn/op_mapping.py +++ b/coremltools/converters/mil/backend/nn/op_mapping.py @@ -16,6 +16,9 @@ from coremltools.converters.mil.mil.types import np_dtype_to_py_type from coremltools.converters.mil.mil import types from coremltools.converters.mil.mil.ops.registry import SSAOpRegistry +from coremltools.models.neural_network.quantization_utils import ( + _convert_array_to_nbit_quantized_bytes, +) from tqdm import tqdm as _tqdm from .mil_to_nn_mapping_registry import * @@ -189,14 +192,16 @@ def _try_convert_global_pool(const_context, builder, op, mode): if is_variadic(rank) or rank not in {4, 5}: return False keep_dims = op.keep_dims.val + if keep_dims is False: + return False + if op.axes is not None: axes = op.axes.val axes = sorted([rank + axis if axis < 0 else axis for axis in axes]) - if keep_dims is False: - return False - if rank == 4 and tuple(axes) != (2, 3): + + if tuple(op.outputs[0].shape[:-2]) != tuple(op.inputs["x"].shape[:-2]): return False - if rank == 5 and tuple(axes) != (2, 3, 4): + if not all([s == 1 for s in op.outputs[0].shape[-2:]]): return False builder.add_pooling( name=op.name, @@ -329,6 +334,17 @@ def batch_norm(const_context, builder, op): channels = op.x.shape[1] gamma = _np.array([1.0] * channels) if op.gamma is None else op.gamma.val beta = _np.array([0.0] * channels) if op.beta is None else op.beta.val + + x_name = make_input(const_context, builder, op.x) + out_name = op.outputs[0].name + + if op.x.rank == 3: + x_name = op.name + "_expanded" + builder.add_expand_dims( + name=x_name, input_name=op.x.name, output_name=x_name, axes=[-1], + ) + out_name += "_batch_norm" + builder.add_batchnorm( name=op.name, channels=channels, @@ -336,13 +352,22 @@ def batch_norm(const_context, builder, op): beta=beta, mean=op.mean.val, variance=op.variance.val, - input_name=make_input(const_context, builder, op.x), - output_name=op.outputs[0].name, + input_name=x_name, + output_name=out_name, compute_mean_var=False, instance_normalization=False, epsilon=op.epsilon.val, ) + # Squeeze added `Width` dimension for 1d case + if op.x.rank == 3: + x_name = op.name + "_squeeze" + builder.add_squeeze( + name=x_name, + input_name=out_name, + output_name=op.outputs[0].name, + axes=[-1], + ) @register_mil_to_nn_mapping def const(const_context, builder, op): @@ -350,8 +375,7 @@ def const(const_context, builder, op): pass -@register_mil_to_nn_mapping -def conv(const_context, builder, op): +def conv_helper(const_context, builder, op): # v2 x: (n, C_in/groups, spatial_dims) x_name = make_input(const_context, builder, op.x) out_name = op.outputs[0].name @@ -440,6 +464,18 @@ def conv(const_context, builder, op): dilations = dilations + [1] strides = strides + [1] + if weights is not None and weights.dtype == 'uint8': + nbits = op.nbits.val + weights = _convert_array_to_nbit_quantized_bytes(weights.flatten(), nbits).tobytes() + quantization_type = op.quantization_type.val + quant_bias = op.quant_bias.val + quant_scale = op.quant_scale.val + else: + quantization_type = None + nbits = None + quant_bias = None + quant_scale = None + if is_conv1d or is_conv2d: builder.add_convolution( name=out_name, @@ -458,6 +494,10 @@ def conv(const_context, builder, op): input_name=input_names, output_name=out_name, dilation_factors=dilations, + quantization_type=quantization_type, + nbits=nbits, + quant_bias=quant_bias, + quant_scale=quant_scale, **pad # Python 2.7.16 will fail with a syntax error if a comma is included after `**pad` ) @@ -497,6 +537,15 @@ def conv(const_context, builder, op): **pad # Python 2.7.16 will fail with a syntax error if a comma is included after `**pad` ) +@register_mil_to_nn_mapping +def conv(const_context, builder, op): + conv_helper(const_context, builder, op) + + +@register_mil_to_nn_mapping() +def conv_quantized(const_context, builder, op): + conv_helper(const_context, builder, op) + @register_mil_to_nn_mapping def cumsum(const_context, builder, op): @@ -606,7 +655,60 @@ def _add_elementwise_binary( return add_const(const_context, builder, op.y.name, op.y.val) - if mode in ["add", "multiply", "max", "min", "ave"]: + if mode in {"add", "multiply", "subtract"} and op.x.rank <= 5 and op.y.rank <= 5: + shape_x = _np.array([1] * (5 - op.x.rank) + list(op.x.shape)) + shape_y = _np.array([1] * (5 - op.y.rank) + list(op.y.shape)) + + internal_x = internal_y = None + if all(shape_x == 1): + internal_y = op.x + internal_x = op.y + elif all(shape_y == 1): + internal_x = op.x + internal_y = op.y + + for indices in ([1], [2], [3, 4], [2, 3, 4], [1, 2, 3, 4]): + if indices == [1, 2, 3, 4] and mode == "multiply": + # INTERNAL_MUL_XYKN not implemented + continue + if all(shape_x[indices] == shape_y[indices]): + if all([True if i in indices else s == 1 for i, s in enumerate(shape_x)]): + internal_y = op.x + internal_x = op.y + break + if all([True if i in indices else s == 1 for i, s in enumerate(shape_y)]): + internal_x = op.x + internal_y = op.y + break + + if internal_x is not None: + if mode in {"add", "multiply"}: + builder.add_elementwise( + name=name, + input_names=make_input(const_context, builder, [internal_x, internal_y]), + output_name=output_name, + mode=mode.upper(), + ) + elif mode == "subtract": + builder.add_activation( + name="_neg_y_" + name, + input_name=make_input(const_context, builder, op.y), + output_name="_neg_y_" + output_name, + non_linearity="LINEAR", + params=[-1, 0]) + if op.x == internal_y: + internal_x = "_neg_y_" + output_name + else: + internal_y = "_neg_y_" + output_name + builder.add_elementwise( + name=name, + input_names=make_input(const_context, builder, [internal_x, internal_y]), + output_name=output_name, + mode="ADD", + ) + return + + if mode in {"add", "multiply", "max", "min"}: if op.x.shape == op.y.shape: builder.add_elementwise( name=name, @@ -618,11 +720,8 @@ def _add_elementwise_binary( add_func = getattr(builder, "add_" + mode + "_broadcastable", None) if add_func is None: - _logging.error( - "Elementwise binary broadcastable method {} not found in builder.".format( - mode - ) - ) + msg = "Element-wise binary method {} not found in builder." + raise ValueError(msg.format(mode)) add_func( name=name, @@ -643,7 +742,7 @@ def _add_elementwise_binary( add_func = getattr(builder, "add_" + mode, None) if add_func is None: - msg = "Elementwise binary method {} not found in builder." + msg = "Element-wise binary method {} not found in builder." raise ValueError(msg.format(mode)) add_func( @@ -803,7 +902,7 @@ def greater_equal(const_context, builder, op): @register_mil_to_nn_mapping def inverse(const_context, builder, op): - _add_elementwise_unary(const_context, builder, op, "inverse") + _add_elementwise_unary(const_context, builder, op, "inverse", epsilon=op.epsilon.val) @register_mil_to_nn_mapping @@ -818,7 +917,7 @@ def less_equal(const_context, builder, op): @register_mil_to_nn_mapping def log(const_context, builder, op): - _add_elementwise_unary(const_context, builder, op, "log") + _add_elementwise_unary(const_context, builder, op, "log", epsilon=op.epsilon.val) @register_mil_to_nn_mapping @@ -883,7 +982,7 @@ def round(const_context, builder, op): @register_mil_to_nn_mapping def rsqrt(const_context, builder, op): - _add_elementwise_unary(const_context, builder, op, "rsqrt") + _add_elementwise_unary(const_context, builder, op, "rsqrt", epsilon=op.epsilon.val) @register_mil_to_nn_mapping @@ -1251,7 +1350,7 @@ def topk(const_context, builder, op): builder.add_topk( name=op.name, input_names=make_input(const_context, builder, [op.x]), - output_names=[op.name + ":0", op.name + ":1"], + output_names=[output.name for output in op.outputs], k=op.k.val, axis=op.axis.val, use_bottom_k=op.ascending.val, @@ -1266,18 +1365,28 @@ def l2_pool(const_context, builder, op): @register_mil_to_nn_mapping def linear(const_context, builder, op): out_channels, in_channels = op.weight.shape - has_bias = op.bias.val is not None - builder.add_inner_product( - name=op.name, - W=op.weight.val, - b=op.bias.val, - input_channels=in_channels, - output_channels=out_channels, - has_bias=has_bias, - input_name=make_input(const_context, builder, op.x), - output_name=op.outputs[0].name, - ) - + if op.x.rank and op.x.rank <= 3 and op.x.rank > 0: + has_bias = op.bias.val is not None + builder.add_inner_product( + name=op.name, + W=op.weight.val, + b=op.bias.val, + input_channels=in_channels, + output_channels=out_channels, + has_bias=has_bias, + input_name=make_input(const_context, builder, op.x), + output_name=op.outputs[0].name, + ) + else: + builder.add_batched_mat_mul( + name=op.name, + input_names=make_input(const_context, builder, [op.x]), + output_name=op.outputs[0].name, + W=op.weight.val.T, + bias=op.bias.val, + weight_matrix_rows=in_channels, + weight_matrix_columns=out_channels, + ) @register_mil_to_nn_mapping def matmul(const_context, builder, op): @@ -2135,9 +2244,6 @@ def pad(const_context, builder, op): mode = nn_mode_mapping.get(mode, mode) if pad is not None and op.x.rank > 1 and _np.all(pad[:-4] == 0): - # check and map mode - if mode == "symmetric": - mode = "reflection" pad = pad[-4:] left, right = pad[2], pad[3] top, bottom = pad[0], pad[1] @@ -2166,7 +2272,6 @@ def pad(const_context, builder, op): input_names=make_input(const_context, builder, [op.x]), output_name=op.outputs[0].name, value=constant_val, - pad_to_given_output_size_mode=False, pad_amounts=pad, ) else: @@ -2203,18 +2308,19 @@ def l2_norm(const_context, builder, op): @register_mil_to_nn_mapping def layer_norm(const_context, builder, op): - input_shape_full = list(op.x.shape) - input_shape = [-1 if is_symbolic(s) else s for s in input_shape_full] - axes = None if op.axes is None else op.axes.val - normalized_shape = input_shape[-len(axes) :] - gamma = _np.ones(normalized_shape) if op.gamma is None else op.gamma.val - beta = _np.zeros(normalized_shape) if op.beta is None else op.beta.val - if ( - len(input_shape) in [2, 3] - and len(axes) == 1 - and axes[0] == len(input_shape) - 1 - and input_shape.count(-1) < 2 - ): + + rank = op.x.rank + input_shape = [-1 if is_symbolic(dim) else dim for dim in list(op.x.shape)] + axes = list(range(op.x.rank)) if op.axes.val is None else op.axes.val + axes = [axis+rank if axis < 0 else axis for axis in op.axes.val] + epsilon = op.epsilon.val + + if rank in [2, 3] and len(axes) == 1 and axes[0] == rank - 1 and input_shape.count(-1) < 2 and input_shape[-1] != -1: + + normalized_shape = input_shape[-len(axes) :] + gamma = _np.ones(normalized_shape) if op.gamma is None else op.gamma.val + beta = _np.zeros(normalized_shape) if op.beta is None else op.beta.val + builder.add_reshape_static( name=op.name + "_reshape", input_name=make_input(const_context, builder, op.x), @@ -2228,7 +2334,7 @@ def layer_norm(const_context, builder, op): output_name=op.x.name + "_mvn", across_channels=True, normalize_variance=True, - epsilon=op.epsilon.val, + epsilon=epsilon, ) builder.add_scale( @@ -2249,14 +2355,100 @@ def layer_norm(const_context, builder, op): output_shape=input_shape, ) else: - builder.add_layer_normalization( - name=op.name, + mean_name = op.name + "_mean" + builder.add_reduce_mean( + name=mean_name, input_name=make_input(const_context, builder, op.x), + output_name=mean_name, + axes=axes, + keepdims=True, + reduce_all=False, + ) + + sub_mean_name = op.name + "_sub_mean" + builder.add_subtract_broadcastable( + name=sub_mean_name, + input_names=[op.x.name, mean_name], + output_name=sub_mean_name, + ) + + square_name = op.name + '_square' + builder.add_unary( + name=square_name, + input_name=sub_mean_name, + output_name=square_name, + mode="power", + alpha=2.0, + ) + + square_sum_name = op.name + '_square_sum' + builder.add_reduce_sum( + name=square_sum_name, + input_name=square_name, + output_name=square_sum_name, + axes=axes, + keepdims=True, + reduce_all=False, + ) + + normalized_shape = [op.x.shape[i] if i in axes else 1 for i in range(rank)] + if not any_symbolic(normalized_shape): + div_prod_name = op.name + '_div_constant' + add_const(const_context, builder, div_prod_name, _np.prod(normalized_shape)) + else: + raise NotImplementedError("dynamic shape input nor supported for layer_norm") + + div_square_sum_name = op.name + '_div_square_sum' + builder.add_divide_broadcastable( + name=div_square_sum_name, + input_names=[square_sum_name, div_prod_name], + output_name=div_square_sum_name + ) + + epsilon_const_name = op.name + '_epsilon' + add_const(const_context, builder, epsilon_const_name, epsilon) + add_epsilon_name = op.name + '_add_epsilon' + builder.add_elementwise( + name=add_epsilon_name, + input_names=[div_square_sum_name, epsilon_const_name], + output_name=add_epsilon_name, + mode="ADD", + ) + + sqrt_name = op.name + '_sqrt' + builder.add_unary( + name=sqrt_name, + input_name=add_epsilon_name, + output_name=sqrt_name, + mode="sqrt", + ) + + div_name = op.name + '_divide' + builder.add_divide_broadcastable( + name=div_name, + input_names=[sub_mean_name, sqrt_name], + output_name=div_name + ) + + gamma = _np.ones(normalized_shape) if op.gamma is None else _np.reshape(op.gamma.val, normalized_shape) + beta = _np.zeros(normalized_shape) if op.beta is None else _np.reshape(op.beta.val, normalized_shape) + + gamma_name = op.name + '_gamma' + beta_name = op.name + '_beta' + add_const(const_context, builder, gamma_name, gamma) + add_const(const_context, builder, beta_name, beta) + + mul_name = op.name + '_mul' + builder.add_multiply_broadcastable( + name=mul_name, + input_names=[div_name, gamma_name], + output_name=mul_name, + ) + + builder.add_add_broadcastable( + name=op.name, + input_names=[mul_name, beta_name], output_name=op.outputs[0].name, - normalized_shape=normalized_shape, - gamma=gamma, - beta=beta, - eps=op.epsilon.val, ) @@ -2564,11 +2756,12 @@ def cond(const_context, builder, op): @register_mil_to_nn_mapping def while_loop(const_context, builder, op): - block = op.blocks[0] + cond_block = op.blocks[0] + body_block = op.blocks[1] # Assume that all loop vars aren't loop invariant (invariant loop vars # should've be optimized away in graph passes). - for v_in, vx_in in zip(op.loop_vars, block.inputs): + for v_in, vx_in in zip(op.loop_vars, cond_block.inputs): assert v_in.name != vx_in.name, "Loop invariant detected in {}".format(op) builder.add_copy( name=vx_in.name + "_input_copy", @@ -2588,36 +2781,37 @@ def while_loop(const_context, builder, op): disable_rank5_shape_mapping=True, use_float_arraytype=True, ) - cond_builder.rank_dict = {k.name: builder.rank_dict[k.name] for k in block.inputs} + cond_builder.rank_dict = {k.name: builder.rank_dict[k.name] for k in cond_block.inputs} convert_ops( const_context, cond_builder, - block.operations_for_vars(block.outputs[:1]), - block.outputs[:1], + cond_block.operations, + cond_block.outputs, ) - loop_layer.loop.conditionVar = block.outputs[0].name + loop_layer.loop.conditionVar = cond_block.outputs[0].name - # while_loop body produces [cond_var] + loop_vars + # while_loop body produces loop_vars body_builder = neural_network.NeuralNetworkBuilder( nn_spec=loop_layer.loop.bodyNetwork, disable_rank5_shape_mapping=True, use_float_arraytype=True, ) - body_builder.rank_dict = {k.name: builder.rank_dict[k.name] for k in block.inputs} + body_builder.rank_dict = {k.name: builder.rank_dict[k.name] for k in body_block.inputs} convert_ops( const_context, body_builder, - block.operations_for_vars(block.outputs[1:]), - block.outputs[1:], + body_block.operations, + body_block.outputs, ) # Also assume all outputs are different from loop inputs (i.e., no loop # invariant.) - for vx_in, vx_out in zip(block.inputs, block.outputs[1:]): + #for vx_in, vx_out in zip(block.inputs, block.outputs[1:]): + for vx_in, vx_out in zip(body_block.inputs, body_block.outputs): if vx_in.name == vx_out.name: msg = "Loop invariant var {} detected in block {}" - _logging.warning(msg.format(vx_in.name, block.name)) + _logging.warning(msg.format(vx_in.name, body_block.name)) continue body_builder.add_copy( name=vx_in.name + "_ret_copy", @@ -2854,21 +3048,27 @@ def custom_op(const_context, builder, op): def make_list(const_context, builder, op): # op.elem_shape is technically optional but ssa passes ensures it's # always there - elem_shape = op.elem_shape.val - has_static_elem_shape = all([dim > 0 for dim in elem_shape]) + # symbolic value in op.elem_shape means runtime-determined dimension, + # Ex: if op.elem_shape = [i0, 128], it means that the first dimension is runtime-determined. + elem_shape = op.elem_shape.sym_val - # Set a default initial size + # Set a initial size size = op.init_length.val - if size is not None and has_static_elem_shape: + + # set the dynamic dimensions to 1 for initialization + # Ex: op.elem_shape = [i0, 128] will result in [1, 128] + elem_shape = [1 if is_symbolic(dim) else dim for dim in elem_shape] + + if size is not None: array_size = size if size > 0 else 1 - array_shape = [array_size] + list(elem_shape) + array_shape = [array_size] + elem_shape add_const( const_context, builder, op.outputs[0].name, val=_np.zeros(array_shape, dtype="float"), ) - elif has_static_elem_shape: + else: if len(elem_shape) > 0: node_es_name = op.name + "_element_shape" add_const( @@ -2878,7 +3078,7 @@ def make_list(const_context, builder, op): val=_np.array(elem_shape, dtype="float"), ) - # Concatenate list length (the input, should be a constant vector of size 1) with element shape + # Concatenate list length of the input, should be a constant vector of size 1) with element shape node_arr_shape_name = op.name + "_arr_shape" layer = builder.add_concat_nd( name=node_arr_shape_name, @@ -2887,20 +3087,33 @@ def make_list(const_context, builder, op): axis=0, ) else: - node_es_name = op.init_length.name + raise ValueError("elem_shape should have length > 0.") + builder.add_fill_dynamic( name=op.name, input_name=node_arr_shape_name, output_name=op.outputs[0].name ) - else: - raise ValueError("TensorArray cannot determine element shapes statically") -def _realloc_list(const_context, builder, ls_var, index_var): +def _realloc_list(const_context, builder, ls_var, index_var, value_var, mode): + # we do two things in this helper function + # (1) + # check if we need to re-initialize the tensorarray: + # it happens when the elem_shape is runtime determined and the runtime shape is not equal to + # the default shape. Ex: elem_shape is = [i0, 10] (initilized with [1, 10]) and at the runtime we get [2, 10]. + + # (2) # If index_var >= len(ls_var), reallocate the array and copy over existing # contents + # index_var: str or Var # ls_var: Var + # check if elem_shape is runtime-determined + elem_shape = tuple(value_var.shape) + has_dynamic_shape = any([is_symbolic(i) for i in elem_shape]) + + # get the fill shape of the tensor array + # [length, elem_dim1, elem_dim2, ...] full_shape_name = ls_var.name + "_full_shape" builder.add_get_shape( name=full_shape_name, @@ -2908,7 +3121,7 @@ def _realloc_list(const_context, builder, ls_var, index_var): output_name=full_shape_name, ) - # slice shape [length, elem_size1, ...] to get current length + # slice shape [length, elem_dim1, elem_dim2, ...] to get current length curr_len_name = ls_var.name + "_length" builder.add_slice_static( name=curr_len_name, @@ -2921,6 +3134,107 @@ def _realloc_list(const_context, builder, ls_var, index_var): strides=[1], ) + value_elem_shape_name = ls_var.name + '_value_elem_shape' + if has_dynamic_shape: + # get elem_shape from value if it is runtime-determined + # this is similar to what the backfill_make_list_elem_type tf graph pass does. + # if mode == "list_write", elem_shape equal to value.shape, + # if mode == "list_scatter", elem_shape equal to value.shape[1:] + if mode == "list_write": + builder.add_get_shape( + name=value_elem_shape_name, + input_name=make_input(const_context, builder, value_var), + output_name=value_elem_shape_name, + ) + elif mode == "list_scatter": + raw_value_elem_shape_name = ls_var.name + '_raw_value_elem_shape' + builder.add_get_shape( + name=raw_value_elem_shape_name, + input_name=make_input(const_context, builder, value_var), + output_name=raw_value_elem_shape_name, + ) + + builder.add_slice_static( + name=value_elem_shape_name, + input_name=raw_value_elem_shape_name, + output_name=value_elem_shape_name, + begin_ids=[1], + end_ids=[-1], + begin_masks=[False], + end_masks=[True], + strides=[1], + ) + else: + add_const(const_context, builder, value_elem_shape_name, _np.array(elem_shape)) + + + # if elem_shape is runtime-determined, check if we need to re-initialize the array + + if has_dynamic_shape: + # slice shape [length, elem_dim1, elem_dim2, ...] to get list elem_shape + curr_elem_shape_name = ls_var.name + "_ls_elem_shape" + builder.add_slice_static( + name=curr_elem_shape_name, + input_name=full_shape_name, + output_name=curr_elem_shape_name, + begin_ids=[1], + end_ids=[-1], + begin_masks=[False], + end_masks=[True], + strides=[1], + ) + + # test if the runtime elem_shape from the list and value are equal + not_equal_name = ls_var.name + '_elem_shape_not_equal' + builder.add_not_equal( + name=not_equal_name, + input_names=[curr_elem_shape_name, value_elem_shape_name], + output_name=not_equal_name, + ) + + reduce_any_name = ls_var.name + '_reduce_any' + builder.add_reduce_sum( + name=reduce_any_name, + input_name=not_equal_name, + output_name=reduce_any_name, + axes=[0], + keepdims=False, + reduce_all=True, + ) + + # if the two elem_shape are different, then re initialize the list with elem_shape from the value + re_initialize_condition_name = ls_var.name + "_condition_re_initialize" + layer = builder.add_branch(name=re_initialize_condition_name, input_name=reduce_any_name) + true_builder = neural_network.NeuralNetworkBuilder( + nn_spec=layer.branch.ifBranch, + disable_rank5_shape_mapping=True, + use_float_arraytype=True, + ) + + re_initialize_shape_name = ls_var.name + "_re_initialize_shape" + true_builder.add_concat_nd( + name=re_initialize_shape_name, + input_names=[curr_len_name, value_elem_shape_name], + output_name=re_initialize_shape_name, + axis=0, + ) + + re_initialize_name = ls_var.name + "_re_initialize" + true_builder.add_fill_dynamic( + name=re_initialize_name, + input_name=re_initialize_shape_name, + output_name=re_initialize_name, + value=0.0, + ) + + true_builder.add_copy( + name=ls_var.name + "_re_initialize_assign", + input_name=re_initialize_name, + output_name=ls_var.name + ) + + # after re-initialize the list, we now check if we need to reallocate the list + # check if the index > curr_length is_growing_name = ls_var.name + "_is_growing" builder.add_greater_than( name=is_growing_name, @@ -2929,9 +3243,6 @@ def _realloc_list(const_context, builder, ls_var, index_var): use_greater_than_equal=True, ) - elem_shape_name = ls_var.name + "_elem_shape" - add_const(const_context, builder, elem_shape_name, _np.array(ls_var.elem_shape)) - condition_name = ls_var.name + "_condition" layer = builder.add_branch(name=condition_name, input_name=is_growing_name) @@ -2963,7 +3274,7 @@ def _realloc_list(const_context, builder, ls_var, index_var): alloc_shape_name = ls_var.name + "_alloc_shape" true_builder.add_concat_nd( name=alloc_shape_name, - input_names=[alloc_length_name1, elem_shape_name], + input_names=[alloc_length_name1, value_elem_shape_name], output_name=alloc_shape_name, axis=0, ) @@ -2994,10 +3305,10 @@ def _realloc_list(const_context, builder, ls_var, index_var): @register_mil_to_nn_mapping def list_write(const_context, builder, op): - _realloc_list(const_context, builder, op.ls, op.index) + _realloc_list(const_context, builder, op.ls, op.index, op.value, "list_write") # expanded_value_name is [1, op.value] - expanded_value_name = op.value.name + "_expanded" + expanded_value_name = op.ls.name + '_' + op.value.name + "_expanded" builder.add_expand_dims( name=expanded_value_name, input_name=make_input(const_context, builder, op.value), @@ -3034,9 +3345,7 @@ def list_scatter(const_context, builder, op): input_name=make_input(const_context, builder, op.indices), output_name=max_idx_name, ) - - _realloc_list(const_context, builder, op.ls, max_idx_name) - + _realloc_list(const_context, builder, op.ls, max_idx_name, op.value, "list_scatter") builder.add_scatter( name=op.name, input_names=make_input(const_context, builder, [op.ls, op.indices, op.value]), @@ -3087,61 +3396,8 @@ def list_length(const_context, builder, op): strides=[1], ) - @register_mil_to_nn_mapping -def isfinite(const_context, builder, op): - int_max = _np.iinfo(_np.int64).max - int_min = -_np.iinfo(_np.int64).max - 1 - const_name_max = op.name + "_const_name_max" - const_name_min = op.name + "_const_name_min" - if any_symbolic(op.x.shape): - shape_name = op.name + "_shape" - builder.add_get_shape( - name=shape_name, - input_name=make_input(const_context, builder, op.x), - output_name=shape_name, - ) - builder.add_fill_dynamic( - name=const_name_max, - input_name=shape_name, - output_name=const_name_max, - value=int_max, - ) - builder.add_fill_dynamic( - name=const_name_min, - input_name=shape_name, - output_name=const_name_min, - value=int_min, - ) - else: - shape = [1] if op.x.shape == () else op.x.shape - builder.add_fill_static( - name=const_name_max, - output_name=const_name_max, - output_shape=shape, - value=int_max, - ) - builder.add_fill_static( - name=const_name_min, - output_name=const_name_min, - output_shape=shape, - value=int_min, - ) - smaller_than_name = op.name + "_smaller" - greater_than_name = op.name + "_greater" - builder.add_less_than( - name=smaller_than_name, - input_names=make_input(const_context, builder, [op.x, const_name_max]), - output_name=smaller_than_name, - ) - builder.add_greater_than( - name=greater_than_name, - input_names=make_input(const_context, builder, [op.x, const_name_min]), - output_name=greater_than_name, - ) - builder.add_logical( - name=op.name, - input_names=[smaller_than_name, greater_than_name], - output_name=op.outputs[0].name, - mode="AND", - ) +def _const_symbolic(const_context, builder, op): + # do nothing + pass + diff --git a/coremltools/converters/mil/backend/nn/passes/commingle_loop_vars.py b/coremltools/converters/mil/backend/nn/passes/commingle_loop_vars.py index a0b2a7eb5..d3d5f565f 100644 --- a/coremltools/converters/mil/backend/nn/passes/commingle_loop_vars.py +++ b/coremltools/converters/mil/backend/nn/passes/commingle_loop_vars.py @@ -20,19 +20,18 @@ def commingle_loop_vars_block(block): if op.op_type != "while_loop": continue - block = op.blocks[0] + for block in op.blocks: + for v_out, vx_in in zip(op.outputs, block.inputs): + # Disable check as v_out is not visible in block. + block.replace_uses_of_var_after_op( + anchor_op=None, + old_var=vx_in, + new_var=v_out, + no_check_var_visibility=True, + ) - for v_out, vx_in in zip(op.outputs, block.inputs): - # Disable check as v_out is not visible in block. - block.replace_uses_of_var_after_op( - anchor_op=None, - old_var=vx_in, - new_var=v_out, - no_check_var_visibility=True, - ) - - # replace block inputs - block._block_inputs = op.outputs + # replace block inputs + block._block_inputs = op.outputs @register_pass(namespace="nn_backend") diff --git a/coremltools/converters/mil/converter.py b/coremltools/converters/mil/converter.py index 0e566ff87..ccf0128b6 100644 --- a/coremltools/converters/mil/converter.py +++ b/coremltools/converters/mil/converter.py @@ -28,7 +28,7 @@ class MILFrontend: name = "mil" def __call__(self, model, *args, **kwargs): - if "inputs" in kwargs: + if "inputs" in kwargs and kwargs["inputs"] is not None: inputs = kwargs["inputs"] if not isinstance(inputs, (list, tuple)): raise ValueError( @@ -93,6 +93,7 @@ def __call__(self, *args, **kwargs): return load(*args, **kwargs) + @ConverterRegistry.frontend class CustomFrontend: name = "custom" @@ -129,6 +130,7 @@ def _convert( msg.format(convert_from, list(converter_registry.frontends.keys())) ) frontend_converter = frontend_converter_type() + prog = frontend_converter(model, **kwargs) common_pass(prog) diff --git a/coremltools/converters/mil/frontend/tensorflow/basic_graph_ops.py b/coremltools/converters/mil/frontend/tensorflow/basic_graph_ops.py index 85274ec66..f3bac886e 100644 --- a/coremltools/converters/mil/frontend/tensorflow/basic_graph_ops.py +++ b/coremltools/converters/mil/frontend/tensorflow/basic_graph_ops.py @@ -280,6 +280,8 @@ def visit(node): vis[node.name] = False elif "global" in node.op: vis[node.name] = False + elif "FakeQuant" in node.op: + vis[node.name] = False elif node.name in assume_variable_nodes: vis[node.name] = False else: diff --git a/coremltools/converters/mil/frontend/tensorflow/load.py b/coremltools/converters/mil/frontend/tensorflow/load.py index 6f1893efc..1394022f7 100644 --- a/coremltools/converters/mil/frontend/tensorflow/load.py +++ b/coremltools/converters/mil/frontend/tensorflow/load.py @@ -194,6 +194,7 @@ def _program_from_tf_ssa(self): delete_asserts, functionalize_loops, constant_propagation, + quantization_pass, cond_to_where, remove_variable_nodes, fuse_dilation_conv, diff --git a/coremltools/converters/mil/frontend/tensorflow/ops.py b/coremltools/converters/mil/frontend/tensorflow/ops.py index bccfe410b..4ad12a0a6 100644 --- a/coremltools/converters/mil/frontend/tensorflow/ops.py +++ b/coremltools/converters/mil/frontend/tensorflow/ops.py @@ -13,7 +13,23 @@ from .convert_utils import convert_graph from .tf_op_registry import register_tf_op from coremltools.converters.mil.mil import types -from coremltools.converters.mil.mil.types.symbolic import is_symbolic +from coremltools.converters.mil.mil.types.symbolic import is_symbolic, any_symbolic + + +def _adjust_min_max(min, max, num_bits=8): + if (min <= max) and (max <= 0): + min = (min - max) * 1.0 + max = 0.0 + elif (min >= 0) and (max >= min): + max = (max - min) * 1.0 + min = 0.0 + else: + scale = (max - min) / (2 ** num_bits - 1) + min_adj = scale * round(min / scale) + max_adj = max + min_adj - min + min = min_adj + max = max_adj + return min, max def _is_scalar(type_): @@ -732,9 +748,66 @@ def DepthwiseConv2dNative(context, node): context.add(node.name, x) +@register_tf_op +def FakeQuantWithMinMaxVars(context, node): + w = context[node.inputs[0]] + min = context[node.inputs[1]].val + max = context[node.inputs[2]].val + num_bits = node.attr['num_bits'] + narrow_range = node.attr['narrow_range'] + + min, max = _adjust_min_max(min, max, num_bits) + + if narrow_range: + scale = (max-min) / (2 ** (num_bits) - 2) + bias = min - scale + else: + scale = (max-min) / (2 ** (num_bits) - 1) + bias = min + + w = mb.clip(x=w, alpha=min, beta=max) + w = mb.sub(x=w, y=bias) + x = mb.real_div(x=w, y=scale) + x = mb.round(x=x) + x = mb.mul(x=x, y=scale) + x = mb.add(x=x, y=bias, name=node.name) + context.add(node.name, x) + + @register_tf_op def Conv2D(context, node): - W_hwio = context[node.inputs[1]] + if "quantize" in node.attr: + quantization_type = "linear" + min = node.attr['quantize_min'] + max = node.attr['quantize_max'] + nbits = node.attr['num_bits'] + narrow_range = node.attr['narrow_range'] + + w = context[node.inputs[1]].sym_val + + min, max = _adjust_min_max(min, max, nbits) + + if narrow_range: + quant_scale = (max - min) / (2 ** (nbits) - 2) + quant_bias = (min-quant_scale) + else: + quant_scale = (max - min) / (2 ** (nbits) - 1) + quant_bias = (min) + + w_clip = _np.clip(w, min, max) + w_round = _np.round((w_clip-quant_bias)/quant_scale) + W_hwio = w_round.astype(_np.uint8) + + if not isinstance(quant_scale, list) and not isinstance(quant_scale, tuple): + quant_bias = [quant_bias] + quant_scale = [quant_scale] + else: + quantization_type = None + nbits = None + quant_scale = None + quant_bias = None + W_hwio = context[node.inputs[1]] + W_oihw = mb.transpose(x=W_hwio, perm=[3, 2, 0, 1]) data_format = node.attr.get("data_format", "NHWC") HW_dilations = _conv2d3d_strides_or_dilations( @@ -759,7 +832,21 @@ def Conv2D(context, node): pad_val = pad_val[4:] # Only the last op should have the same name as node.name conv_name = node.name + "x" if data_format == "NHWC" else node.name - if pad_type == "custom": + + if quantization_type is not None: + x = mb.conv_quantized( + x=x, + weight=W_oihw, + pad_type=pad_type, + strides=HW_strides, + dilations=HW_dilations, + name=conv_name, + quantization_type=quantization_type, + nbits=nbits, + quant_scale=quant_scale, + quant_bias=quant_bias, + ) + elif pad_type == "custom": x = mb.conv( x=x, weight=W_oihw, @@ -896,7 +983,6 @@ def EuclideanNorm(context, node): x = mb.reduce_l2_norm(x=x, axes=axes, keep_dims=keep_dims, name=node.name) context.add(node.name, x) - @register_tf_op def ExpandDims(context, node): x = context[node.inputs[0]] @@ -1198,7 +1284,7 @@ def Sqrt(context, node): @register_tf_op def Square(context, node): x = context[node.inputs[0]] - x = mb.square(x=x, name=node.name) + x = mb.mul(x=x, y=x, name=node.name) context.add(node.name, x) @@ -1435,6 +1521,8 @@ def MirrorPad(context, node): raise ValueError("TF `paddings` in Pad op must be const.") mode = node.attr.get("mode", "reflect").lower() + if mode == "symmetric": + mode = "reflect" in_rank = len(x.sym_type.get_shape()) if in_rank > 5 or in_rank < 2: @@ -1488,6 +1576,8 @@ def Pad(context, node): pad = context[node.inputs[1]] mode = node.attr.get("mode", "constant").lower() + if mode == "symmetric": + mode = "reflect" constant_val = node.attr.get("constant_val", 0.0) in_rank = len(x.sym_type.get_shape()) @@ -2299,7 +2389,10 @@ def Unpack(context, node): num_splits = node.attr.get("num", None) if num_splits is None: num_splits = x.shape[axis] - y = mb.split(x=x, num_splits=num_splits, axis=axis, name=node.name + "_unsqueezed") + if num_splits == 1: + y = [x] + else: + y = mb.split(x=x, num_splits=num_splits, axis=axis, name=node.name + "_unsqueezed") output_vars = [] for i in range(num_splits): output_vars.append( @@ -2383,7 +2476,15 @@ def ZerosLike(context, node): @register_tf_op def IsFinite(context, node): x = context[node.inputs[0]] - x = mb.isfinite(x=x, name=node.name) + if any_symbolic(x.shape): + x_shape = mb.shape(x=x) + else: + x_shape = [1] if x.shape == () else x.shape + max_tensor = mb.fill(shape=x_shape, value=_np.finfo(_np.float32).max) + min_tensor = mb.fill(shape=x_shape, value=_np.finfo(_np.float32).min) + less_then = mb.less_equal(x=x, y=max_tensor) + greater_than = mb.greater_equal(x=x, y=min_tensor) + x = mb.logical_and(x=less_then, y=greater_than, name=node.name) context.add(node.name, x) @@ -2468,12 +2569,19 @@ def TensorArrayV3(context, node): elem_shape = node.attr.get("element_shape", None) size = node.attr.get("size", None) if size is None: - size = context[node.inputs[0]] + size = context[node.inputs[0]].val + if size is None: + msg = 'TensorArrayV3 size must be compile-time known. Got {}' + raise ValueError(msg.format(size)) + init_length = size + if init_length == 0: + # Dynamic list. Use 1 as init_length + init_length = 1 builtin_dtype = node.attr["dtype"] dtype_str = types.builtin_to_string(builtin_dtype) if elem_shape is not None: ls = mb.make_list( - init_length=size, + init_length=init_length, dtype=dtype_str, elem_shape=elem_shape, dynamic_length=dynamic_length, @@ -2481,7 +2589,7 @@ def TensorArrayV3(context, node): ) else: ls = mb.tf_make_list( - init_length=size, + init_length=init_length, dtype=dtype_str, dynamic_length=dynamic_length, name=node.name, diff --git a/coremltools/converters/mil/frontend/tensorflow/parse.py b/coremltools/converters/mil/frontend/tensorflow/parse.py index cad219946..ae6da409e 100644 --- a/coremltools/converters/mil/frontend/tensorflow/parse.py +++ b/coremltools/converters/mil/frontend/tensorflow/parse.py @@ -19,6 +19,7 @@ def parse_type(t): mapping = { + DataType.DT_HALF: types.fp16, DataType.DT_FLOAT: types.float, DataType.DT_DOUBLE: types.double, DataType.DT_INT32: types.int32, diff --git a/coremltools/converters/mil/frontend/tensorflow/ssa_passes/backfill_make_list_elem_type.py b/coremltools/converters/mil/frontend/tensorflow/ssa_passes/backfill_make_list_elem_type.py index d3bf48515..cc173a143 100644 --- a/coremltools/converters/mil/frontend/tensorflow/ssa_passes/backfill_make_list_elem_type.py +++ b/coremltools/converters/mil/frontend/tensorflow/ssa_passes/backfill_make_list_elem_type.py @@ -12,6 +12,7 @@ from coremltools.converters.mil.mil.passes.pass_registry import register_pass from coremltools.converters.mil.mil import Builder as mb from coremltools.converters.mil.mil import types +from coremltools.converters.mil.mil.var import ListVar @register_pass(namespace="tensorflow") @@ -53,11 +54,17 @@ def backfill_make_list_elem_type_block(block): raise ValueError(msg.format(op.name, op.enclosing_block)) with block: + # elem_shape can be runtime-detemrined, which cannot be inferred here at this point, + # so we add an internal _const_symbolic node to cover both static and dynamic cases. + elem_shape_var = mb._const_symbolic( + mode="immediate_value", + val=elem_type.get_shape(), + before_op=op, + ) new_list = mb.make_list( init_length=op.init_length, dynamic_length=op.dynamic_length, - # elem_shape cannot be symbolic by definition of list. - elem_shape=elem_type.get_shape(), + elem_shape=elem_shape_var, dtype=op.inputs["dtype"], before_op=op, name=op.name, @@ -105,6 +112,17 @@ def infer_elem_type(list_var): block_var = block.inputs[idx] elem_type = infer_elem_type(block_var) if elem_type is not None: + + def _set_types_for_block_inputs(block): + block_var = block.inputs[idx] + new_block_var = ListVar(name=block_var.name, elem_type=elem_type, + init_length=block_var.sym_type.T[1], + dynamic_length=block_var.sym_type.T[2]) + block._replace_var(block_var, new_block_var) + + _set_types_for_block_inputs(o.blocks[0]) # condition block + _set_types_for_block_inputs(o.blocks[1]) # body block + return elem_type # otherwise continue to other block_var (a list_var can be # passed into while_loop twice). diff --git a/coremltools/converters/mil/frontend/tensorflow/test/test_composite_ops.py b/coremltools/converters/mil/frontend/tensorflow/test/test_composite_ops.py new file mode 100644 index 000000000..e04c7da14 --- /dev/null +++ b/coremltools/converters/mil/frontend/tensorflow/test/test_composite_ops.py @@ -0,0 +1,68 @@ +# Copyright (c) 2020, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +from coremltools.converters.mil import testing_reqs +from coremltools.converters.mil.testing_reqs import * +from coremltools.converters.mil.frontend.tensorflow.test.testing_utils import ( + make_tf_graph, + run_compare_tf, +) + +# Custom Op imports +from coremltools.converters.mil.frontend.tensorflow.tf_op_registry import register_tf_op + +# Importing _TF_OPS_REGISTRY to ensure `overriding` existing TF op does not break +# testing of default op +# pytest imports all the tests and hence overriding op invokes custom op which is not expected +# In real usecase, importing following is not recommended!! +from coremltools.converters.mil.frontend.tensorflow.tf_op_registry import ( + _TF_OPS_REGISTRY, +) +from coremltools.converters.mil.mil.ops.defs._op_reqs import * +from coremltools.converters.mil.mil import Builder as mb + + +class TestCompositeOp: + @pytest.fixture(scope="class") + def create_custom_selu(self): + default_selu = _TF_OPS_REGISTRY.get("Selu", None) + + @register_tf_op(tf_alias=[], override=True) + def Selu(context, node): + x = context[node.inputs[0]] + alpha = 1.6732631921768188 + lamda = 1.0507010221481323 + out_elu = mb.elu(x=x, alpha=alpha) + out = mb.mul(x=out_elu, y=lamda, name=node.name) + context.add(node.name, out) + + yield + + _TF_OPS_REGISTRY["Selu"] = default_selu + + @pytest.mark.parametrize( + "use_cpu_only, backend, rank", + itertools.product([True, False], backends, list(range(1, 5))), + ) + @pytest.mark.usefixtures("create_custom_selu") + def test_selu(self, use_cpu_only, backend, rank): + input_shape = np.random.randint(low=1, high=6, size=rank) + + @make_tf_graph([input_shape]) + def build_model(x): + return tf.keras.activations.selu(x) + + model, inputs, outputs = build_model + + input_values = [random_gen(input_shape, -10.0, 10.0)] + input_dict = dict(zip(inputs, input_values)) + run_compare_tf( + model, + input_dict, + outputs, + use_cpu_only=use_cpu_only, + frontend_only=False, + backend=backend, + ) diff --git a/coremltools/converters/mil/frontend/tensorflow/test/test_custom_ops.py b/coremltools/converters/mil/frontend/tensorflow/test/test_custom_ops.py index f9d415116..3e204baae 100644 --- a/coremltools/converters/mil/frontend/tensorflow/test/test_custom_ops.py +++ b/coremltools/converters/mil/frontend/tensorflow/test/test_custom_ops.py @@ -136,74 +136,70 @@ def test_tf( ), "Incorrect parameter value k" -# TODO: rdar://61241807 ([MIL] [Polish] Custom layer operator documentation) -# Following logging is to ensure testing of TopK implemented in tf converter -# default path is testing with appropriate conversion function -# Log default tf topk -default_tf_topk = _TF_OPS_REGISTRY.get("TopKV2", None) - - -# Override TopK op with override=True flag -@register_tf_op(tf_alias=["TopKV2"], override=True) -def CustomTopK(context, node): - x = context[node.inputs[0]] - k = context[node.inputs[1]] - sorted = node.attr.get("sorted", False) - x = mb.custom_topk(x=x, k=k.val, axis=-1, sorted=sorted, name=node.name) - context.add(node.name, x) - +class TestCustomTopK: + @pytest.fixture(scope="class") + def create_custom_TopK(self): + # Defining SSA TopK Op + @register_op(doc_str="Custom TopK Layer", is_custom_op=True) + class custom_topk(Operation): + input_spec = InputSpec( + x=TensorInputType(), + k=IntInputType(const=True, default=1), + axis=IntInputType(const=True, default=-1), + sorted=BoolInputType(const=True, default=False), + ) -# Custom TF TopK -custom_tf_topk = _TF_OPS_REGISTRY["TopKV2"] + bindings = { + "class_name": "TopK", + "input_order": ["x"], + "parameters": ["k", "axis", "sorted"], + "description": "Top K Custom layer", + } + def __init__(self, **kwargs): + super(custom_topk, self).__init__(**kwargs) -def _set_tf_op(op_type, _op_func): - _TF_OPS_REGISTRY[op_type] = _op_func + def type_inference(self): + x_type = self.x.dtype + x_shape = self.x.shape + k = self.k.val + axis = self.axis.val + if not is_symbolic(x_shape[axis]) and k > x_shape[axis]: + msg = "K={} is greater than size of the given axis={}" + raise ValueError(msg.format(k, axis)) -class TestCustomTopK: - # Defining SSA TopK Op - @register_op(doc_str="Custom TopK Layer", is_custom_op=True) - class custom_topk(Operation): - input_spec = InputSpec( - x=TensorInputType(), - k=IntInputType(const=True, default=1), - axis=IntInputType(const=True, default=-1), - sorted=BoolInputType(const=True, default=False), - ) + ret_shape = list(x_shape) + ret_shape[axis] = k + return types.tensor(x_type, ret_shape), types.tensor(types.int32, ret_shape) - bindings = { - "class_name": "TopK", - "input_order": ["x"], - "parameters": ["k", "axis", "sorted"], - "description": "Top K Custom layer", - } + # TODO: rdar://61241807 ([MIL] [Polish] Custom layer operator documentation) + # Following logging is to ensure testing of TopK implemented in tf converter + # default path is testing with appropriate conversion function + # Log default tf topk + default_tf_topk = _TF_OPS_REGISTRY.get("TopKV2", None) - def __init__(self, **kwargs): - super(TestCustomTopK.custom_topk, self).__init__(**kwargs) + # Override TopK op with override=True flag + @register_tf_op(tf_alias=["TopKV2"], override=True) + def CustomTopK(context, node): + x = context[node.inputs[0]] + k = context[node.inputs[1]] + sorted = node.attr.get("sorted", False) + x = mb.custom_topk(x=x, k=k.val, axis=-1, sorted=sorted, name=node.name) + context.add(node.name, x) - def type_inference(self): - x_type = self.x.dtype - x_shape = self.x.shape - k = self.k.val - axis = self.axis.val + yield - if not is_symbolic(x_shape[axis]) and k > x_shape[axis]: - msg = "K={} is greater than size of the given axis={}" - raise ValueError(msg.format(k, axis)) + _TF_OPS_REGISTRY["TopKV2"] = default_tf_topk - ret_shape = list(x_shape) - ret_shape[axis] = k - return types.tensor(x_type, ret_shape), types.tensor(types.int32, ret_shape) @pytest.mark.skipif(not testing_reqs._HAS_TF_1, reason=MSG_TF1_NOT_FOUND) @pytest.mark.parametrize( "use_cpu_only, backend, rank, k", itertools.product([True], backends, [rank for rank in range(1, 4)], [1, 2],), ) + @pytest.mark.usefixtures("create_custom_TopK") def test_tf(self, use_cpu_only, backend, rank, k): - # Set TopK to custom TF function - _set_tf_op("TopKV2", custom_tf_topk) shape = np.random.randint(low=3, high=6, size=rank) with tf.Graph().as_default() as graph: x = tf.placeholder(tf.float32, shape=shape) @@ -226,49 +222,4 @@ def test_tf(self, use_cpu_only, backend, rank, k): assert ( True == layers[-1].custom.parameters["sorted"].boolValue ), "Incorrect parameter value for Sorted" - # Set TopK to default conversion function - _set_tf_op("TopKV2", default_tf_topk) - - -default_selu = _TF_OPS_REGISTRY.get("Selu", None) - - -@register_tf_op(tf_alias=[], override=True) -def Selu(context, node): - x = context[node.inputs[0]] - alpha = 1.6732631921768188 - lamda = 1.0507010221481323 - out_elu = mb.elu(x=x, alpha=alpha) - out = mb.mul(x=out_elu, y=lamda, name=node.name) - context.add(node.name, out) - - -composite_selu = _TF_OPS_REGISTRY["Selu"] - -class TestCompositeOp: - @pytest.mark.parametrize( - "use_cpu_only, backend, rank", - itertools.product([True, False], backends, list(range(1, 5))), - ) - def test_selu(self, use_cpu_only, backend, rank): - _set_tf_op("Selu", composite_selu) - input_shape = np.random.randint(low=1, high=6, size=rank) - - @make_tf_graph([input_shape]) - def build_model(x): - return tf.keras.activations.selu(x) - - model, inputs, outputs = build_model - - input_values = [random_gen(input_shape, -10.0, 10.0)] - input_dict = dict(zip(inputs, input_values)) - run_compare_tf( - model, - input_dict, - outputs, - use_cpu_only=use_cpu_only, - frontend_only=False, - backend=backend, - ) - _set_tf_op("Selu", default_selu) diff --git a/coremltools/converters/mil/frontend/tensorflow/test/test_ops.py b/coremltools/converters/mil/frontend/tensorflow/test/test_ops.py index 9c7bba648..228f2baa3 100644 --- a/coremltools/converters/mil/frontend/tensorflow/test/test_ops.py +++ b/coremltools/converters/mil/frontend/tensorflow/test/test_ops.py @@ -9,8 +9,12 @@ make_tf_graph, run_compare_tf, layer_counts, + load_tf_pb, + freeze_g, ) import math +import tempfile +import shutil backends = testing_reqs.backends @@ -21,10 +25,9 @@ class TestDebugging: """ TF converter does not handling debugging nodes, they are expected to be deleted by graph pass before op conversions + in Grappler graph pass: debug_stripper. """ - @pytest.mark.xfail( - reason=" test_v2_ops.py::TestDebugging::test_assert CI failure", run=False) @pytest.mark.parametrize( "use_cpu_only, backend", itertools.product([True, False], backends), @@ -34,7 +37,7 @@ def test_assert(self, use_cpu_only, backend): @make_tf_graph([input_shape]) def build_model(x): - tf.debugging.Assert(tf.reduce_all(tf.greater_equal(x, 0)), [x]) + tf.debugging.Assert(True, [x]) return tf.nn.relu(x) model, inputs, outputs = build_model @@ -205,8 +208,6 @@ def test(self, use_cpu_only, backend, rank, num_inputs): if use_cpu_only is False and rank == 5 and num_inputs == 9: # Failure on this specific parameter set return - if backend == "mil_proto" and rank == 0: - return input_shape = np.random.randint(low=1, high=4, size=rank) input_shapes = [input_shape[:] for _ in range(num_inputs)] @@ -576,7 +577,7 @@ def build_model(x): class TestCond: @pytest.mark.parametrize( - "use_cpu_only, backend", itertools.product([True, False], ["nn_proto"],) + "use_cpu_only, backend", itertools.product([True, False], backends,) ) def test_cond_naive(self, use_cpu_only, backend): @make_tf_graph([(1,), (1,)]) @@ -757,6 +758,23 @@ def build_model(x,y): use_cpu_only=use_cpu_only, backend=backend) + @pytest.mark.parametrize( + "use_cpu_only, backend", itertools.product([True, False], backends) + ) + def test_while_loop_no_entry(self, use_cpu_only, backend): + @make_tf_graph([(1,)]) + def build_model(x): + c = lambda i: tf.greater(tf.math.reduce_mean(i), 5) + b = lambda i: i - 1 + return tf.while_loop(c, b, [x]) + + model, inputs, outputs = build_model + input_values = [np.array([5], dtype=np.float32)] + input_dict = dict(zip(inputs, input_values)) + run_compare_tf( + model, input_dict, outputs, use_cpu_only=use_cpu_only, backend=backend + ) + @pytest.mark.parametrize( "use_cpu_only, backend", itertools.product([True, False], backends) ) @@ -1393,8 +1411,6 @@ def build_model_static_weights(x): if not any([True if d > 1 else False for d in dilations]): test_dynamic_W() - -@pytest.mark.skip(reason="rdar://65198011 (Re-enable Conv3dTranspose and DynamicTile unit tests)") class TestConvTranspose: @pytest.mark.parametrize( ",".join( @@ -1524,6 +1540,7 @@ def build_model(x): [(1, 1, 1)], # Dilation > 1 not supported by TF ), ) + @pytest.mark.skip(reason="rdar://65198011 (Re-enable Conv3dTranspose and DynamicTile unit tests)") def test_conv3d_transpose( self, use_cpu_only, backend, padding, data_format, DHWkDkHkW, strides, dilations ): @@ -1579,319 +1596,185 @@ def build_model(x): class TestElementWiseBinary: @pytest.mark.parametrize( - "use_cpu_only, backend, rank, mode", + "use_cpu_only, backend, rank, tf_op", itertools.product( [True, False], backends, - [rank for rank in range(0, 4)], + [0, 1, 2, 3, 4], [ - "add", - "floor_div", - "floor_mod", - "maximum", - "minimum", - "mod", - "mul", - "pow", - "real_div", - "sub", - "squared_difference", + tf.math.add, + tf.math.floordiv, + tf.math.floormod, + tf.math.maximum, + tf.math.minimum, + tf.math.mod, + tf.math.multiply, + tf.math.pow, + tf.math.truediv, + tf.math.subtract, + tf.math.squared_difference, ], ), ) - def test_binary(self, use_cpu_only, backend, rank, mode): - # TODO: rdar://problem/63030405. Rank 0 tensor for MIL - if rank == 0 and backend == "mil_proto": - return - x_shape = list(np.random.randint(low=2, high=4, size=rank)) - y_shape = x_shape[:] - for i in range(rank): - if np.random.randint(4) == 0: - y_shape[i] = 1 - if np.random.randint(2) == 0: + def test_binary_math(self, use_cpu_only, backend, rank, tf_op): + x_shape = y_shape = list(np.random.randint(low=2, high=4, size=rank)) + + # test broadcasting + case = np.random.choice([0, 1, 2, 3]) + # 0 -> broadcast with one of the inputs is a 0-D tensor (scalar) + # 1 -> broadcast with same rank, some of dimensions are size 1 + # 2 -> broadcast with different rank, extra dimension with size 1 + # 3 -> no broadcast, same type for both inputs + if case == 0: + y_shape = [] + elif case == 1: + y_shape = [1 if np.random.randint(2) == 0 else d for d in y_shape] + elif case == 2: y_shape = [1] + y_shape + # randomly swap x and y + if np.random.randint(2) == 0: + x_shape, y_shape = y_shape, x_shape + # lower precision input data for non-CPU tests dtype = np.float32 if use_cpu_only else np.float16 - if mode == "add": - res = tf.math.add - x_val = random_gen(x_shape, -1000, 1000, dtype=dtype).astype(np.float32) - y_val = random_gen(y_shape, -1000, 1000, dtype=dtype).astype(np.float32) - elif mode == "floor_div": - res = tf.math.floordiv - x_val = random_gen(x_shape, 0, 1000, dtype=dtype).astype(np.float32) - y_val = random_gen(y_shape, 1, 20, dtype=dtype).astype(np.float32) - elif mode == "floor_mod": - res = tf.math.floormod - x_val = random_gen(x_shape, 0, 100, dtype=dtype).astype(np.float32) - y_val = random_gen(y_shape, 1, 20, dtype=dtype).astype(np.float32) - elif mode == "maximum": - res = tf.math.maximum - x_val = random_gen(x_shape, -10, 10, dtype=dtype).astype(np.float32) - y_val = random_gen(y_shape, -10, 10, dtype=dtype).astype(np.float32) - elif mode == "minimum": - res = tf.math.minimum - x_val = random_gen(x_shape, -10, 10, dtype=dtype).astype(np.float32) - y_val = random_gen(y_shape, -10, 10, dtype=dtype).astype(np.float32) - elif mode == "mod": - res = tf.math.mod - x_val = random_gen(x_shape, 0, 1000, dtype=dtype).astype(np.float32) - y_val = random_gen(y_shape, 1, 20, dtype=dtype).astype(np.float32) - elif mode == "mul": - res = tf.math.multiply + if tf_op in {tf.math.add, tf.math.subtract, tf.math.multiply}: x_val = random_gen(x_shape, -100, 100, dtype=dtype).astype(np.float32) y_val = random_gen(y_shape, -100, 100, dtype=dtype).astype(np.float32) - elif mode == "pow": - res = tf.math.pow - x_val = np.random.randint(low=-5, high=5, size=x_shape).astype(np.float32) - y_val = np.random.randint(low=-5, high=5, size=y_shape).astype(np.float32) - elif mode == "real_div": - res = tf.math.truediv - x_val = random_gen(x_shape, 0, 1000, dtype=dtype).astype(np.float32) + elif tf_op in {tf.math.truediv, tf.math.floordiv, tf.math.floormod, tf.math.mod}: + x_val = random_gen(x_shape, -100, 100, dtype=dtype).astype(np.float32) y_val = random_gen(y_shape, 1, 20, dtype=dtype).astype(np.float32) - elif mode == "sub": - res = tf.math.subtract - x_val = random_gen(x_shape, -1000, 1000, dtype=dtype).astype(np.float32) - y_val = random_gen(y_shape, -1000, 1000, dtype=dtype).astype(np.float32) - elif mode == "squared_difference": - if backend == "mil_proto": - return # TODO - res = tf.math.squared_difference - x_val = random_gen(x_shape, -5, 5, dtype=dtype).astype(np.float32) - y_val = random_gen(y_shape, -5, 5, dtype=dtype).astype(np.float32) + elif tf_op in {tf.math.maximum, tf.math.minimum}: + x_val = random_gen(x_shape, -10, 10, dtype=dtype).astype(np.float32) + y_val = random_gen(y_shape, -10, 10, dtype=dtype).astype(np.float32) + elif tf_op in {tf.math.pow, tf.math.squared_difference}: + x_val = random_gen(x_shape, -5, 5, dtype=np.int).astype(np.float32) + y_val = random_gen(y_shape, -5, 5, dtype=np.int).astype(np.float32) + else: + raise NotImplementedError("input values needs to be defined") @make_tf_graph([x_shape, y_shape]) def build_model(x, y): - return res(x, y) + return tf_op(x, y) model, inputs, outputs = build_model input_values = [x_val, y_val] - - input_dict = dict(zip(inputs, input_values)) - - run_compare_tf( - model, input_dict, outputs, use_cpu_only=use_cpu_only, backend=backend - ) - - @pytest.mark.parametrize( - "use_cpu_only, backend, rank", - itertools.product([True, False], backends, [rank for rank in range(0, 4)]), - ) - def test_equal(self, use_cpu_only, backend, rank): - if rank == 0 and backend == "mil_proto": - return - x_shape = list(np.random.randint(low=2, high=4, size=rank)) - y_shape = x_shape[:] - for i in range(rank): - if np.random.randint(4) == 0: - y_shape[i] = 1 - if np.random.randint(2) == 0: - y_shape = [1] + y_shape - - # lower precision input data for non-CPU tests - dtype = np.float32 if use_cpu_only else np.float16 - - @make_tf_graph([x_shape, y_shape]) - def build_model(x, y): - return tf.equal(x, y) - - model, inputs, outputs = build_model - - input_values = [ - random_gen(x_shape, -5, 3, dtype=dtype).astype(np.float32), - random_gen(y_shape, -5, 3, dtype=dtype).astype(np.float32), - ] - input_dict = dict(zip(inputs, input_values)) - run_compare_tf( model, input_dict, outputs, use_cpu_only=use_cpu_only, backend=backend ) @pytest.mark.parametrize( - "use_cpu_only, backend, rank", - itertools.product([True, False], backends, [rank for rank in range(0, 4)]), - ) - def test_greater(self, use_cpu_only, backend, rank): - if rank == 0 and backend == "mil_proto": - return - x_shape = list(np.random.randint(low=2, high=4, size=rank)) - y_shape = x_shape[:] - for i in range(rank): - if np.random.randint(4) == 0: - y_shape[i] = 1 - if np.random.randint(2) == 0: - y_shape = [1] + y_shape - - # lower precision input data for non-CPU tests - dtype = np.float32 if use_cpu_only else np.float16 - - @make_tf_graph([x_shape, y_shape]) - def build_model(x, y): - return tf.greater(x, y) - - model, inputs, outputs = build_model - - input_values = [ - random_gen(x_shape, -5, 3, dtype=dtype).astype(np.float32), - random_gen(y_shape, -5, 3, dtype=dtype).astype(np.float32), - ] - - input_dict = dict(zip(inputs, input_values)) - - run_compare_tf( - model, input_dict, outputs, use_cpu_only=use_cpu_only, backend=backend - ) - - @pytest.mark.parametrize( - "use_cpu_only, backend, rank", - itertools.product([True, False], backends, [rank for rank in range(0, 4)]), + "use_cpu_only, backend, rank, tf_op", + itertools.product( + [True, False], + backends, + [0, 1, 2, 3, 4], + [ + tf.equal, + tf.not_equal, + tf.greater, + tf.greater_equal, + tf.less, + tf.less_equal, + ], + ), ) - def test_greater_equal(self, use_cpu_only, backend, rank): - if rank == 0 and backend == "mil_proto": - return - x_shape = list(np.random.randint(low=2, high=4, size=rank)) - y_shape = x_shape[:] - for i in range(rank): - if np.random.randint(4) == 0: - y_shape[i] = 1 - if np.random.randint(2) == 0: + def test_binary_compare(self, use_cpu_only, backend, rank, tf_op): + x_shape = y_shape = list(np.random.randint(low=2, high=4, size=rank)) + + # test broadcasting + case = np.random.choice([0, 1, 2, 3]) + # 0 -> broadcast with one of the inputs is a 0-D tensor (scalar) + # 1 -> broadcast with same rank, some of dimensions are size 1 + # 2 -> broadcast with different rank, extra dimension with size 1 + # 3 -> no broadcast, same type for both inputs + if case == 0: + y_shape = [] + elif case == 1: + y_shape = [1 if np.random.randint(2) == 0 else d for d in y_shape] + elif case == 2: y_shape = [1] + y_shape - # lower precision input data for non-CPU tests - dtype = np.float32 if use_cpu_only else np.float16 - - @make_tf_graph([x_shape, y_shape]) - def build_model(x, y): - return tf.greater_equal(x, y) - - model, inputs, outputs = build_model - - input_values = [ - random_gen(x_shape, -5, 3, dtype=dtype).astype(np.float32), - random_gen(y_shape, -5, 3, dtype=dtype).astype(np.float32), - ] - - input_dict = dict(zip(inputs, input_values)) - - run_compare_tf( - model, input_dict, outputs, use_cpu_only=use_cpu_only, backend=backend - ) - - @pytest.mark.parametrize( - "use_cpu_only, backend, rank", - itertools.product([True, False], backends, [rank for rank in range(0, 4)]), - ) - def test_less(self, use_cpu_only, backend, rank): - if rank == 0 and backend == "mil_proto": - return - x_shape = list(np.random.randint(low=2, high=4, size=rank)) - y_shape = x_shape[:] - for i in range(rank): - if np.random.randint(4) == 0: - y_shape[i] = 1 + # randomly swap x and y if np.random.randint(2) == 0: - y_shape = [1] + y_shape + x_shape, y_shape = y_shape, x_shape # lower precision input data for non-CPU tests dtype = np.float32 if use_cpu_only else np.float16 @make_tf_graph([x_shape, y_shape]) def build_model(x, y): - return tf.less(x, y) + return tf_op(x, y) model, inputs, outputs = build_model - input_values = [ random_gen(x_shape, -5, 3, dtype=dtype).astype(np.float32), random_gen(y_shape, -5, 3, dtype=dtype).astype(np.float32), ] - input_dict = dict(zip(inputs, input_values)) - run_compare_tf( model, input_dict, outputs, use_cpu_only=use_cpu_only, backend=backend ) @pytest.mark.parametrize( - "use_cpu_only, backend, rank", - itertools.product([True, False], backends, [rank for rank in range(0, 4)]), + "use_cpu_only, backend, rank, tf_op", + itertools.product( + [True, False], + backends, + [0, 1, 2, 3, 4], + [ + tf.math.logical_and, + tf.math.logical_or, + tf.math.logical_xor, + ], + ), ) - def test_less_equal(self, use_cpu_only, backend, rank): - if rank == 0 and backend == "mil_proto": - return - x_shape = list(np.random.randint(low=2, high=4, size=rank)) - y_shape = x_shape[:] - for i in range(rank): - if np.random.randint(4) == 0: - y_shape[i] = 1 - if np.random.randint(2) == 0: + def test_binary_logical(self, use_cpu_only, backend, rank, tf_op): + x_shape = y_shape = list(np.random.randint(low=2, high=4, size=rank)) + + # test broadcasting + case = np.random.choice([0, 1, 2, 3]) + # 0 -> broadcast with one of the inputs is a 0-D tensor (scalar) + # 1 -> broadcast with same rank, some of dimensions are size 1 + # 2 -> broadcast with different rank, extra dimension with size 1 + # 3 -> no broadcast, same type for both inputs + if case == 0: + y_shape = [] + elif case == 1: + y_shape = [1 if np.random.randint(2) == 0 else d for d in y_shape] + elif case == 2: y_shape = [1] + y_shape - # lower precision input data for non-CPU tests - dtype = np.float32 if use_cpu_only else np.float16 - - @make_tf_graph([x_shape, y_shape]) - def build_model(x, y): - return tf.less_equal(x, y) - - model, inputs, outputs = build_model - - input_values = [ - random_gen(x_shape, -5, 3, dtype=dtype).astype(np.float32), - random_gen(y_shape, -5, 3, dtype=dtype).astype(np.float32), - ] - - input_dict = dict(zip(inputs, input_values)) - - run_compare_tf( - model, input_dict, outputs, use_cpu_only=use_cpu_only, backend=backend - ) - - @pytest.mark.parametrize( - "use_cpu_only, backend, rank", - itertools.product([True, False], backends, [rank for rank in range(0, 4)]), - ) - def test_not_equal(self, use_cpu_only, backend, rank): - if rank == 0 and backend == "mil_proto": - return - x_shape = list(np.random.randint(low=2, high=4, size=rank)) - y_shape = x_shape[:] - for i in range(rank): - if np.random.randint(4) == 0: - y_shape[i] = 1 + # randomly swap x and y if np.random.randint(2) == 0: - y_shape = [1] + y_shape - - # lower precision input data for non-CPU tests - dtype = np.float32 if use_cpu_only else np.float16 + x_shape, y_shape = y_shape, x_shape - @make_tf_graph([x_shape, y_shape]) + @make_tf_graph([x_shape + [tf.bool], y_shape + [tf.bool]]) def build_model(x, y): - return tf.not_equal(x, y) + return tf_op(x, y) model, inputs, outputs = build_model - input_values = [ - random_gen(x_shape, -5, 3, dtype=dtype).astype(np.float32), - random_gen(y_shape, -5, 3, dtype=dtype).astype(np.float32), + random_gen(x_shape, 0, 2, dtype=np.int).astype(np.bool), + random_gen(y_shape, 0, 2, dtype=np.int).astype(np.bool), ] - input_dict = dict(zip(inputs, input_values)) - run_compare_tf( model, input_dict, outputs, use_cpu_only=use_cpu_only, backend=backend ) class TestElementWiseUnary: + _FP16_UNSUPPORTED = {'acos', 'asin', 'atan', 'atanh', 'cosh', 'sinh'} + @pytest.mark.parametrize( "use_cpu_only, backend, rank, mode", itertools.product( [True, False], backends, - [rank for rank in range(1, 6)], + [1, 2, 5], [ "abs", "acos", @@ -1922,41 +1805,45 @@ class TestElementWiseUnary: ), ) def test_unary(self, use_cpu_only, backend, rank, mode): + if not use_cpu_only and mode in self._FP16_UNSUPPORTED: + return + atol, rtol = 1e-4, 1e-5 input_shape = np.random.randint(low=2, high=4, size=rank) if use_cpu_only: dtype = np.float32 + tf_dtype = tf.float32 else: dtype = np.float16 + tf_dtype = tf.float16 - with tf.Graph().as_default() as graph: - x = tf.placeholder(tf.float32, shape=input_shape) - if mode == "abs": - res = tf.abs(x) + def cast_func(x): + return tf.cast(x, dtype=tf.int32) + + def clip_func(x): + return tf.clip_by_value(x, clip_value_min=0.0, clip_value_max=5.0) + + def _get_test(test_mode): + if test_mode == "abs": + res = tf.abs val = random_gen(input_shape, rand_min=-1, rand_max=1) - elif mode == "acos": - res = tf.acos(x) + elif test_mode == "acos": + res = tf.acos val = random_gen(input_shape, rand_min=-1, rand_max=1) - elif mode == "asin": - res = tf.asin(x) + elif test_mode == "asin": + res = tf.asin val = random_gen(input_shape, rand_min=-1, rand_max=1) - elif mode == "atan": - res = tf.atan(x) + elif test_mode == "atan": + res = tf.atan val = random_gen(input_shape, rand_min=-100, rand_max=100) - elif mode == "atanh": - if backend == "mil_proto": - # TODO - return - res = tf.atanh(x) + elif test_mode == "atanh": + res = tf.atanh val = random_gen(input_shape, rand_min=-0.9, rand_max=0.9) - elif mode == "cast": - if backend == "mil_proto": - # TODO [MIL] Add cast operation in MIL backend and enable tests - return + elif test_mode == "cast": eps_from_int = 0.0 if not use_cpu_only: eps_from_int = 0.1 - res = tf.cast(x, dtype=tf.int32) + res = cast_func val = random_gen( input_shape, rand_min=-10, @@ -1964,8 +1851,8 @@ def test_unary(self, use_cpu_only, backend, rank, mode): eps_from_int=eps_from_int, dtype=dtype, ) - elif mode == "ceil": - res = tf.ceil(x) + elif test_mode == "ceil": + res = tf.math.ceil eps_from_int = 0.0 if not use_cpu_only: eps_from_int = 0.1 @@ -1976,32 +1863,31 @@ def test_unary(self, use_cpu_only, backend, rank, mode): eps_from_int=eps_from_int, dtype=dtype, ) - elif mode == "clip": - if backend == "mil_proto": - # TODO - return - res = tf.clip_by_value(x, clip_value_min=0.0, clip_value_max=5.0) + elif test_mode == "clip": + if use_cpu_only is False: + return None, None # clip does not support float16 + res = clip_func val = random_gen(input_shape, rand_min=-5, rand_max=10) - elif mode == "cos": - res = tf.cos(x) + elif test_mode == "cos": + res = tf.cos rand_range = 1000 if not use_cpu_only: rand_range = 10 val = random_gen(input_shape, rand_min=-rand_range, rand_max=rand_range) - elif mode == "cosh": - res = tf.cosh(x) + elif test_mode == "cosh": + res = tf.cosh val = random_gen(input_shape, rand_min=-4, rand_max=4) - elif mode == "erf": - res = tf.math.erf(x) + elif test_mode == "erf": + res = tf.math.erf val = random_gen(input_shape, rand_min=1, rand_max=6) - elif mode == "exp": + elif test_mode == "exp": if not use_cpu_only: # We skip GPU here, since exp(1) already differs in backend. - return - res = tf.exp(x) + return None, None + res = tf.exp val = random_gen(input_shape, rand_min=-4, rand_max=20) - elif mode == "floor": - res = tf.floor(x) + elif test_mode == "floor": + res = tf.floor eps_from_int = 0.0 if not use_cpu_only: eps_from_int = 0.1 @@ -2012,63 +1898,75 @@ def test_unary(self, use_cpu_only, backend, rank, mode): eps_from_int=eps_from_int, dtype=dtype, ) - elif mode == "inverse": - if backend == "mil_proto": - return # TODO - res = tf.reciprocal(x) + elif test_mode == "inverse": + res = tf.math.reciprocal val = random_gen(input_shape, rand_min=0.1, rand_max=10) - elif mode == "log": - res = tf.log(x) + elif test_mode == "log": + res = tf.math.log val = random_gen(input_shape, rand_min=0.2, rand_max=1000) - elif mode == "negative": - if backend == "mil_proto": - return # TODO - res = tf.math.negative(x) + elif test_mode == "negative": + res = tf.math.negative val = random_gen(input_shape, rand_min=-100.0, rand_max=100.0) - elif mode == "round": - res = tf.round(x) + elif test_mode == "round": + res = tf.round val = random_gen( input_shape, rand_min=-1000, rand_max=1000, dtype=dtype ) - elif mode == "rsqrt": - res = tf.rsqrt(x) + elif test_mode == "rsqrt": + res = tf.math.rsqrt val = random_gen(input_shape, rand_min=0.5, rand_max=1000) - elif mode == "sign": - res = tf.sign(x) + elif test_mode == "sign": + res = tf.sign val = random_gen(input_shape, rand_min=-5, rand_max=5) - elif mode == "sin": - res = tf.sin(x) + elif test_mode == "sin": + res = tf.sin rand_range = 1000 if not use_cpu_only: rand_range = 10 val = random_gen(input_shape, rand_min=-rand_range, rand_max=rand_range) - elif mode == "sinh": - res = tf.sinh(x) + elif test_mode == "sinh": + res = tf.sinh val = random_gen(input_shape, rand_min=-10, rand_max=10) - elif mode == "sqrt": - res = tf.sqrt(x) + elif test_mode == "sqrt": + res = tf.sqrt val = random_gen(input_shape, rand_min=0.5, rand_max=1000) - elif mode == "square": - res = tf.math.square(x) + elif test_mode == "square": + res = tf.math.square val = random_gen(input_shape, rand_min=-5, rand_max=5) - atol, rtol = 1e-2, 1e-3 - elif mode == "tan": - res = tf.tan(x) + elif test_mode == "tan": + res = tf.tan val = random_gen(input_shape, rand_min=-1000, rand_max=1000) - elif mode == "tanh": - res = tf.tanh(x) + elif test_mode == "tanh": + res = tf.tanh val = random_gen(input_shape, rand_min=-1000, rand_max=1000) - run_compare_tf( - graph, - {x: val}, - res, - use_cpu_only=use_cpu_only, - frontend_only=False, - backend=backend, - atol=atol, - rtol=rtol, - ) + return res, val + + func, input_val = _get_test(mode) + if func is None: + return + + input_type = list(input_shape) + [tf_dtype] + @make_tf_graph([input_type]) + def build_model(x): + return func(x) + + model, inputs, outputs = build_model + + input_dict = dict(zip(inputs, [input_val.astype(dtype)])) + + if mode == "inverse" or mode == "rsqrt": + atol, rtol = 1e-2, 1e-3 + + run_compare_tf( + model, + input_dict, + outputs, + use_cpu_only=use_cpu_only, + backend=backend, + atol=atol, + rtol=rtol + ) class TestImageResizing: @@ -2094,21 +1992,26 @@ def test_resize_bilinear( ): if half_pixel_centers and align_corners: return - with tf.Graph().as_default() as graph: - x = tf.placeholder(tf.float32, shape=input_shape) - ref = tf.raw_ops.ResizeBilinear( - images=x, - size=target_shape, - half_pixel_centers=half_pixel_centers, - align_corners=align_corners, - ) - run_compare_tf( - graph, - {x: random_gen(input_shape, rand_min=-100, rand_max=100)}, - ref, - use_cpu_only=use_cpu_only, - backend=backend, - ) + + @make_tf_graph([input_shape]) + def build_model(x): + return tf.raw_ops.ResizeBilinear( + images=x, + size=target_shape, + half_pixel_centers=half_pixel_centers, + align_corners=align_corners, + ) + + model, inputs, outputs = build_model + input_values = [random_gen(input_shape, -100, 100)] + input_dict = dict(zip(inputs, input_values)) + run_compare_tf( + model, + input_dict, + outputs, + use_cpu_only=use_cpu_only, + backend=backend, + ) @pytest.mark.parametrize( "use_cpu_only, backend, input_shape, upsample_factor, data_format", @@ -2130,18 +2033,23 @@ def test_upsampling_2d( input_shape[3], input_shape[1], ) - with tf.Graph().as_default() as graph: - x = tf.placeholder(tf.float32, shape=input_shape) - ref = tf.keras.layers.UpSampling2D( - size=upsample_factor, data_format=data_format, interpolation="nearest" - )(x) - run_compare_tf( - graph, - {x: random_gen(input_shape, rand_min=-100, rand_max=100)}, - ref, - use_cpu_only=use_cpu_only, - backend=backend, - ) + + @make_tf_graph([input_shape]) + def build_model(x): + return tf.keras.layers.UpSampling2D( + size=upsample_factor, data_format=data_format, interpolation="nearest" + )(x) + + model, inputs, outputs = build_model + input_values = [random_gen(input_shape, -100, 100)] + input_dict = dict(zip(inputs, input_values)) + run_compare_tf( + model, + input_dict, + outputs, + use_cpu_only=use_cpu_only, + backend=backend, + ) @pytest.mark.parametrize( "use_cpu_only, backend, input_shape, num_of_crops, crop_size, method, dynamic", @@ -2165,49 +2073,54 @@ def test_crop_and_resize( method, dynamic, ): - input = np.random.randn(*input_shape) - boxes = np.random.uniform(size=(num_of_crops, 4)) + input = np.random.randn(*input_shape).astype(np.float32) + boxes = np.random.uniform(size=(num_of_crops, 4)).astype(np.float32) box_indices = np.random.randint( size=(num_of_crops,), low=0, high=input_shape[0] - ) + ).astype(np.int32) def test_static(): - with tf.Graph().as_default() as graph: - x = tf.placeholder(tf.float32, shape=input_shape) - output = tf.raw_ops.CropAndResize( + @make_tf_graph([input_shape]) + def build_model(x): + return tf.raw_ops.CropAndResize( image=x, boxes=boxes, box_ind=box_indices, crop_size=crop_size, method=method, ) - run_compare_tf( - graph, - {x: input}, - output, - use_cpu_only=use_cpu_only, - backend=backend, - ) + + model, inputs, outputs = build_model + input_values = [input] + input_dict = dict(zip(inputs, input_values)) + run_compare_tf( + model, + input_dict, + outputs, + use_cpu_only=use_cpu_only, + backend=backend, + ) def test_dynamic(): - with tf.Graph().as_default() as graph: - x = tf.placeholder(tf.float32, shape=input_shape) - boxes_pl = tf.placeholder(tf.float32, shape=boxes.shape) - box_indices_pl = tf.placeholder(tf.int32, shape=box_indices.shape) - output = tf.raw_ops.CropAndResize( + @make_tf_graph([input_shape, boxes.shape, list(box_indices.shape) + [tf.int32]]) + def build_model(x, boxes_pl, box_indices_pl): + return tf.raw_ops.CropAndResize( image=x, boxes=boxes_pl, box_ind=box_indices_pl, crop_size=crop_size, method=method, ) - run_compare_tf( - graph, - {x: input, boxes_pl: boxes, box_indices_pl: box_indices}, - output, - use_cpu_only=use_cpu_only, - backend=backend, - ) + model, inputs, outputs = build_model + input_values = [input, boxes, box_indices] + input_dict = dict(zip(inputs, input_values)) + run_compare_tf( + model, + input_dict, + outputs, + use_cpu_only=use_cpu_only, + backend=backend, + ) test_dynamic() if dynamic else test_static() @@ -2232,25 +2145,33 @@ def test_extract_patches( # but there seems to have a bug in crop_resize when using GPU and batch_size > 1. # We should test batch_size > 1 after the issue is fixed. # - input = np.random.rand(1, height, width, 128) + input = np.random.rand(1, height, width, 128).astype(np.float32) if padding == "VALID": size_h = min(sizes[0], height) size_w = min(sizes[1], width) else: size_h = sizes[0] size_w = sizes[1] - with tf.Graph().as_default() as graph: - x = tf.placeholder(tf.float32, shape=input.shape) - output = tf.extract_image_patches( + + @make_tf_graph([input.shape]) + def build_model(x): + return tf.compat.v1.image.extract_image_patches( images=x, ksizes=[1, size_h, size_w, 1], strides=[1, strides[0], strides[1], 1], rates=[1, 1, 1, 1], padding=padding, ) - run_compare_tf( - graph, {x: input}, output, use_cpu_only=use_cpu_only, backend=backend - ) + model, inputs, outputs = build_model + input_values = [input] + input_dict = dict(zip(inputs, input_values)) + run_compare_tf( + model, + input_dict, + outputs, + use_cpu_only=use_cpu_only, + backend=backend, + ) class TestLinear: @@ -2359,8 +2280,8 @@ def build_model(x, m, v, o, s): outputs, use_cpu_only=use_cpu_only, backend=backend, - atol=1e-2, - rtol=1e-3, + atol=.2, + rtol=1e-4, ) @pytest.mark.parametrize( @@ -2444,8 +2365,8 @@ def build_model(x, m, v, b): outputs, use_cpu_only=use_cpu_only, backend=backend, - atol=1e-2, - rtol=1e-3, + atol=0.2, + rtol=1e-4, ) @@ -2492,7 +2413,7 @@ def build_model(x): rtol=1e-3, ) - @pytest.mark.skip(reason=" Specific failure on CI") +class TestL2Normalization: @pytest.mark.parametrize( "use_cpu_only, backend, rank, axes, epsilon", itertools.product( @@ -2522,10 +2443,11 @@ def build_model(x): outputs, use_cpu_only=use_cpu_only, backend=backend, - atol=1e-2, - rtol=1e-3, + atol=0.05, + rtol=1e-4, ) +class TestLocalResponseNormalization: @pytest.mark.parametrize( "use_cpu_only, backend, size, alpha, beta, k", itertools.product( @@ -2756,7 +2678,7 @@ class TestRandom: ), ) def test_random_binomial(self, use_cpu_only, backend, size, rank, constant): - if not constant and backend == "mil_proto": + if not constant and backend != "nn_proto": return # TODO: rdar://61948178 (MIL backend Random op does not support dynamic input shape) shape = np.random.randint(low=1, high=4, size=rank).astype(np.int32) @@ -2809,7 +2731,7 @@ def test_random_categorical(self, use_cpu_only, backend, size): ), ) def test_random_normal(self, use_cpu_only, backend, mean, rank, constant): - if not constant and backend == "mil_proto": + if not constant and backend != "nn_proto": return # TODO: rdar://61948178 (MIL backend Random op does not support dynamic input shape) shape = np.random.randint(low=1, high=4, size=rank).astype(np.int32) @@ -2843,7 +2765,7 @@ def test_random_normal(self, use_cpu_only, backend, mean, rank, constant): ), ) def test_keras_random_normal(self, use_cpu_only, backend, mean, rank, constant): - if not constant and backend == "mil_proto": + if not constant and backend != "nn_proto": return # TODO: rdar://61948178 (MIL backend Random op does not support dynamic input shape) shape = np.random.randint(low=1, high=4, size=rank).astype(np.int32) @@ -2881,7 +2803,7 @@ def test_keras_random_normal(self, use_cpu_only, backend, mean, rank, constant): ), ) def test_random_uniform(self, use_cpu_only, backend, low, high, rank, constant): - if not constant and backend == "mil_proto": + if not constant and backend != "nn_proto": return # TODO: rdar://61948178 (MIL backend Random op does not support dynamic input shape) shape = np.random.randint(low=1, high=4, size=rank).astype(np.int32) @@ -2918,7 +2840,7 @@ def test_random_uniform(self, use_cpu_only, backend, low, high, rank, constant): def test_keras_random_uniform( self, use_cpu_only, backend, low, high, rank, constant ): - if not constant and backend == "mil_proto": + if not constant and backend != "nn_proto": return # TODO: rdar://61948178 (MIL backend Random op does not support dynamic input shape) shape = np.random.randint(low=1, high=4, size=rank).astype(np.int32) with tf.Graph().as_default() as graph: @@ -2958,15 +2880,17 @@ class TestReduction: (2, (-1, 0)), (3, (1, -3)), (3, (-2,)), + (3, (-3, -2, -1)), (4, (0, 1, 2)), (4, (-2, -1, 0)), (4, (1, -2)), (5, (-3, -1)), + (5, (-2, -1)), + (5, (-3, -2, -1)), (5, (0, -1, 1, -2)), (3, None), (5, None), (3, 1), - (5, -1), ], [True, False], [ @@ -2996,58 +2920,75 @@ def parse_axes(axes): return axes def test_tf_argmax(): - with tf.Graph().as_default() as graph: - x = tf.placeholder(tf.float32, shape=shape) - ref = tf.math.argmax(x, axis=parse_axes(axes)) - run_compare_tf( - graph, - {x: random_gen(shape=shape, rand_min=-5.0, rand_max=5.0)}, - ref, - use_cpu_only=use_cpu_only, - backend=backend, - ) + @make_tf_graph([shape]) + def build_model(x): + return tf.math.argmax(x, axis=parse_axes(axes)) + + model, inputs, outputs = build_model + input_values = [random_gen(shape, rand_min=-5.0, rand_max=5.0)] + input_dict = dict(zip(inputs, input_values)) + run_compare_tf( + model, + input_dict, + outputs, + use_cpu_only=use_cpu_only, + backend=backend, + ) def test_tf_argmin(): - with tf.Graph().as_default() as graph: - x = tf.placeholder(tf.float32, shape=shape) - ref = tf.math.argmin(x, axis=parse_axes(axes)) - run_compare_tf( - graph, - {x: random_gen(shape=shape, rand_min=-5.0, rand_max=5.0)}, - ref, - use_cpu_only=use_cpu_only, - backend=backend, - ) + @make_tf_graph([shape]) + def build_model(x): + return tf.math.argmin(x, axis=parse_axes(axes)) + + model, inputs, outputs = build_model + input_values = [random_gen(shape, rand_min=-5.0, rand_max=5.0)] + input_dict = dict(zip(inputs, input_values)) + run_compare_tf( + model, + input_dict, + outputs, + use_cpu_only=use_cpu_only, + backend=backend, + ) def test_tf_reduction(): if isinstance(axes, list) and axes and len(axes) == rank and not keep_dims: return # TODO MIL: Add rank 0 and dim size 0 related tests for every op - if tf_op in {tf.reduce_any, tf.reduce_all, tf.reduce_logsumexp} and backend != "nn_proto": # Remove backend constraint, rdar://66610973 + if tf_op in {tf.reduce_any, tf.reduce_all, tf.reduce_logsumexp}: # Remove constraint, rdar://66610973 return - with tf.Graph().as_default() as graph: - x = tf.placeholder(tf.float32, shape=shape) - x_val = random_gen(shape=shape, rand_min=-5.0, rand_max=5.0) - if tf_op in {tf.reduce_all, tf.reduce_any}: - x = tf.placeholder(tf.bool, shape=shape) - x_val = np.random.randint(low=0, high=2, size=shape).astype( - np.float32 - ) - elif tf_op in {tf.math.reduce_euclidean_norm}: - x_val = random_gen(shape=shape, rand_min=0.0, rand_max=10.0) - elif tf_op in {tf.reduce_prod}: - x_val = random_gen(shape=shape, rand_min=1.0, rand_max=1.5) - elif tf_op in {tf.reduce_logsumexp}: - x_val = random_gen(shape=shape, rand_min=-5, rand_max=5) + input_type = list(shape) + x_val = random_gen(shape=shape, rand_min=-5.0, rand_max=5.0) + if tf_op in {tf.reduce_all, tf.reduce_any}: + input_type += [tf.bool] + x_val = np.random.randint(low=0, high=2, size=shape).astype( + np.float32 + ) + elif tf_op in {tf.math.reduce_euclidean_norm}: + x_val = random_gen(shape=shape, rand_min=0.0, rand_max=10.0) + elif tf_op in {tf.reduce_prod}: + x_val = random_gen(shape=shape, rand_min=1.0, rand_max=1.5) + elif tf_op in {tf.reduce_logsumexp}: + x_val = random_gen(shape=shape, rand_min=-5, rand_max=5) + + @make_tf_graph([input_type]) + def build_model(x): ref = tf_op(x, axis=axes, keepdims=keep_dims) - if tf_op == tf.reduce_any: ref = tf.cast(ref, tf.float32) + return ref - run_compare_tf( - graph, {x: x_val}, ref, use_cpu_only=use_cpu_only, backend=backend - ) + model, inputs, outputs = build_model + input_values = [random_gen(shape, rand_min=-5.0, rand_max=5.0)] + input_dict = dict(zip(inputs, [x_val])) + run_compare_tf( + model, + input_dict, + outputs, + use_cpu_only=use_cpu_only, + backend=backend, + ) if tf_op in {tf.math.argmax}: test_tf_argmax() @@ -3651,6 +3592,106 @@ def build_model(x): backend=backend) +class TestFakeQuant: + @pytest.mark.parametrize( + "num_bits, weight_boundaries, use_cpu_only, backend", + itertools.product( + [bits for bits in range(2, 9)], # TensorFlow does not support 1-bit quantization + [(-10, 0), (0, 10), (-0.01, 0.02), (-0.001, 0.003), (-101, 100)], + [True, False], + backends, + ), + ) + def test_fake_quant_weight_quantization_with_conv(self, num_bits, weight_boundaries, use_cpu_only, backend): + tf.reset_default_graph() + filter_width = 1 + filter_height = 1 + spatial_size = 2 + input_channels = 3 + output_channels = 1 + input_tensor = tf.placeholder(tf.float32, [1, spatial_size, spatial_size, input_channels], name='input') + output_tensor = tf.placeholder(tf.float32, [1, spatial_size, spatial_size, output_channels], name='output') + kernel_in = random_gen((filter_width, filter_height), weight_boundaries[0], weight_boundaries[1]) + init = tf.constant_initializer(kernel_in) + + def model(x): + with tf.compat.v1.variable_scope('quantized_model'): + x = tf.layers.conv2d(x, filters=3, kernel_size=1, strides=1, kernel_initializer=init) + return x + + with tf.compat.v1.variable_scope('quantize'): + output = model(x=input_tensor) + tf.contrib.quantize.experimental_create_training_graph(quant_delay=0, weight_bits=num_bits, + activation_bits=num_bits) + loss = tf.losses.mean_squared_error(labels=input_tensor, predictions=output) + saver = tf.train.Saver() + + update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) + with tf.control_dependencies(update_ops): + optimizer = tf.train.AdamOptimizer().minimize(loss) + + checkpoint_dir = tempfile.mkdtemp() + # Run training pass to retrieve the correct min and max in FakeQuant op (to avoid using default values) and + # save dummy checkpoint. + with tf.Session() as sess: + tf.global_variables_initializer().run() + for iter in range(1): + image = np.random.rand(spatial_size, spatial_size, input_channels).astype(np.float32) * 255 + label = np.random.rand(spatial_size, spatial_size, output_channels).astype(np.float32) * 255 + training_loss, _ = sess.run([loss, optimizer], feed_dict={input_tensor: image[None, ...], + output_tensor: label[None, ...]}) + + saver.save(sess=sess, save_path=os.path.join(checkpoint_dir, 'quantization')) + + with tf.Graph().as_default() as g: + input_tensor = tf.placeholder(tf.float32, [1, spatial_size, spatial_size, input_channels], name='input') + with tf.variable_scope('quantize'): + output = model(x=input_tensor) + + # define eval graph, by quantizing the weights of the model with learned min/max values for each layer + tf.contrib.quantize.experimental_create_eval_graph(input_graph=g, weight_bits=num_bits, + activation_bits=num_bits) + with open('tf_graph.pb', 'wb') as f: + f.write(g.as_graph_def().SerializeToString()) + freeze_g(input_graph="tf_graph.pb", + input_saver="", + input_binary=True, + input_checkpoint=os.path.join(checkpoint_dir, 'quantization'), + output_node_names="quantize/quantized_model/conv2d/Conv2D", + restore_op_name="save/restore_all", + filename_tensor_name="save/Const:0", + output_graph="frozen_graph_quantized.pb", + clear_devices=True, + initializer_nodes="") + shutil.rmtree(checkpoint_dir) + + graph = load_tf_pb("frozen_graph_quantized.pb") + + tf.reset_default_graph() + graphdef = tf.GraphDef() + input_dict = {} + with open("frozen_graph_quantized.pb", "rb") as f: + graphdef.ParseFromString(f.read()) + with tf.Graph().as_default(), tf.Session(config=None) as sess: + tf.graph_util.import_graph_def(graphdef, name='') + input_dict[sess.graph.get_tensor_by_name('input:0')] = (np.random.rand(1, spatial_size, spatial_size, + input_channels).astype(np.float32)) + outputs = [] + outputs.append(sess.graph.get_tensor_by_name('quantize/quantized_model/conv2d/Conv2D:0')) + tf_outs = sess.run(outputs, feed_dict=input_dict) + + run_compare_tf( + graph, + input_dict, + ["quantize/quantized_model/conv2d/Conv2D"], + use_cpu_only=use_cpu_only, + frontend_only=False, + backend=backend, + tf_outputs=tf_outs, + rtol=0.005, + ) + + class TestFill: @pytest.mark.parametrize( "use_cpu_only, backend, rank, value", @@ -3661,6 +3702,7 @@ class TestFill: def test_fill(self, use_cpu_only, backend, rank, value): def test_tf_static(): shape = np.random.randint(low=1, high=3, size=rank) + @make_tf_graph([shape]) def build_model(x): return tf.add( @@ -3817,8 +3859,6 @@ class TestPad: ) ) def test(self, use_cpu_only, backend, rank, mode, dynamic, trial): - if backend == "mil_proto" and dynamic: - return input_shape = np.random.randint(low=2, high=10, size=rank) min_input_dim_size = input_shape.min() padding_val = np.random.randint(low=0, high=min_input_dim_size, size=(rank, 2), dtype=np.int32) @@ -3870,10 +3910,8 @@ class TestPadV2: ) ) def test(self, use_cpu_only, backend, rank, constant_values, dynamic, trial): - if backend == "mil_proto" and dynamic: - return input_shape = np.random.randint(low=2, high=10, size=rank) - paddings = np.random.randint(low=2, high=5, size=2*rank) + paddings = np.random.randint(low=2, high=5, size=2*rank).astype(np.int32) padding_val = paddings.reshape(-1,2) if dynamic: padding_shape = padding_val.shape @@ -3919,42 +3957,52 @@ class TestRange: ), ) def test_range(self, use_cpu_only, backend, params): - start, end, step = params - with tf.Graph().as_default() as graph: - limit = tf.placeholder(tf.float32) - res = tf.range(start=start, limit=limit, delta=step) - run_compare_tf( - graph, - {limit: end}, - res, - use_cpu_only=use_cpu_only, - frontend_only=False, - backend=backend, - ) + start, end, step = np.array(params).astype(np.float32) - with tf.Graph().as_default() as graph: - delta = tf.placeholder(tf.float32) - res = tf.range(start=start, limit=end, delta=delta) - run_compare_tf( - graph, - {delta: step}, - res, - use_cpu_only=use_cpu_only, - frontend_only=False, - backend=backend, - ) + @make_tf_graph([[tf.float32]]) + def build_model(limit): + return tf.range(start=start, limit=limit, delta=step) - with tf.Graph().as_default() as graph: - begin = tf.placeholder(tf.float32) - res = tf.range(start=begin, limit=end, delta=step) - run_compare_tf( - graph, - {begin: start}, - res, - use_cpu_only=use_cpu_only, - frontend_only=False, - backend=backend, - ) + model, inputs, outputs = build_model + input_values = [end] + input_dict = dict(zip(inputs, input_values)) + run_compare_tf( + model, + input_dict, + outputs, + use_cpu_only=use_cpu_only, + backend=backend, + ) + + @make_tf_graph([[tf.float32]]) + def build_model(delta): + return tf.range(start=start, limit=end, delta=delta) + + model, inputs, outputs = build_model + input_values = [step] + input_dict = dict(zip(inputs, input_values)) + run_compare_tf( + model, + input_dict, + outputs, + use_cpu_only=use_cpu_only, + backend=backend, + ) + + @make_tf_graph([[tf.float32]]) + def build_model(begin): + return tf.range(start=begin, limit=end, delta=step) + + model, inputs, outputs = build_model + input_values = [start] + input_dict = dict(zip(inputs, input_values)) + run_compare_tf( + model, + input_dict, + outputs, + use_cpu_only=use_cpu_only, + backend=backend, + ) class TestTile: @@ -3980,18 +4028,21 @@ class TestTile: def test_tile(self, use_cpu_only, backend, rank_and_reps): rank, reps = rank_and_reps x_shape = np.random.randint(low=2, high=4, size=rank) - with tf.Graph().as_default() as graph: - x = tf.placeholder(tf.float32, shape=x_shape) - res = tf.tile(x, multiples=reps) - run_compare_tf( - graph, - {x: np.random.rand(*x_shape)}, - res, - use_cpu_only=use_cpu_only, - frontend_only=False, - backend=backend, - ) + @make_tf_graph([x_shape]) + def build_model(x): + return tf.tile(x, multiples=reps) + + model, inputs, outputs = build_model + input_values = [random_gen(x_shape)] + input_dict = dict(zip(inputs, input_values)) + run_compare_tf( + model, + input_dict, + outputs, + use_cpu_only=use_cpu_only, + backend=backend, + ) @pytest.mark.skip(reason="rdar://65198011 (Re-enable Conv3dTranspose and DynamicTile unit tests)") class TestDynamicTile: @@ -4026,32 +4077,37 @@ class TestTopK: def test_top_k(self, use_cpu_only, backend, rank, k): # TensorFlow only supports last dimension (axis = -1). shape = np.random.randint(low=3, high=4, size=rank) - with tf.Graph().as_default() as graph: - x = tf.placeholder(tf.float32, shape=shape) + + @make_tf_graph([shape]) + def build_model(x): ref = tf.math.top_k(x, k=k, sorted=True) - ref = (ref[1], ref[0]) - run_compare_tf( - graph, - {x: random_gen(shape, rand_min=-100, rand_max=100)}, - ref, - use_cpu_only=use_cpu_only, - backend=backend, - ) + return (ref[1], ref[0]) + + model, inputs, outputs = build_model + input_values = [random_gen(shape, rand_min=-100, rand_max=100)] + input_dict = dict(zip(inputs, input_values)) + run_compare_tf( + model, + input_dict, + outputs, + use_cpu_only=use_cpu_only, + backend=backend, + ) class TestConcat: @pytest.mark.parametrize("use_cpu_only, backend, op_version, rank, num_inputs", itertools.product( [True, False], - ['nn_proto'], + backends, ['v1', 'v2'], - [1,2,3,4,5], + list(range(6)), list(range(1, 4)), )) def test_concat(self, use_cpu_only, backend, op_version, rank, num_inputs): import random - for axis in range(-rank,rank): - input_shape = np.random.randint(low=1, high=6, size=rank) + for axis in range(-rank, rank): + input_shape = np.random.randint(low=1, high=4, size=rank) input_shapes = [input_shape.copy() for _ in range(num_inputs)] concat_axis_value = np.random.randint(low=1, high=3, size=num_inputs) for i, v in enumerate(concat_axis_value): @@ -4059,10 +4115,10 @@ def test_concat(self, use_cpu_only, backend, op_version, rank, num_inputs): @make_tf_graph(input_shapes) def build_model(*inputs): - # add 5 zero size constants + # add 3 additional tensor contains dimension size of 0 zero_shape = input_shape.copy() zero_shape[axis] = 0 - const = [tf.constant([], shape=zero_shape) for _ in range(5)] + const = [tf.constant([], shape=zero_shape) for _ in range(3)] values = inputs + tuple(const) values = list(values) random.shuffle(values) @@ -4075,13 +4131,11 @@ def build_model(*inputs): return res model, inputs, outputs = build_model - - input_values = [np.random.rand(*shape).astype(np.float32) for shape in input_shapes] + input_values = [random_gen(shape) for shape in input_shapes] input_dict = dict(zip(inputs, input_values)) - run_compare_tf(model, input_dict, outputs, - use_cpu_only=use_cpu_only, - frontend_only=False, backend=backend) + use_cpu_only=use_cpu_only, + frontend_only=False, backend=backend) class TestSplit: @@ -4179,27 +4233,67 @@ def build_model(x): class TestStack: @pytest.mark.parametrize( - "use_cpu_only, backend", itertools.product([True, False], backends,) + "use_cpu_only, backend", itertools.product([True, False], backends, ) ) def test_stack(self, use_cpu_only, backend): input_shape1 = [3, 1, 1] input_shape2 = [3, 1, 1] - with tf.Graph().as_default() as graph: - x = tf.placeholder(tf.float32, shape=input_shape1) - y = tf.placeholder(tf.float32, shape=input_shape2) - res = [tf.stack((x, y), axis=0), tf.stack((x, y), axis=-1)] - inputs = { - x: np.random.rand(*input_shape1), - y: np.random.rand(*input_shape2), - } - run_compare_tf( - graph, - inputs, - res, - use_cpu_only=use_cpu_only, - frontend_only=False, - backend=backend, - ) + + @make_tf_graph([input_shape1, input_shape2]) + def build_model(x, y): + return [tf.stack((x, y), axis=0), tf.stack((y, x), axis=-1)] + + model, inputs, outputs = build_model + input_values = [random_gen(input_shape1), random_gen(input_shape2)] + input_dict = dict(zip(inputs, input_values)) + run_compare_tf( + model, + input_dict, + outputs, + use_cpu_only=use_cpu_only, + backend=backend, + ) + +class TestUnstack: + @pytest.mark.parametrize( + "use_cpu_only, backend, shape", itertools.product([True, False], backends, [[3, 1], [4, 3]],) + ) + def test_unstack(self, use_cpu_only, backend, shape): + @make_tf_graph([shape]) + def build_model(x): + return tf.unstack(x, axis=1) + + model, inputs, outputs = build_model + input_values = [random_gen(shape)] + input_dict = dict(zip(inputs, input_values)) + run_compare_tf( + model, + input_dict, + outputs, + use_cpu_only=use_cpu_only, + backend=backend, + ) + + @pytest.mark.parametrize( + "use_cpu_only, backend, shape", itertools.product([True, False], backends, [[3, 1], [4, 3]]) + ) + + def test_unstack_and_stack(self, use_cpu_only, backend, shape): + @make_tf_graph([shape]) + def build_model(x): + x = tf.unstack(x, axis=1) + return tf.stack(x) + + model, inputs, outputs = build_model + input_values = [random_gen(shape)] + input_dict = dict(zip(inputs, input_values)) + run_compare_tf( + model, + input_dict, + outputs, + use_cpu_only=use_cpu_only, + backend=backend, + ) class TestPack: @@ -4243,28 +4337,27 @@ class TestArgSort: ) def test_argsort(self, use_cpu_only, backend, rank, axis, direction): shape = np.random.randint(low=1, high=4, size=rank) - with tf.Graph().as_default() as graph: - x = tf.placeholder(tf.float32, shape=shape) - ref = tf.argsort(x, axis=axis, direction=direction.upper()) - if use_cpu_only: - dtype = np.float32 - else: - dtype = np.float16 - run_compare_tf( - graph, - { - x: random_gen( - shape, - rand_min=-100, - rand_max=100, - allow_duplicate=False, - dtype=dtype, - ) - }, - ref, - use_cpu_only=use_cpu_only, - backend=backend, - ) + if use_cpu_only: + dtype = np.float32 + tf_dtype = tf.float32 + else: + dtype = np.float16 + tf_dtype = tf.float16 + + @make_tf_graph([list(shape) + [tf_dtype]]) + def build_model(x): + return tf.argsort(x, axis=axis, direction=direction.upper()) + + model, inputs, outputs = build_model + input_values = [random_gen(shape, rand_min=-100, rand_max=100, allow_duplicate=False, dtype=dtype)] + input_dict = dict(zip(inputs, input_values)) + run_compare_tf( + model, + input_dict, + outputs, + use_cpu_only=use_cpu_only, + backend=backend + ) class TestDepthToSpace: @@ -4991,6 +5084,40 @@ def build_model(x): backend=backend, ) + @pytest.mark.parametrize( + "use_cpu_only, backend", itertools.product([True, False], backends) + ) + def test_tf_dynamic_elem_shape(self, use_cpu_only, backend): + # Suport dynamic elem_shape in mil proto + if backend == "mil_proto": + return + + # TF1: TensorArrayV3, TensorArrayWriteV3, TensorArrayScatterV3, + # TensorArraySizeV3, TensorArrayGatherV3 + # TF2: TensorListReserve, TensorListLength, TensorListSetItem, + # TensorListScatterIntoExistingList, TensorListStack, + # TensorListResize + elem_shape = (None, None) + + @make_tf_graph([elem_shape]) + def build_model(x): + ta = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True) + ta = ta.write(10, x) + ta = ta.write(9, x) + ta = ta.scatter([3], tf.expand_dims(x, 0)) + ta = ta.scatter([8], tf.expand_dims(x, 0)) + + return ta.stack() + + model, inputs, outputs = build_model + input_values = [random_gen((2,3))] + input_dict = dict(zip(inputs, input_values)) + run_compare_tf( + model, + input_dict, outputs, + use_cpu_only=use_cpu_only, + frontend_only=False, backend=backend) + @pytest.mark.skip( reason="[NNv2 TensorArray scatter returns wrong result](rdar://63345281)" ) diff --git a/coremltools/converters/mil/frontend/tensorflow/test/test_parse.py b/coremltools/converters/mil/frontend/tensorflow/test/test_parse.py index 4d21e2bbb..41adf5095 100644 --- a/coremltools/converters/mil/frontend/tensorflow/test/test_parse.py +++ b/coremltools/converters/mil/frontend/tensorflow/test/test_parse.py @@ -118,7 +118,7 @@ def compare(expected, tf_type): compare(None, types.DataType.DT_QUINT16) compare(mil_types.uint16, types.DataType.DT_UINT16) compare(None, types.DataType.DT_COMPLEX128) - compare(None, types.DataType.DT_HALF) + compare(mil_types.fp16, types.DataType.DT_HALF) compare(None, types.DataType.DT_RESOURCE) compare(None, types.DataType.DT_VARIANT) compare(mil_types.uint32, types.DataType.DT_UINT32) diff --git a/coremltools/converters/mil/frontend/tensorflow/test/testing_utils.py b/coremltools/converters/mil/frontend/tensorflow/test/testing_utils.py index bfe63acbd..5c216be8c 100644 --- a/coremltools/converters/mil/frontend/tensorflow/test/testing_utils.py +++ b/coremltools/converters/mil/frontend/tensorflow/test/testing_utils.py @@ -27,7 +27,7 @@ def make_tf_graph(input_types): Parameters ---------- - input_types: list of tuple + input_types: list of tuple or list of list List of input types. E.g. [(3, 224, 224, tf.int32)] represent 1 input, with shape (3, 224, 224), and the expected data type is tf.int32. The dtype is optional, in case it's missing, tf.float32 will be used. @@ -177,6 +177,7 @@ def run_compare_tf( rtol=1e-05, validate_shapes_only=False, freeze_graph=False, + tf_outputs=None, ): """ Utility function to convert and compare a given TensorFlow 1.x model. @@ -203,6 +204,8 @@ def run_compare_tf( The relative tolerance parameter. validate_shapes_only: bool If true, skip element-wise value comparision. + tf_outputs: float or list[float] + If present, use it as TensorFlow predictions """ proto, input_key_values, output_names, output_nodes = tf_graph_to_proto( graph, feed_dict, output_nodes, frontend, backend @@ -247,9 +250,10 @@ def run_compare_tf( graph, feed_dict, output_nodes, frontend, backend ) else: - with tf.Session(graph=graph) as sess: - sess.run(tf.global_variables_initializer()) - tf_outputs = sess.run(output_nodes, feed_dict=feed_dict) + if not tf_outputs: + with tf.Session(graph=graph) as sess: + sess.run(tf.global_variables_initializer()) + tf_outputs = sess.run(output_nodes, feed_dict=feed_dict) expected_outputs = {name: val for name, val in zip(output_names, tf_outputs)} for k,v in input_key_values.items(): diff --git a/coremltools/converters/mil/frontend/tensorflow/tf_graph_pass/__init__.py b/coremltools/converters/mil/frontend/tensorflow/tf_graph_pass/__init__.py index 706c12bb0..0411d896d 100644 --- a/coremltools/converters/mil/frontend/tensorflow/tf_graph_pass/__init__.py +++ b/coremltools/converters/mil/frontend/tensorflow/tf_graph_pass/__init__.py @@ -17,6 +17,7 @@ # graph passes from .delete_asserts import delete_asserts from .constant_propagation import constant_propagation +from .quantization_pass import quantization_pass from .variable_node_transform import remove_variable_nodes from .functionalize_loops import functionalize_loops from .cond_to_where import cond_to_where diff --git a/coremltools/converters/mil/frontend/tensorflow/tf_graph_pass/insert_get_tuple.py b/coremltools/converters/mil/frontend/tensorflow/tf_graph_pass/insert_get_tuple.py index c10af592d..6c6b9892f 100644 --- a/coremltools/converters/mil/frontend/tensorflow/tf_graph_pass/insert_get_tuple.py +++ b/coremltools/converters/mil/frontend/tensorflow/tf_graph_pass/insert_get_tuple.py @@ -67,7 +67,7 @@ def make_op(input_node, index, new_node_name, gto_make_op_cache): "TensorArrayV3", "Const", ] - inclusions = ["Split", "SplitV", "LSTMBlockCell"] + inclusions = ["Split", "SplitV", "LSTMBlockCell", "TopK", "TopKV2", "Unpack"] gto_make_op_cache = {} for name in list(gddict.keys()): new_node = ParsedTFNode() diff --git a/coremltools/converters/mil/frontend/tensorflow/tf_graph_pass/quantization_pass.py b/coremltools/converters/mil/frontend/tensorflow/tf_graph_pass/quantization_pass.py new file mode 100644 index 000000000..c26fd8edb --- /dev/null +++ b/coremltools/converters/mil/frontend/tensorflow/tf_graph_pass/quantization_pass.py @@ -0,0 +1,70 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2020, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +from __future__ import print_function as _ +from __future__ import division as _ +from __future__ import absolute_import as _ +from ..basic_graph_ops import delete_node +import logging +import sys + +def delete_fakequant_node_and_repair_graph(g, node): + inputs = node.inputs + # Delete const inputs of the fakequant op + for i in inputs: + if g[i].op == 'Const': + delete_node(g, i) + else: + non_const_input = i + outputs = node.outputs + # Append FakeQuant Op's outputs to its input node's outputs + g[non_const_input].outputs = [i for i in g[non_const_input].outputs if i != node.name] + g[non_const_input].outputs.extend(outputs) + # Modify the FakeQuant op's outputs to set FakeQuant op's parent node as the new input. + for i in outputs: + for j in range(len(g[i].inputs)): + if g[i].inputs[j] == node.name: + g[i].inputs[j] = non_const_input + delete_node(g, node) + +def quantization_pass_impl(fn): + all_quantization_ops = [i for i in fn.graph.values() if "FakeQuant" in i.op] + for node in all_quantization_ops: + is_const_input = True + for input in node.inputs: + if fn.graph[input].op != 'Const': + is_const_input = False + if not is_const_input and ('weights_quant' not in input): + # If activation quantization - + # Delete the FakeQuant op and its const inputs, + # Append FakeQuant Op's outputs to its input node's outputs, + # Modify the FakeQuant op's outputs to reflect the 'new' input node. + delete_fakequant_node_and_repair_graph(fn.graph, node) + else: + # If weight quantization - + # Add attributes of the FakeQuant op to its output's attr dict + for output in node.outputs: + output_node = fn.graph[output] + output_node.attr['quantize'] = True + output_node.attr['num_bits'] = node.attr['num_bits'] + output_node.attr['narrow_range'] = node.attr['narrow_range'] + output_node.attr['quantize_min'] = fn.graph[node.inputs[1]].value.val + output_node.attr['quantize_max'] = fn.graph[node.inputs[2]].value.val + +def quantization_pass(tfssa): + """ + Delete activation quantization ops and repair TF graph: + If the FakeQuant op is not connected to constant inputs (which means that the op performs activation + quantization) then delete that FakeQuant op and repair the graph. + Edit weight quantization ops: + If the FakeQuant op is connected to constant inputs then add its attributes to its output op so that parameters + min, max, narrow_range, num_bits are available (in addition to weights) to downstream ops for denoting and + supporting weight quantization. + """ + for v in tfssa.functions.values(): + quantization_pass_impl(v) + print("pass completed") diff --git a/coremltools/converters/mil/frontend/tensorflow2/load.py b/coremltools/converters/mil/frontend/tensorflow2/load.py index 36fd5d31d..95e37fd54 100644 --- a/coremltools/converters/mil/frontend/tensorflow2/load.py +++ b/coremltools/converters/mil/frontend/tensorflow2/load.py @@ -12,26 +12,10 @@ import logging as _logging import os.path as _os_path -from six import string_types as _string_types -from tqdm import tqdm as _tqdm import tensorflow as _tf - -from tensorflow.python.framework import dtypes as _dtypes -from tensorflow.python.framework.convert_to_constants import ( - convert_variables_to_constants_v2 as _convert_variables_to_constants_v2, -) -from tensorflow.python.framework.function_def_to_graph import ( - function_def_to_graph as _function_def_to_graph, -) -from tensorflow.python.keras.saving import saving_utils as _saving_utils - -from tensorflow.lite.python.util import ( - run_graph_optimizations as _run_graph_optimizations, -) -from tensorflow.lite.python.util import get_grappler_config as _get_grappler_config - -from .converter import TF2Converter from coremltools.converters.mil.frontend.tensorflow.basic_graph_ops import fill_outputs +from coremltools.converters.mil.frontend.tensorflow.load import TFLoader +from coremltools.converters.mil.frontend.tensorflow.parsed_tf_node import ParsedTFNode from coremltools.converters.mil.frontend.tensorflow.tf_graph_pass import ( constant_propagation, remove_variable_nodes, @@ -40,16 +24,30 @@ delete_disconnected_nodes, fuse_dilation_conv, ) +from coremltools.converters.mil.frontend.tensorflow.tfssa import ( + NetworkEnsemble, + SSAFunction, +) from coremltools.converters.mil.frontend.tensorflow2.tf_graph_pass import ( flatten_sub_graph_namespaces, rewrite_control_flow_functions, ) -from coremltools.converters.mil.frontend.tensorflow.tfssa import ( - NetworkEnsemble, - SSAFunction, +from six import string_types as _string_types +from tensorflow.lite.python.util import get_grappler_config as _get_grappler_config +from tensorflow.lite.python.util import ( + run_graph_optimizations as _run_graph_optimizations, ) -from coremltools.converters.mil.frontend.tensorflow.parsed_tf_node import ParsedTFNode -from coremltools.converters.mil.frontend.tensorflow.load import TFLoader +from tensorflow.python.framework import dtypes as _dtypes +from tensorflow.python.framework.convert_to_constants import ( + convert_variables_to_constants_v2 as _convert_variables_to_constants_v2, +) +from tensorflow.python.framework.function_def_to_graph import ( + function_def_to_graph as _function_def_to_graph, +) +from tensorflow.python.keras.saving import saving_utils as _saving_utils +from tqdm import tqdm as _tqdm + +from .converter import TF2Converter class TF2Loader(TFLoader): @@ -261,23 +259,22 @@ def _dict_from_graph_def(graph, fn_name="main", sg_input_shapes=None): for name, sg in graph._functions.items(): sg_def = sg.definition - input_shapes = sg_input_shapes[name] - input_shapes = input_shapes[-len(sg_def.signature.input_arg) :] - fn_graph = _function_def_to_graph(sg_def, input_shapes=input_shapes) - - graph_dict.update( - TF2Loader._dict_from_graph_def(fn_graph, name, sg_input_shapes)[0] - ) - graph_inputs.update({name: [t.name.split(":")[0] for t in fn_graph.inputs]}) - graph_outputs.update( - {name: [t.name.split(":")[0] for t in fn_graph.outputs]} - ) - - # ret is a mapping from the output arg names from `signature` to the - # outputs from `node_def` that should be returned by the function. - sg_def_ret = sg_def.ret - sg_def_ret["identity_0"] = sg_def_ret.pop("identity") - graph_ret.update({name: sg_def_ret}) + if name in sg_input_shapes: + input_shapes = sg_input_shapes[name] + input_shapes = input_shapes[-len(sg_def.signature.input_arg):] + fn_graph = _function_def_to_graph(sg_def, input_shapes=input_shapes) + + graph_dict.update( + TF2Loader._dict_from_graph_def(fn_graph, name, sg_input_shapes)[0] + ) + graph_inputs.update({name: [t.name.split(":")[0] for t in fn_graph.inputs]}) + graph_outputs.update( + {name: [t.name.split(":")[0] for t in fn_graph.outputs]} + ) + + # ret is a mapping from the output arg names from `signature` to the + # outputs from `node_def` that should be returned by the function. + graph_ret.update({name: sg_def.ret}) return graph_dict, graph_inputs, graph_outputs, graph_ret diff --git a/coremltools/converters/mil/frontend/tensorflow2/test/test_v2_composite_ops.py b/coremltools/converters/mil/frontend/tensorflow2/test/test_v2_composite_ops.py new file mode 100644 index 000000000..c9535a747 --- /dev/null +++ b/coremltools/converters/mil/frontend/tensorflow2/test/test_v2_composite_ops.py @@ -0,0 +1,30 @@ +# Copyright (c) 2020, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +from coremltools.converters.mil import testing_reqs +from coremltools.converters.mil.frontend.tensorflow.test import ( + testing_utils as tf_testing_utils, +) +from coremltools.converters.mil.frontend.tensorflow2.test.testing_utils import ( + make_tf2_graph as make_tf_graph, + run_compare_tf2 as run_compare_tf, +) +from coremltools.converters.mil.testing_reqs import * + +tf = pytest.importorskip("tensorflow", minversion="2.1.0") + +backends = testing_reqs.backends + +# ----------------------------------------------------------------------------- +# Overwrite utilities to enable different conversion / compare method +tf_testing_utils.frontend = "TensorFlow2" +tf_testing_utils.make_tf_graph = make_tf_graph +tf_testing_utils.run_compare_tf = run_compare_tf + +# ----------------------------------------------------------------------------- +# Import TF 2.x-compatible TF 1.x test cases +from coremltools.converters.mil.frontend.tensorflow.test.test_composite_ops import ( + TestCompositeOp, +) diff --git a/coremltools/converters/mil/frontend/tensorflow2/test/test_v2_ops.py b/coremltools/converters/mil/frontend/tensorflow2/test/test_v2_ops.py index cf44f25a2..f533ab224 100644 --- a/coremltools/converters/mil/frontend/tensorflow2/test/test_v2_ops.py +++ b/coremltools/converters/mil/frontend/tensorflow2/test/test_v2_ops.py @@ -25,9 +25,6 @@ # ----------------------------------------------------------------------------- # Import TF 2.x-compatible TF 1.x test cases -from coremltools.converters.mil.frontend.tensorflow.test.test_custom_ops import ( - TestCompositeOp, -) from coremltools.converters.mil.frontend.tensorflow.test.test_ops import ( TestActivationElu, TestActivationLeakyReLU, @@ -39,52 +36,67 @@ TestActivationSoftPlus, TestActivationSoftSign, TestAddN, - TestBroadcastTo, + TestArgSort, TestBatchNormalization, TestBatchToSpaceND, + TestBroadcastTo, TestCast, - TestCond, + TestClipByValue, TestConcat, # Redirects to ConcatV2 in TF2 + TestCond, TestConv, - TestConv3d, TestConvTranspose, + TestConv3d, TestCumSum, TestDebugging, - TestDepthwiseConv, TestDepthToSpace, + TestDepthwiseConv, TestElementWiseBinary, + TestElementWiseUnary, TestExpandDims, TestFill, TestGather, - TestIsFinite, TestIdentity, + TestImageResizing, + TestIsFinite, + TestL2Normalization, TestLinear, + TestLocalResponseNormalization, + TestLogSoftMax, TestMatrixBandPart, + TestMatrixDiag, TestNonMaximumSuppression, TestNormalization, TestOneHot, TestPad, + TestPadV2, TestPack, TestPool1d, TestPool2d, TestPool3d, - TestSelect, - TestSeparableConv, - TestShape, - TestSpaceToBatchND, - TestSqueeze, - TestTensorArray, - TestWhileLoop, + TestRange, + TestReduction, TestReshape, TestReverse, TestReverseSequence, TestScatter, TestSelect, + TestSeparableConv, + TestShape, + TestSize, TestSliceByIndex, TestSliceBySize, + TestSpaceToBatchND, TestSpaceToDepth, TestSplit, + TestSqueeze, + TestStack, + TestTensorArray, + TestTile, + TestTopK, TestTranspose, + TestUnstack, + TestWhileLoop, TestZerosLike, ) diff --git a/coremltools/converters/mil/frontend/tensorflow2/test/test_v2_ops_tf_keras.py b/coremltools/converters/mil/frontend/tensorflow2/test/test_v2_ops_tf_keras.py index 487d7be8d..cc14ed144 100644 --- a/coremltools/converters/mil/frontend/tensorflow2/test/test_v2_ops_tf_keras.py +++ b/coremltools/converters/mil/frontend/tensorflow2/test/test_v2_ops_tf_keras.py @@ -429,6 +429,49 @@ def test_depth_wise_conv( backend=backend, ) + @pytest.mark.parametrize( + ",".join( + [ + "use_cpu_only", + "backend", + "padding", + ] + ), + itertools.product( + [True, False], + backends, + ["same", "valid"], + ), + ) + @pytest.mark.skip(reason="rdar://65198011 (Re-enable Conv3dTranspose and DynamicTile unit tests)") + def test_conv2d_padding_dynamic_input( + self, + use_cpu_only, + backend, + padding, + ): + from tensorflow.keras import Input + from tensorflow.keras.models import Model + from tensorflow.keras.layers import Conv2D, GlobalMaxPooling2D + + # Test same padding + input_layer = Input(batch_size=1, shape=(None, None, 1)) + layer = Conv2D( + filters=16, + kernel_size=(3, 3), + padding=padding, + activation="relu" + )(input_layer) + output_layer = GlobalMaxPooling2D()(layer) + model = Model(inputs=[input_layer], outputs=[output_layer]) + run_compare_tf_keras( + model, + [random_gen((1, 80, 40 ,1), rand_min=-10, rand_max=10)], + use_cpu_only=use_cpu_only, + backend=backend, + ) + + @pytest.mark.parametrize( ",".join( [ @@ -501,8 +544,86 @@ def test_separable_conv( backend=backend, ) - @pytest.mark.skip(reason="rdar://65198011 (Re-enable Conv3dTranspose and DynamicTile unit tests)") +class TestConvTranspose: + @pytest.mark.parametrize( + ",".join( + [ + "use_cpu_only", + "backend", + "op", + "padding", + "data_format", + "spatial_dim_and_ks", + "output_padding", + "strides", + "dilations", + "batch_size", + ] + ), + itertools.product( + [True, False], + backends, + [tf.keras.layers.Conv2DTranspose, tf.keras.layers.Conv3DTranspose], + ["same", "valid"], + ["channels_last"], + [(7, 11, 12, 1, 2, 2), (9, 5, 7, 3, 3, 3)], + [(1, 1, 1)], + [(2, 2, 2), (2, 3, 3)], + [(1, 1, 1)], # Dilation > 1 not supported by TF + [1, 3], + ), + ) + @pytest.mark.skip(reason="rdar://65198011 (Re-enable Conv3dTranspose and DynamicTile unit tests)") + def test_conv_transpose( + self, + use_cpu_only, + backend, + op, + padding, + data_format, + spatial_dim_and_ks, + output_padding, + strides, + dilations, + batch_size, + ): + s1, s2, s3, k1, k2, k3 = spatial_dim_and_ks + c_in, c_out = 2, 3 + input_shape = None + kernel_size = None + if op == tf.keras.layers.Conv2DTranspose: + input_shape = (batch_size, s2, s3, c_in) + kernel_size = (k2, k3) + strides = (strides[1], strides[2]) + dilations = dilations[1:] + output_padding = (output_padding[1], output_padding[2]) + elif op == tf.keras.layers.Conv3DTranspose: + input_shape = (batch_size, s1, s2, s3, c_in) + kernel_size = (k1, k2, k3) + + model = tf.keras.Sequential( + [ + op( + batch_input_shape=input_shape, + filters=c_out, + kernel_size=kernel_size, + strides=strides, + padding=padding.upper(), + output_padding=output_padding, + data_format=data_format, + dilation_rate=dilations, + ) + ] + ) + + run_compare_tf_keras( + model, + [random_gen(input_shape, rand_min=-10, rand_max=10)], + use_cpu_only=use_cpu_only, + backend=backend, + ) + class TestConvTranspose: @pytest.mark.parametrize( ",".join( @@ -870,8 +991,8 @@ def test_instance_normalization( [random_gen(shape, rand_min=-1, rand_max=1)], use_cpu_only=use_cpu_only, backend=backend, - atol=1e-3, - rtol=1e-4, + atol=1e-2, + rtol=1e-3, ) diff --git a/coremltools/converters/mil/frontend/tensorflow2/test/testing_utils.py b/coremltools/converters/mil/frontend/tensorflow2/test/testing_utils.py index a9eb8da37..6f0cdf300 100644 --- a/coremltools/converters/mil/frontend/tensorflow2/test/testing_utils.py +++ b/coremltools/converters/mil/frontend/tensorflow2/test/testing_utils.py @@ -23,7 +23,7 @@ def make_tf2_graph(input_types): Parameters ---------- - input_types: list of tuple + input_types: list of tuple or list of list List of input types. E.g. [(3, 224, 224, tf.int32)] represent 1 input, with shape (3, 224, 224), and the expected data type is tf.int32. The dtype is optional, in case it's missing, tf.float32 will be used. @@ -34,21 +34,19 @@ def make_tf2_graph(input_types): """ def wrapper(ops): - class TensorFlowModule(tf.Module): - input_signature = [] - for input_type in input_types: - if len(input_type) > 0 and isinstance(input_type[-1], dtypes.DType): - shape, dtype = input_type[:-1], input_type[-1] - else: - shape, dtype = input_type, tf.float32 - input_signature.append(tf.TensorSpec(shape=shape, dtype=dtype)) - - @tf.function(input_signature=input_signature) - def __call__(self, *args): - return ops(*args) - - module = TensorFlowModule() - concrete_func = module.__call__.get_concrete_function() + input_signature = [] + for input_type in input_types: + if len(input_type) > 0 and isinstance(input_type[-1], dtypes.DType): + shape, dtype = input_type[:-1], input_type[-1] + else: + shape, dtype = input_type, tf.float32 + input_signature.append(tf.TensorSpec(shape=shape, dtype=dtype)) + + @tf.function(input_signature=input_signature) + def tf2_model(*args): + return ops(*args) + + concrete_func = tf2_model.get_concrete_function() inputs = get_tf_node_names( [t.name for t in concrete_func.inputs if t.dtype != dtypes.resource], mode="input", @@ -188,7 +186,7 @@ def run_compare_tf_keras( if frontend_only: return - # get tf.keras model output as reference and run comparision + # get tf.keras model output as reference and run comparison ref = [model(input_values).numpy()] expected_outputs = {n: v for n, v in zip(outputs, ref)} input_key_values = {n: v for n, v in zip(inputs, input_values)} diff --git a/coremltools/converters/mil/frontend/tensorflow2/tf_graph_pass/rewrite_control_flow_functions.py b/coremltools/converters/mil/frontend/tensorflow2/tf_graph_pass/rewrite_control_flow_functions.py index 304d64dec..7eb17ce73 100644 --- a/coremltools/converters/mil/frontend/tensorflow2/tf_graph_pass/rewrite_control_flow_functions.py +++ b/coremltools/converters/mil/frontend/tensorflow2/tf_graph_pass/rewrite_control_flow_functions.py @@ -5,12 +5,12 @@ # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause -from __future__ import print_function as _ -from __future__ import division as _ from __future__ import absolute_import as _ +from __future__ import division as _ +from __future__ import print_function as _ import logging -from coremltools.converters.mil.frontend.tensorflow.parsed_tf_node import ParsedTFNode + from coremltools.converters.mil.frontend.tensorflow.basic_graph_ops import ( disconnect_edge, connect_edge, @@ -19,6 +19,7 @@ replace_dest, connect_edge_at_index, ) +from coremltools.converters.mil.frontend.tensorflow.parsed_tf_node import ParsedTFNode def _rename_node_in_fn(node, new_name, fn): @@ -252,8 +253,17 @@ def _rewrite_cond_functions(tf_ssa, fn): c_input = [n for n in o_original.input if str(n).startswith(cond_name)][ 0 ] - c_index = c_input.split(":")[-1] if ":" in c_input else 0 - mapped_name = then_fn.ret["identity_{}".format(c_index)].split(":")[0] + if ":" in c_input: + identity_postfix = "identity_{}".format(c_input.split(":")[-1]) + else: # access identity "0" + identity_postfix = "identity" + + identity_keys = [t for t in then_fn.ret.keys() if t.endswith(identity_postfix)] + if len(identity_keys) != 1: + raise NotImplementedError("Branch not found.") + + mapped_name = then_fn.ret[identity_keys[0]].split(":")[0] + if mapped_name in then_fn.outputs: idx = then_fn.outputs.index(mapped_name) else: # in else_fn.outputs @@ -491,7 +501,16 @@ def _rewrite_while_loop_functions(tf_ssa, fn): n for n in o_original.input if str(n).startswith(while_name) ][0] while_index = while_input.split(":")[-1] - mapped_name = body_fn.ret["identity_{}".format(while_index)].split(":")[0] + if while_index != 0: + identity_postfix = "identity_{}".format(while_index) + else: # access identity "0" + identity_postfix = "identity" + + identity_keys = [t for t in body_fn.ret.keys() if t.endswith(identity_postfix)] + if len(identity_keys) != 1: + raise NotImplementedError("Branch not found.") + + mapped_name = body_fn.ret[identity_keys[0]].split(":")[0] idx = body_fn.outputs.index(mapped_name) loop_output = _insert_get_tuple( diff --git a/coremltools/converters/mil/frontend/torch/converter.py b/coremltools/converters/mil/frontend/torch/converter.py index 2253d7945..8486df20f 100644 --- a/coremltools/converters/mil/frontend/torch/converter.py +++ b/coremltools/converters/mil/frontend/torch/converter.py @@ -275,16 +275,19 @@ def _expand_and_optimize_ir(torchscript): # Replaces a couple specific ops patterns (add, sub, mul, div, chunk). if version_lt(_torch, '1.6.0'): _torch._C._jit_pass_canonicalize_ops(graph) + _torch._C._jit_pass_lint(graph) + + # From PyTorch code: This pass catches all of the small, easy to catch + # peephole optimizations you might be interested in doing. + # Eliminate no-op 'expand' nodes + # Simplify x.t().t() to x + # pass disabled for v1.6.0 and onwards, wrongly captures the shape of dummy inputs during tracing. + _torch._C._jit_pass_peephole(graph, addmm_fusion_enabled=False) else: # v1.6.0 pass renamed _torch._C._jit_pass_canonicalize_graph_fuser_ops(graph) _torch._C._jit_pass_lint(graph) - # From PyTorch code: This pass catches all of the small, easy to catch - # peephole optimizations you might be interested in doing. - # Eliminate no-op 'expand' nodes - # Simplify x.t().t() to x - _torch._C._jit_pass_peephole(graph, addmm_fusion_enabled=False) - _torch._C._jit_pass_lint(graph) + # From PyTorch docs: Renumber the graph so that all structurally # equivalent graphs have same numbers. graph = _torch._C._jit_pass_canonicalize(graph) @@ -300,7 +303,7 @@ def _expand_and_optimize_ir(torchscript): _torch._C._jit_pass_lint(graph) input_and_param_names = [val.debugName() for val in graph.inputs()] - param_names = input_and_param_names[len(input_and_param_names) - len(params) :] + param_names = input_and_param_names[len(input_and_param_names) - len(params):] params_dict = dict(zip(param_names, params)) return graph, params_dict diff --git a/coremltools/converters/mil/frontend/torch/load.py b/coremltools/converters/mil/frontend/torch/load.py index 8bd41c372..383e287d5 100644 --- a/coremltools/converters/mil/frontend/torch/load.py +++ b/coremltools/converters/mil/frontend/torch/load.py @@ -89,14 +89,14 @@ def _convert_to_inputtype(inputs): def _torchscript_from_model(model_spec): - if isinstance(model_spec, _string_types) and model_spec.endswith(".pt"): + if isinstance(model_spec, _string_types) and (model_spec.endswith(".pt") or model_spec.endswith(".pth")): filename = _os_path.abspath(model_spec) return _torch.jit.load(filename) elif isinstance(model_spec, _torch.jit.ScriptModule): return model_spec else: raise TypeError( - "@model must either be a PyTorch .pt file or a TorchScript object, received: {}".format( + "@model must either be a PyTorch .pt or .pth file or a TorchScript object, received: {}".format( type(model_spec) ) ) diff --git a/coremltools/converters/mil/frontend/torch/ops.py b/coremltools/converters/mil/frontend/torch/ops.py index 438b1e28a..bbbe4d00f 100644 --- a/coremltools/converters/mil/frontend/torch/ops.py +++ b/coremltools/converters/mil/frontend/torch/ops.py @@ -103,6 +103,26 @@ def convert_block(context, block, inputs): 11: torch.bool, } +NUMPY_DTYPE_TO_TORCH_NUM = { + _np.uint8: 0, + _np.int8: 1, + _np.int16: 2, + _np.int32: 3, + _np.int64: 4, + _np.float16: 5, + _np.float32: 6, + _np.float64: 7, + _np.bool: 11, +} + +NUM_TO_DTYPE_STRING = { + 3: "int32", + 4: "int64", + 6: "fp32", + 7: "fp64", + 11: "bool", +} + def decide_immediate_or_file(val): if ( @@ -115,17 +135,22 @@ def decide_immediate_or_file(val): def _get_inputs(context, node, expected=None): - """Look up a node's inputs in @context and return them as a list. If - @expected is not None, also verifies the number of inputs matches the - value of @expcted. + """ + Look up a node's inputs in @context and return them as a list. If + @expected is not None, also verifies the number of inputs matches the + value of @expected. """ inputs = [context[name] for name in node.inputs] - if expected is not None and len(inputs) != expected: - raise ValueError( - "node {} ({}) got {} input(s), expected {}".format( - node.name, node.kind, len(inputs), expected + if expected is not None: + expected = [expected] if not isinstance(expected, (list, tuple)) else expected + + if len(inputs) not in expected: + raise ValueError( + "node {} ({}) got {} input(s), expected {}".format( + node.name, node.kind, len(inputs), expected + ) ) - ) + return inputs @@ -407,8 +432,8 @@ def _convolution(context, node): if transposed is True: # Transposed convolution - # PyTorch weight ordering [Cin, Cout, H, W] - # MIL expects [Cout, Cin, H, W] + # PyTorch weight ordering [Cin, Cout, *D] + # MIL expects [Cout, Cin, *D] perm = _np.arange(len(weight.shape)) perm[[0, 1]] = perm[[1, 0]] weight_transpose = mb.transpose( @@ -433,7 +458,7 @@ def _convolution(context, node): # TODO: rdar://65588783 ([PyTorch] Define and error out on unsupported configuration for output_padding) # error out here with unsupported configuration along with output padding if sum(pad) == 0 and any(output_padding): - raise ValueError("ConvTransponse configuration of padding=0 and output_padding > 0 not supported!") + raise ValueError("ConvTranspose configuration of padding=0 and output_padding > 0 not supported!") post_crop = pad.copy() pad *= 0 for i in range(0, len(pad)): @@ -443,7 +468,7 @@ def _convolution(context, node): pre_pad[i] = output_padding[i] - post_crop[i] kwargs["pad"] = pre_pad if any(pre_pad): - # Constant pad requires pad to be of lenght 2*input_rank + # Constant pad requires pad to be of length 2*input_rank pre_pad = [0] * 2 * (len(x.shape) - 2) + pre_pad x = mb.pad(x=x, pad=pre_pad) kwargs["x"] = x @@ -645,7 +670,7 @@ def div(context, node): context.add(res) -@register_torch_op +@register_torch_op(torch_alias=["floordiv"]) def floor_divide(context, node): inputs = _get_inputs(context, node, expected=2) div_res = mb.floor_div(x=inputs[0], y=inputs[1]) @@ -678,7 +703,7 @@ def pow_(context, node): @register_torch_op(torch_alias=["rsub"]) def sub(context, node): - inputs = _get_inputs(context, node, expected=3) + inputs = _get_inputs(context, node, expected=[2, 3]) assert len(node.outputs) == 1 if node.kind == "rsub": @@ -688,19 +713,21 @@ def sub(context, node): else: x = inputs[0] y = inputs[1] - alpha = inputs[2].val - # TODO (sberardi): 3rd param to aten::sub is a scale factor, need to handle that. - # out=input-alpha x other - # rdar://60175736 - if alpha != 1: - raise ValueError("SUB does not support scale factor param") + if len(inputs) > 2: + alpha = inputs[2].val + + # TODO (sberardi): 3rd param to aten::sub is a scale factor, need to handle that. + # out=input-alpha x other + # rdar://60175736 + if alpha != 1: + raise ValueError("SUB does not support scale factor param") res = mb.sub(x=x, y=y, name=node.name) context.add(res) -@register_torch_op +@register_torch_op(torch_alias=["sum"]) def mean(context, node): inputs = _get_inputs(context, node) @@ -725,10 +752,10 @@ def mean(context, node): # Last input to mean is an optional output tensor. We always expect this to # be None or absent. assert len(inputs) <= 3 or inputs[3] is None - res = mb.reduce_mean(**kwargs) + func = mb.reduce_sum if node.kind == "sum" else mb.reduce_mean + res = func(**kwargs) context.add(res) - @register_torch_op def squeeze(context, node): inputs = _get_inputs(context, node) @@ -749,13 +776,14 @@ def unsqueeze(context, node): @register_torch_op def size(context, node): - inputs = _get_inputs(context, node, expected=2) + inputs = _get_inputs(context, node, expected=[1, 2]) # Get the shape of the tensor. - shape_node = mb.shape(x=inputs[0], name=node.name + "_shape") + size_node = mb.shape(x=inputs[0], name=node.name + "_shape") # Get the size of the tensor along the input dimension. - dim = inputs[1].val - size_node = _list_select(shape_node, dim) + if len(node.inputs) == 2: + dim = inputs[1].val + size_node = _list_select(size_node, dim) context.add(size_node, node.name) @@ -769,6 +797,10 @@ def view(context, node): length = mb.list_length(ls=shape) indices = mb.range_1d(start=0, end=length, step=1) shape = mb.list_gather(ls=shape, indices=indices) + + if isinstance(shape, list) and all([isinstance(dim, Var) and len(dim.shape) == 0 for dim in shape]) and any([dim.val is None for dim in shape]): + shape = mb.concat(values=shape, axis=0) + view = mb.reshape(x=x, shape=shape, name=node.name) context.add(view) @@ -958,20 +990,8 @@ def hardtanh(context, node): @register_torch_op def cat(context, node): inputs = _get_inputs(context, node) - - values = inputs[0] - if len(values) == 1: - # Only one item to "concatenate", so treat it as a no-OP. Otherwise, - # NN concatND layer will complain it only has one input. - context.add(values[0], node.name) - return - - if len(inputs) < 2: - axis = 0 - else: - axis = inputs[1] - - concat = mb.concat(values=values, axis=axis, name=node.name) + axis = 0 if len(inputs) == 1 else inputs[1] + concat = mb.concat(values=inputs[0], axis=axis, name=node.name) context.add(concat) @@ -1022,6 +1042,9 @@ def _cast(context, node, dtype, dtype_name): res = mb.const(val=dtype(x.val), name=node.name) else: res = x + elif x.shape == (1,): + x = mb.squeeze(x=x, name=node.name + "_item") + res = mb.cast(x=x, dtype=dtype_name, name=node.name) else: if len(x.shape) > 0: # TODO: There's no MIL op to extract a value from a symbolic tensor, @@ -1033,14 +1056,7 @@ def _cast(context, node, dtype, dtype_name): @register_torch_op(torch_alias=["bool"]) def _bool(context, node): - inputs = _get_inputs(context, node, expected=1) - - x = inputs[0] - # TODO: this is a hack and should be replaced once MIL supports cast to - # bool. - if x.val is not None and not isinstance(x.val, bool): - x = mb.const(val=bool(x.val), name=node.name) - context.add(x, node.name) + _cast(context, node, bool, "bool") @register_torch_op(torch_alias=["int"]) @@ -1048,26 +1064,6 @@ def _int(context, node): _cast(context, node, int, "int32") -def _get_axes_from_normalized_shape(original_shape, normalized_shape): - """Convert the `normalized_shape` argument of torch.nn.LayerNorm to the - backend argument `axes` in order to reuse the same backend signature - """ - if not isinstance(normalized_shape, list): - normalized_shape = list(normalized_shape) - - nb_reduced_axes = len(normalized_shape) - nb_total_axes = len(original_shape) - shape_to_reduce = original_shape[-nb_reduced_axes:] - - if not list(shape_to_reduce) == normalized_shape: - raise ValueError( - "normalized_shape ({}) is incompatible with input tensor shape ({}) for layer_norm op. " - "normalized_shape must match the last len(normalized_shape) entries in the input tensor shape".format( - normalized_shape, original_shape) - ) - return list(range(nb_total_axes-nb_reduced_axes,nb_total_axes)) - - @register_torch_op def layer_norm(context, node): inputs = _get_inputs(context, node, expected=6) @@ -1077,11 +1073,10 @@ def layer_norm(context, node): bias = inputs[3] eps = inputs[4] # cudnn_enable = inputs[5] unused - axes = _get_axes_from_normalized_shape(_input.shape, normalized_shape.val) layer_norm = mb.layer_norm( x=_input, - axes=axes, + axes=list(range(-len(normalized_shape.val),0)), gamma=weight, beta=bias, epsilon=eps, @@ -1638,7 +1633,7 @@ def select(context, node): @register_torch_op def ones(context, node): - inputs = _get_inputs(context, node, expected=6) + inputs = _get_inputs(context, node, expected=[5, 6]) size = inputs[0] # dtype = NUM_TO_TORCH_DTYPE[inputs[1].val] unused # layout = inputs[2] unused @@ -1756,8 +1751,21 @@ def _slice(context, node): inputs = _get_inputs(context, node, expected=5) x = inputs[0] dim = inputs[1].val - start = inputs[2].val if inputs[2].val is not None else 0 - end = inputs[3].val if inputs[3] is not None else None + + if inputs[2] and inputs[2].val is not None: + start = inputs[2].val + elif isinstance(inputs[2], Var): + start = inputs[2] + else: + start = 0 + + if inputs[3] and inputs[3].val is not None: + end = inputs[3].val + elif isinstance(inputs[3], Var): + end = inputs[3] + else: + end = None + step = inputs[4].val if start == 0 and end is None and step == 1: @@ -1773,6 +1781,12 @@ def _slice(context, node): end_array[dim] = end end_mask[dim] = False + if isinstance(start, Var): + begin_array = mb.concat(values=begin_array, axis=0) + + if isinstance(end, Var): + end_array = mb.concat(values=end_array, axis=0) + kwargs = { "x": x, "begin": begin_array, @@ -1825,7 +1839,14 @@ def split(context, node): def to(context, node): # @non_blocking and @copy are unused inputs = _get_inputs(context, node) - if len(inputs) == 6: + + if len(inputs) == 8: + _input = inputs[0] + dtype = inputs[1].val + elif len(inputs) == 7: + _input = inputs[0] + dtype = inputs[1].val + elif len(inputs) == 6: _input = inputs[0] device = inputs[1] dtype = inputs[2].val @@ -1834,10 +1855,10 @@ def to(context, node): # memory_format = inputs[5] # usually None elif len(inputs) == 5: _input = inputs[0] - device = inputs[1] - dtype = inputs[2].val - # non_blocking = inputs[3] - # copy = inputs[4] + dtype = NUMPY_DTYPE_TO_TORCH_NUM[inputs[1].val.dtype.type] if isinstance(inputs[1].val, _np.ndarray) else inputs[1].val + # non_blocking = inputs[2] + # copy = inputs[3] + # memory_format = inputs[4] elif len(inputs) == 4: _input = inputs[0] dtype = inputs[1].val @@ -1856,14 +1877,15 @@ def to(context, node): ) torch_dtype = NUM_TO_TORCH_DTYPE[dtype] - if isinstance(_input, Var): + if isinstance(_input, Var) and _input.val is not None: _input = _input.val - - # numpy -> torch -> torch cast -> numpy - # This path is needed to use the mapping of passed in dtypes to torch dtypes. - casted_input = torch.tensor(_input).type(torch_dtype).numpy() - const = mb.const(mode="immediate_value", val=casted_input, name=node.name) - context.add(const) + # numpy -> torch -> torch cast -> numpy + # This path is needed to use the mapping of passed in dtypes to torch dtypes. + casted_input = torch.tensor(_input).type(torch_dtype).numpy() + res = mb.const(mode="immediate_value", val=casted_input, name=node.name) + else: + res = mb.cast(x=_input, dtype=NUM_TO_DTYPE_STRING[dtype], name=node.name) + context.add(res) @register_torch_op @@ -1874,14 +1896,18 @@ def erf(context, node): context.add(erf) -@register_torch_op +@register_torch_op(torch_alias=["scalarimplicit"]) def implicittensortonum(context, node): inputs = _get_inputs(context, node, expected=1) _input = inputs[0] - assert _input.shape == (1,) - # shape: (1,) -> () - squeeze = mb.squeeze(x=_input, name=node.name) - context.add(squeeze) + + if _input.shape == (): #already a scalar + context.add(_input, node.name) + else: + assert _input.shape == (1,) + # shape: (1,) -> () + squeeze = mb.squeeze(x=_input, name=node.name) + context.add(squeeze) @register_torch_op @@ -1912,7 +1938,9 @@ def _expand(context, name, tensor, shape): @register_torch_op def expand(context, node): - inputs = _get_inputs(context, node, expected=2) + # PyTorch 1.6+ has 3 inputs while older version has 2 + inputs = _get_inputs(context, node, expected=[2, 3]) + x = inputs[0] shape = inputs[1].val @@ -1921,7 +1949,8 @@ def expand(context, node): @register_torch_op def expand_as(context, node): - inputs = _get_inputs(context, node, expected=2) + # PyTorch 1.6+ has 3 inputs while older version has 2 + inputs = _get_inputs(context, node, expected=[2, 3]) x = inputs[0] other = inputs[1] @@ -2148,14 +2177,12 @@ def _abs(context, node): inputs = _get_inputs(context, node, expected=1) context.add(mb.abs(x=inputs[0], name=node.name)) - @register_torch_op def repeat(context, node): x = context[node.inputs[0]] reps = context[node.inputs[1]] context.add(mb.tile(x=x, reps=reps, name=node.name)) - @register_torch_op def acos(context, node): inputs = _get_inputs(context, node, expected=1) @@ -2259,7 +2286,8 @@ def sqrt(context, node): @register_torch_op def square(context, node): inputs = _get_inputs(context, node, expected=1) - context.add(mb.square(x=inputs[0], name=node.name)) + # mb.square is not supported in some backend + context.add(mb.mul(x=inputs[0], y=inputs[0], name=node.name)) @register_torch_op def tan(context, node): @@ -2310,40 +2338,16 @@ def is_floating_point(context, node): is_float = types.is_float(inputs[0].dtype) context.add(mb.const(val=is_float, name=node.name)) -@register_torch_op(torch_alias=['sum']) -def _sum(context, node): - inputs = _get_inputs(context, node) - kwargs = {"x": inputs[0], "name": node.name} - - # function declarations to handle: torch.sum(input, dtype=None) and torch.sum(input, dim, keepdim=False, dtype=None) - # the 2nd arguments dtype and dim allow int values causing ambiguity. To ensure reasonable outputs - # dtype is restricted to None and dim must be a tuple in the pytorch definition - if len(inputs) >= 2: - if inputs[1] is not None: - if isinstance(inputs[1].val, _np.ndarray): - kwargs["axes"] = inputs[1] - - # optional: @keep_dims - if len(inputs) >= 3: - keep_dims = inputs[2] - kwargs["keep_dims"] = keep_dims - - if len(inputs) >= 4: - if inputs[3] is not None: - raise Exception("dtype input to sum should be None but the input is {}".format(inputs[3].val)) - else: - raise Exception("Unsupported input argument to sum. Allowed second input arguments are dtype=None or " - "dim=") - - res = mb.reduce_sum(**kwargs) - context.add(res) +@register_torch_op +def where(context, node): + inputs = _get_inputs(context, node, expected=3) + context.add(mb.select(cond=inputs[0], a=inputs[1], b=inputs[2], name=node.name)) @register_torch_op def neg(context, node): inputs = _get_inputs(context, node, expected=1) context.add(mb.mul(x=inputs[0], y=-1, name=node.name)) -@register_torch_op def topk(context, node): inputs = _get_inputs(context, node) kwargs = {"name": node.name, "x": inputs[0], "k": inputs[1]} @@ -2376,4 +2380,3 @@ def topk(context, node): indices_name = node.outputs[1] context.add(res[0], torch_name=values_name) context.add(res[1], torch_name=indices_name) - diff --git a/coremltools/converters/mil/frontend/torch/test/test_api.py b/coremltools/converters/mil/frontend/torch/test/test_api.py index 8b49fdcc2..31b629b45 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_api.py +++ b/coremltools/converters/mil/frontend/torch/test/test_api.py @@ -1,5 +1,6 @@ import pytest import coremltools as ct +import os from coremltools._deps import ( _HAS_TORCH, @@ -23,3 +24,35 @@ def test_no_inputs(): with pytest.raises(ValueError) as e: mlmodel = ct.convert(traced_model) e.match(r'Expected argument for pytorch "inputs" not provided') + + @staticmethod + def test_pth_extension(tmpdir): + # test for issue: https://github.com/apple/coremltools/issues/917 + import torch + + class TestModule(torch.nn.Module): + def __init__(self): + super(TestModule, self).__init__() + self.linear = torch.nn.Linear(10, 20) + + def forward(self, x): + return self.linear(x) + + model = TestModule() + model.eval() + example_input = torch.rand(1, 10) + traced_model = torch.jit.trace(model, example_input) + model_path = os.path.join(str(tmpdir), "torch_model.pth") + traced_model.save(model_path) + + ct.convert( + model_path, + source='pytorch', + inputs=[ + ct.TensorType( + shape=example_input.shape, + ) + ], + ) + + diff --git a/coremltools/converters/mil/frontend/torch/test/test_internal_graph.py b/coremltools/converters/mil/frontend/torch/test/test_internal_graph.py index e9ee7bcde..f9c091697 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_internal_graph.py +++ b/coremltools/converters/mil/frontend/torch/test/test_internal_graph.py @@ -130,7 +130,7 @@ def _test_elementwise_binary( ssa = self._construct_test_graph( context, op, eb_node, output_name, constants=constants ) - np.testing.assert_allclose(expected_result, ssa.val, atol=1e-7) + np.testing.assert_allclose(expected_result, ssa.val, atol=1e-6) def _test_cast(self, context, test_val, op_kind, op_func, python_type): constants, input_list, output_name = self._gen_constants(1, [test_val]) @@ -1093,13 +1093,12 @@ def test_int(self, context, test_val): @pytest.mark.parametrize("input_shape", [(1, 3, 15, 15), (1, 1, 1, 1)]) def test_layer_norm(self, context, input_shape): graph_inputs = {"input": mb.placeholder(input_shape, dtype=types.float)} - channels = input_shape[1] constants, input_list, output_name = self._gen_constants( 5, [ input_shape, # normalized shape - torch.rand(channels), # weight - torch.rand(channels), # running bias + torch.rand(*input_shape), # weight + torch.rand(*input_shape), # running bias 1e-6, 1, # cudnn enabled ], @@ -1451,41 +1450,6 @@ def test_split(self, context, split_sizes, dim, make_explicit): for ex_res, ssa_res in zip(expected_result, ssa): np.testing.assert_allclose(ex_res.numpy(), ssa_res.val, atol=1e-6) - @pytest.mark.parametrize( - "num_args, dtype", itertools.product([4, 5, 6], [0, 1, 2, 3, 4, 5, 6, 7, 11]) - ) - def test_to(self, context, num_args, dtype): - test_input = torch.rand(1, 2, 3) - # These args should be unused - copy = True - non_blocking = True - device = 1337 - - constants_list = [non_blocking, copy] - if num_args == 4: - constants_list = [dtype] + constants_list - elif num_args == 5: - constants_list = [device, dtype] + constants_list - else: - constants_list = [device, dtype, copy] + constants_list - constants_list = [test_input] + constants_list - constants, input_list, output_name = self._gen_constants( - len(constants_list), constants_list - ) - to_node = InternalTorchIRNode( - kind="to", inputs=input_list, outputs=[output_name] - ) - ssa = self._construct_test_graph( - context, ops.to, to_node, output_name, constants=constants, - ) - if num_args == 3: - expected_result = test_input.numpy() - else: - expected_result = test_input.to( - dtype=ops.NUM_TO_TORCH_DTYPE[dtype] - ).numpy() - assert np.allclose(expected_result, ssa.val) - def test_floor(self, context): test_input = torch.rand(1, 2, 3) * 10 constants, input_list, output_name = self._gen_constants(1, test_input) @@ -1778,7 +1742,10 @@ def test_sort(self, context, input_size, dim, descending): @pytest.mark.parametrize( "input_shape, dim, keepdim", - itertools.product([(3, 20, 20), (1, 50, 50)], [[0], [1], [2], [0, 2]], [True, False]), + itertools.product( + [(3, 20, 20), (1, 50, 50)], + [[0], [1], [2], [0, 2]], + [True, False]), ) def test_sum(self, context, input_shape, dim, keepdim): test_input = torch.rand(*input_shape) @@ -1790,7 +1757,7 @@ def test_sum(self, context, input_shape, dim, keepdim): kind="sum", inputs=input_list, outputs=[output_name] ) ssa = self._construct_test_graph( - context, ops._sum, sum_node, output_name, constants=constants + context, ops.mean, sum_node, output_name, constants=constants ) expected_result = torch.sum(test_input, dim, keepdim) assert np.allclose(expected_result, ssa.val) @@ -1803,7 +1770,7 @@ def test_sum_no_dims(self, context): kind="sum", inputs=input_list, outputs=[output_name] ) ssa = self._construct_test_graph( - context, ops._sum, sum_node, output_name, constants=constants + context, ops.mean, sum_node, output_name, constants=constants ) expected_result = torch.sum(test_input) assert np.allclose(expected_result, ssa.val) diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py index e853a6ef7..6dc1599c8 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py @@ -10,6 +10,7 @@ from coremltools.converters.mil import testing_reqs from coremltools.converters.mil.testing_reqs import * from .testing_utils import * +from coremltools import TensorType, ImageType, RangeDim backends = testing_reqs.backends @@ -46,6 +47,22 @@ def test_batchnorm(self, num_features, eps, backend): model = nn.BatchNorm2d(num_features, eps) run_compare_torch((6, num_features, 5, 5), model, backend=backend) + @pytest.mark.parametrize("backend", backends) + def test_batchnorm_1d(self, backend): + class CRNNBase(nn.Module): + def __init__(self, ch_in, ch_out, kernel_size=3, use_bn=True): + super(CRNNBase, self).__init__() + self.conv = nn.Conv1d(ch_in, ch_out, kernel_size=kernel_size) + self.norm = nn.BatchNorm1d(ch_out) + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + return x + model = CRNNBase(ch_in=6, ch_out=16) + run_compare_torch((1, 6, 15), model, backend=backend) + + class TestInstanceNorm: @pytest.mark.parametrize( "num_features, eps, backend", @@ -107,15 +124,16 @@ class TestConvTranspose: ), ) def test_convolution_transpose1d( - self, - width, - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - backend, + self, + width, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + backend, + groups=1, ): model = nn.ConvTranspose1d( in_channels=in_channels, @@ -124,9 +142,11 @@ def test_convolution_transpose1d( stride=stride, padding=padding, dilation=dilation, + groups=groups ) run_compare_torch((1, in_channels, width), model, backend=backend) + @pytest.mark.parametrize( "height, width, in_channels, out_channels, kernel_size, stride, padding, dilation, backend", itertools.product( @@ -156,37 +176,6 @@ def test_convolution_transpose2d( ) run_compare_torch((1, in_channels, height, width), model, backend=backend) - @pytest.mark.parametrize( - "depth, height, width, in_channels, out_channels, kernel_size, stride, padding, dilation, backend", - itertools.product( - [3, 4], [5, 6], [5, 7], [1, 3], [1, 3], [1, 3], [2, 3], [0, 1], [1, 3], backends - ), - ) - @pytest.mark.skip(reason="old macOS version on the CI machine does not have fixes for convolution transposed 3D. " - "Please, see details in https://github.com/apple/coremltools/pull/942") - def test_convolution_transpose3d( - self, - depth, - height, - width, - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - backend, - ): - model = nn.ConvTranspose3d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - ) - run_compare_torch((1, in_channels, depth, height, width), model, backend=backend) - # TODO: rdar://65588783 ([PyTorch] Define and error out on unsupported configuration for output_padding) # TODO: rdar://65550420 (Add Image Resizing (crop, upsample, resize_bilinear) layers to the MIL backend) @pytest.mark.parametrize( @@ -250,6 +239,56 @@ def test_convolution_transpose2d_output_padding( ) run_compare_torch((1, in_channels, height, width), model, backend=backend) + @pytest.mark.parametrize( + "depth, height, width, in_channels, out_channels, kernel_size, stride, padding, dilation, backend", + itertools.product( + [3, 4], [5, 6], [5, 7], [1, 3], [1, 3], [1, 3], [2, 3], [0, 1], [1, 3], backends + ), + ) + @pytest.mark.skip(reason="rdar://65198011 (Re-enable Conv3dTranspose and DynamicTile unit tests)") + def test_convolution_transpose3d( + self, + depth, + height, + width, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + backend, + ): + model = nn.ConvTranspose3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ) + run_compare_torch((1, in_channels, depth, height, width), model, backend=backend) + + +class TestCond: + @pytest.mark.parametrize("backend", backends) + def test_cond(self, backend): + in_features = 1 + out_features = 2 + class TestNet(nn.Module): + def forward(self, x): + if torch.squeeze(x) < 10.: + return x*10. + else: + return x*2. + + model = TestNet().eval() + torch_model = torch.jit.script(model) + + run_compare_torch(torch.tensor([1.]), torch_model, + input_as_shape=False, backend=backend) + run_compare_torch(torch.tensor([11.]), torch_model, + input_as_shape=False, backend=backend) class TestLoop: @pytest.mark.parametrize("backend", backends) @@ -615,7 +654,7 @@ def _pytorch_hidden_to_coreml(self, x): # Concat on Hidden Size axis x = torch.cat((f, b), dim=2) # NOTE: - # We are ommiting a squeeze because the conversion + # We are omitting a squeeze because the conversion # function for the mil op lstm unsqueezes the num_layers # dimension return x @@ -702,7 +741,7 @@ def test_lstm_xexception( ) # Workaround for GitHub Issue #824 -# i.e. the return h_n/c_n for a converted BLSTM are mangled. +# i.e. the return h_n/c_n for a converted BLSTM are mangled. # Therefore, just look at output 'y' (for now) which is correct. class StripCellAndHidden(nn.Module): def __init__(self,flagReturnTuple_): @@ -712,7 +751,7 @@ def __init__(self,flagReturnTuple_): def forward(self,x): # Pass tuple, not tensor, to avoid issue in coremltools/converters/mil/frontend/torch/test/testing_utils.py on "if not expected_results:" # Pass tensor when we need input for LSTM #2 as part of nn.Sequential() - return tuple(x[0]) if self.flagReturnTuple else x[0] + return tuple(x[0]) if self.flagReturnTuple else x[0] # Check GitHub Issue #810, assume num_layers == 2 and bidirectional == True class TestStackedBLSTM: @@ -763,7 +802,7 @@ def test_lstm( else: _input = torch.randn(SEQUENCE_LENGTH, BATCH_SIZE, input_size) - # Do not use h_0/c_0 input and do not check h_n/c_n output, GitHub Issue #824 + # Do not use h_0/c_0 input and do not check h_n/c_n output, GitHub Issue #824 expected_results = model(_input) run_compare_torch(_input, model, expected_results, input_as_shape=False, backend=backend) @@ -836,6 +875,42 @@ def test_pixel_shuffle(self, batch_size, CHW, r, backend): run_compare_torch(input_shape, model, backend=backend) +class TestExpand: + @pytest.mark.parametrize( + "backend, shapes", + itertools.product( + backends, + [[(2, 1), (2, 2)], [(3, 1), (-1, 4)], [(1, 3, 4, 4), (3, 3, 4, 4)]] + ), + ) + def test_expand(self, backend, shapes): + input_shape, output_shape = shapes + + class TestModel(torch.nn.Module): + def forward(self, x): + return x.expand(*output_shape) + + model = TestModel() + + run_compare_torch(input_shape, model, backend=backend) + + @pytest.mark.parametrize( + "backend, input_shapes", + itertools.product( + backends, + [[(2, 1), (2, 2)], [(3, 1), (3, 4)], [(1, 3, 4, 4), (3, 3, 4, 4)]] + ), + ) + def test_expand_as(self, backend, input_shapes): + class TestModel(torch.nn.Module): + def forward(self, x, y): + return x.expand_as(y) + + model = TestModel() + + run_compare_torch(input_shapes, model, backend=backend) + + class TestExpandDims: @pytest.mark.parametrize( "backend, rank_and_axis", @@ -1108,7 +1183,6 @@ def test_softsign(self, backend, rank): input_shape, model, backend=backend, ) - class TestElementWiseUnary: @pytest.mark.parametrize( "backend, rank, op_string", @@ -1258,6 +1332,63 @@ def test(self, use_cpu_only, backend, rank, dims): run_compare_torch(input_shape, model, backend=backend) +class TestTo: + @pytest.mark.parametrize( + "backend", backends, + ) + def test_cast_bug(self, backend): + class TestModel(torch.nn.Module): + def forward(self, spans, embedding): + spans = spans.float().relu().int() + + max1, _ = torch.max(spans, dim=1, keepdim=False) + max1, _ = torch.max(max1, dim=1, keepdim=False) + max2, _ = torch.max(embedding, dim=1, keepdim=False) + max2, _ = torch.max(max2, dim=1, keepdim=False) + sigmoided_scores = max1 + max2 + return sigmoided_scores + + model = TestModel() + run_compare_torch([(1, 21, 2), (1, 6, 384)], model, backend=backend)# [spans.shape, embedding.shape] + +class TestSlice: + @pytest.mark.skipif(_python_version() < (3, 6), reason="requires python 3.6") + @pytest.mark.parametrize( + "backend", backends, + ) + def test_dynamic_slice(self, backend): + class DynamicSlicer(torch.nn.Module): + def __init__(self): + super(DynamicSlicer, self).__init__() + + def forward(self, x, context_length): + return x[context_length:, :, :] + + class Model(torch.nn.Module): + + def __init__(self): + super(Model, self).__init__() + self.tokens_embedding = torch.nn.Embedding(10, 10, 0) + self.context_embedding = torch.nn.Embedding(10, 10, 0) + self.dynamic_slicer = DynamicSlicer() + + def forward(self, tokens, context, context_length): + tokens_embeddings = self.tokens_embedding(tokens) + context_embeddings = self.context_embedding(context) + embeddings = torch.cat((context_embeddings, tokens_embeddings), dim=0) + embeddings = self.dynamic_slicer(embeddings, context_length) + + return embeddings + + model = Model() + batch_size = 5 + inputs = [ TensorType(name="tokens", shape=(10, batch_size), dtype=np.int64), + TensorType(name="context", shape=(3, batch_size), dtype=np.int64), + TensorType(name="context_length", shape=(), dtype=np.int32), + ] + run_compare_torch(inputs, model, rand_range=(0, 8), backend=backend, use_scripting=False) + + class TestRepeat: @pytest.mark.parametrize( "use_cpu_only, backend, rank", diff --git a/coremltools/converters/mil/frontend/torch/test/testing_utils.py b/coremltools/converters/mil/frontend/torch/test/testing_utils.py index 003d3160f..ef339ad65 100644 --- a/coremltools/converters/mil/frontend/torch/test/testing_utils.py +++ b/coremltools/converters/mil/frontend/torch/test/testing_utils.py @@ -7,10 +7,12 @@ import torch import torch.nn as nn from six import string_types as _string_types -from coremltools import TensorType +from coremltools import TensorType, RangeDim +from ..converter import torch_to_mil_types from coremltools.converters.mil.testing_reqs import _converter from coremltools.models import MLModel from coremltools._deps import _IS_MACOS +from coremltools.converters.mil.mil.types.type_mapping import nptype_from_builtin class ModuleWrapper(nn.Module): """ @@ -49,8 +51,13 @@ def convert_to_coreml_inputs(input_description, inputs): """ flattened_inputs = _flatten(inputs) coreml_inputs = { - str(x): inp.numpy() for x, inp in zip(input_description, flattened_inputs) + str(x): inp.numpy().astype(np.float32) for x, inp in zip(input_description, flattened_inputs) } + + for k, v in coreml_inputs.items(): + if isinstance(v, np.ndarray) and v.ndim == 0: + coreml_inputs[k] = np.expand_dims(v, axis=-1) + return coreml_inputs @@ -60,8 +67,10 @@ def _convert_to_inputtype(inputs): return [_convert_to_inputtype(x) for x in inputs] elif isinstance(inputs, tuple): return tuple([_convert_to_inputtype(x) for x in inputs]) + elif isinstance(inputs, TensorType): + return inputs elif isinstance(inputs, torch.Tensor): - return TensorType(shape=inputs.shape) + return TensorType(shape=inputs.shape, dtype=torch_to_mil_types[inputs.dtype]) else: raise ValueError( "Unable to parse type {} into InputType.".format(type(inputs)) @@ -72,12 +81,28 @@ def _convert_to_inputtype(inputs): return MLModel(proto, useCPUOnly=True) -def generate_input_data(input_size, rand_range = (0.0, 1.0)): +def generate_input_data(input_size, rand_range=(0, 1)): r1, r2 = rand_range + + def random_data(spec): + if isinstance(spec, TensorType): + spec_shape = spec.shape.shape + dtype = nptype_from_builtin(spec.dtype) + else: + spec_shape = spec + dtype = np.float32 + + static_shape = tuple([np.random.randint(dim.lower_bound, dim.upper_bound if dim.upper_bound > 0 else 10) + if isinstance(dim, RangeDim) else dim for dim in spec_shape]) + + data = np.random.rand(*static_shape) if static_shape != () else np.random.rand() + data = (r1 - r2) * data + r2 + return torch.from_numpy(np.array(data).astype(dtype)) + if isinstance(input_size, list): - return [(r1 - r2) * torch.rand(_size) + r2 for _size in input_size] + return [random_data(size) for size in input_size] else: - return (r1 - r2) * torch.rand(input_size) + r2 + return random_data(input_size) def trace_model(model, input_data): @@ -90,7 +115,7 @@ def trace_model(model, input_data): def run_compare_torch( input_data, model, expected_results=None, places=5, input_as_shape=True, backend="nn_proto", - rand_range = (0.0, 1.0) + rand_range=(0.0, 1.0), use_scripting=False, ): """ Traces a model and runs a numerical test. @@ -101,9 +126,9 @@ def run_compare_torch( model.eval() if input_as_shape: input_data = generate_input_data(input_data, rand_range) - model_spec = trace_model(model, input_data) + model_spec = torch.jit.script(model) if use_scripting else trace_model(model, input_data) convert_and_compare( - input_data, model_spec, expected_results=expected_results, atol=10.0 ** -places, backend=backend + input_data, model_spec, expected_results=expected_results, atol=10.0 ** -places, backend=backend, ) @@ -116,8 +141,12 @@ def flatten_and_detach_torch_results(torch_results): def convert_and_compare(input_data, model_spec, expected_results=None, atol=1e-5, backend="nn_proto"): """ - If expected results is not set, it will by default - be set to the flattened output of the torch model. + If expected results is not set, it will by default + be set to the flattened output of the torch model. + + Inputs: + + - input_data: torch.tensor or list[torch.tensor] """ if isinstance(model_spec, _string_types): torch_model = torch.jit.load(model_spec) diff --git a/coremltools/converters/mil/frontend/torch/torch_op_registry.py b/coremltools/converters/mil/frontend/torch/torch_op_registry.py index 461616e18..8ecc04d57 100644 --- a/coremltools/converters/mil/frontend/torch/torch_op_registry.py +++ b/coremltools/converters/mil/frontend/torch/torch_op_registry.py @@ -28,12 +28,12 @@ def register_torch_op(_func=None, torch_alias=None, override=False): def func_wrapper(func): f_name = func.__name__ if not override and f_name in _TORCH_OPS_REGISTRY: - raise ValueError("Torch Op {} already registered.".format(f_name)) + raise ValueError("Torch op {} already registered.".format(f_name)) _TORCH_OPS_REGISTRY[f_name] = func if torch_alias is not None: for name in torch_alias: if not override and name in _TORCH_OPS_REGISTRY: - msg = "Torch Op alias {} already registered." + msg = "Torch op alias {} already registered." raise ValueError(msg.format(name)) _TORCH_OPS_REGISTRY[name] = func return func diff --git a/coremltools/converters/mil/mil/block.py b/coremltools/converters/mil/mil/block.py index 3c8271296..c1b983f1b 100644 --- a/coremltools/converters/mil/mil/block.py +++ b/coremltools/converters/mil/mil/block.py @@ -510,6 +510,12 @@ def _replace_var( if end_id != -1 and old_var.op not in op_list: return num_ops_affected + if old_var in self._block_inputs: + idx = self._block_inputs.index(old_var) + self._block_inputs = list(self._block_inputs) + self._block_inputs[idx] = new_var + self._block_inputs = tuple(self._block_inputs) + # If old_var is block's output, replace as well. if old_var in self._outputs: idx = self._outputs.index(old_var) diff --git a/coremltools/converters/mil/mil/builder.py b/coremltools/converters/mil/mil/builder.py index dce1bd850..ca86c4b6d 100644 --- a/coremltools/converters/mil/mil/builder.py +++ b/coremltools/converters/mil/mil/builder.py @@ -11,8 +11,9 @@ import numpy as np from coremltools.converters.mil.mil.types.symbolic import any_symbolic - -from . import curr_block, Program, Function, Placeholder, is_internal_input +from .program import Program, Placeholder +from .block import curr_block, Function +from .operation import is_internal_input from .input_type import ( _InputType, InternalStringInputType, diff --git a/coremltools/converters/mil/mil/operation.py b/coremltools/converters/mil/mil/operation.py index 8c93718d4..b726e2285 100644 --- a/coremltools/converters/mil/mil/operation.py +++ b/coremltools/converters/mil/mil/operation.py @@ -117,6 +117,17 @@ def is_internal_input(arg_name): return arg_name[0] == "_" +class mil_list(object): + ''' + A wrapper around python list + ''' + + def __init__(self, ls=None): + self.ls = ls if ls is not None else [] + if not isinstance(self.ls, list): + raise TypeError("Type of 'ls' must be list in the 'mil_list' class") + + class Operation(object): """ Represents Operation in MIL. @@ -203,6 +214,7 @@ def type_value_inference(self, overwrite_output=False): elem_type=sym_type.T[0], init_length=sym_type.T[1], dynamic_length=sym_type.T[2], + sym_val=sym_val if (sym_val is not None and isinstance(sym_val.val, list)) else None, op=self, op_output_idx=i, ) @@ -284,7 +296,10 @@ def _auto_val(self, output_types): auto_val = [] for t, v in zip(output_types, vals): builtin_val = t() - builtin_val.val = v + if isinstance(v, mil_list): + builtin_val.val = v.ls + else: + builtin_val.val = v auto_val.append(builtin_val) return auto_val diff --git a/coremltools/converters/mil/mil/ops/defs/_utils.py b/coremltools/converters/mil/mil/ops/defs/_utils.py index 7856b0582..5d4d558bf 100644 --- a/coremltools/converters/mil/mil/ops/defs/_utils.py +++ b/coremltools/converters/mil/mil/ops/defs/_utils.py @@ -6,8 +6,8 @@ import math import coremltools.converters -import sympy as sm +from coremltools.converters.mil.mil import get_new_symbol from coremltools.converters.mil.mil.types.symbolic import is_symbolic from ._op_reqs import * @@ -54,7 +54,7 @@ def broadcast_shapes(shape_x, shape_y): ) ret_shapes.append(shape_x[i]) elif x_unknown or y_unknown: - ret_shapes.append(sm.functions.Max(shape_x[i], shape_y[i])) + ret_shapes.append(get_new_symbol()) else: assert shape_x[i] == shape_y[i] ret_shapes.append(shape_x[i]) @@ -158,6 +158,7 @@ def aggregated_pad( effective_ks = effective_kernel(kernel_shape, dilations) return [ int(max(0, s * math.ceil(float(i) / float(s)) - i + k - s)) + if not is_symbolic(i) else get_new_symbol() for i, k, s in zip(input_shape, effective_ks, strides) ] if pad_type == "valid": @@ -222,6 +223,7 @@ def spatial_dimensions_out_shape( len(custom_pad), ) ) + pad = aggregated_pad( pad_type=pad_type, kernel_shape=kernel_shape, @@ -231,7 +233,8 @@ def spatial_dimensions_out_shape( custom_pad=custom_pad, ) effective_ks = effective_kernel(kernel_shape, dilations) - return [ + out_shape = [ (input_shape[r] + pad[r] - effective_ks[r]) // strides[r] + 1 for r in range(num_spatial_dims) ] + return [dim if not is_symbolic(dim) else get_new_symbol() for dim in out_shape] diff --git a/coremltools/converters/mil/mil/ops/defs/control_flow.py b/coremltools/converters/mil/mil/ops/defs/control_flow.py index ecabbd1ba..7cf93cbc1 100644 --- a/coremltools/converters/mil/mil/ops/defs/control_flow.py +++ b/coremltools/converters/mil/mil/ops/defs/control_flow.py @@ -15,7 +15,7 @@ from coremltools.converters.mil.mil import get_new_symbol from ._op_reqs import * import logging - +from coremltools.converters.mil.mil import mil_list @register_op(doc_str="") class cond(Operation): @@ -29,13 +29,13 @@ class cond(Operation): * 0-D tensor (scalar) predicate to switch between true and false branches. _true_fn: function (Required) - * A Python function that executes if ``cond`` evaluates to ``True``. - * It should take no input, and return one or more values whose type becomes + * A Python function that executes if ``pred`` evaluates to ``True``. + * It must take zero input (i.e, no input), and return one or more values whose type becomes the operation's return type. _false_fn: function (Required) - * A Python function that executes if ``cond`` evaluates to ``False``. - * It should take no input, and have return types that match those of the + * A Python function that executes if ``pred`` evaluates to ``False``. + * It must take zero input (i.e. no input), and have return types that match those of the ``if`` branch. Returns @@ -156,7 +156,19 @@ def _get_type_val(self, value): # We use float32 by default. value = value.astype(np.float32) - if not isinstance(value, (np.generic, np.ndarray, six.string_types, bool)): + elif isinstance(value, mil_list): + # if val that was passed in is of type mil_list, which is just a wrapper on top of python list + # then construct the list type + list_value = value.ls + if len(list_value) == 0: + raise ValueError("'mil_list' points to an empty list") + builtin_elem_type, _ = self._get_type_val(list_value[0]) + from coremltools.converters.mil.mil.types.type_list import list as types_list + builtin_type = types_list(builtin_elem_type, init_length=len(list_value), dynamic_length=False) + return builtin_type, value + + + if not isinstance(value, (np.generic, np.ndarray, six.string_types, bool, mil_list)): raise ValueError("Unknown value for constant: {}".format(value)) _, builtin_type = numpy_val_to_builtin_val(value) @@ -209,8 +221,7 @@ class select(Operation): from ``cond, a, b``. * If ``a, b`` are ``None``, the return shape is 2-D, where the first dimension ``n`` is the number of matching indices in ``cond``, and ``len(D1)`` is the - ``cond`` rank. - + rank of ``cond``. Attributes ---------- T: fp32 @@ -247,7 +258,7 @@ def value_inference(self): class while_loop(Operation): """ Perform the body repeatedly while the condition ``cond`` is true. - + Parameters ---------- _cond: function (Required) @@ -296,32 +307,39 @@ def _check_equal_value(val1, val2): return val1 == val2 @staticmethod - def clean_up_child_ops(block): + def _clean_up_child_ops(block): for op in list(block.operations): for b in op.blocks: - while_loop.clean_up_child_ops(b) + while_loop._clean_up_child_ops(b) inputs = op.get_flattened_inputs() for in_var in inputs: in_var.remove_child_op(op) - def build_block(self, block_inputs): - block_name = self.name + '_block' + def _build_block(self, block_inputs): + # Cond block: + block_name = self.name + '_cond_block' with Block(block_inputs=block_inputs, outer_op=self, - name=block_name) as block: - # Body func - body_func = self._body.val - exit_vars = body_func(*block.inputs) + name=block_name) as cond_block: - # Cond func: cond_func = self._cond.val - cond_var = cond_func(*block.inputs) + cond_var = cond_func(*cond_block.inputs) cond_vars = cond_var if isinstance(cond_var, list) else [cond_var] + cond_block.set_outputs(cond_vars) - # Concatenate the outputs - block.set_outputs(cond_vars + list(exit_vars)) - return block, exit_vars + # Body block + block_name = self.name + '_body_block' + with Block(block_inputs=block_inputs, outer_op=self, + name=block_name) as body_block: + body_func = self._body.val + exit_vars = body_func(*body_block.inputs) + exit_vars = list(exit_vars) if isinstance(exit_vars, (list, tuple)) \ + else [exit_vars] + body_block.set_outputs(exit_vars) + #self.blocks.append(body_block) + + return cond_block, body_block, exit_vars def build_nested_blocks(self): # self.loop_vars is python tuple of Vars. @@ -334,6 +352,29 @@ def build_nested_blocks(self): # We assume that sym_val is unchanging across the block iterate. If it # changes, we rebuild the block and rerun type and value inference. + # Design notes on two blocks (cond and body): + # + # - Observe that two blocks can always be represented as a single + # block that contains both cond and body logic, which would return + # [loop_cond] + loop_carries. `loop_cond` is a bool. + # + # - Observe that single block implies a do-while logic, + # in which the first iterate is always executed. It's possible to add + # a cond input to while_loop to modify do-while behavior: + # + # %first_cond = cond_logic(...) + # while_loop(cond=%first_cond, loop_vars=(...)) + # + # and we enter the first iterate only if cond is True. But this would + # require caller to execute cond logic outside of while_loop first + # (which also needs to be duplicated within the loop), + # resulting in duplicated code / ops. + # + # - Thus, single block is unnatural for the natural execution order, + # in which we execute the cond block first to get the loop_cond. Only + # if `loop_cond` is True do we execute the body block. This is the + # semantics of tf.while_loop. + block_inputs = tuple(copy.copy(v) for v in self.loop_vars) for v in block_inputs: v._op = None @@ -343,7 +384,7 @@ def build_nested_blocks(self): v._sym_val = v._sym_val v.consuming_blocks = list() - block, exit_vars = self.build_block(block_inputs) + cond_block, body_block, exit_vars = self._build_block(block_inputs) # Verify exit_vars has the same types as loop_vars block_input_type_change = False @@ -368,10 +409,11 @@ def build_nested_blocks(self): if block_input_type_change: # Since we are going to build the block again, we first need to remove ops # in the block from vars's _child_ops. - while_loop.clean_up_child_ops(block) + while_loop._clean_up_child_ops(cond_block) + while_loop._clean_up_child_ops(body_block) # Rebuild our block to invoke type inference. - block, exit_vars = self.build_block(block_inputs) + cond_block, body_block, exit_vars = self._build_block(block_inputs) for i, (v_in, v_out) in enumerate(zip(block_inputs, exit_vars)): if not is_subtype(v_out.sym_type, v_in.sym_type): msg = 'Block output {}: {} is not a subtype of ' +\ @@ -383,7 +425,8 @@ def build_nested_blocks(self): 'block input {}: {} after value changes' raise ValueError(msg.format(v_out.name. v.sym_val, v_in.name, v_in.sym_val)) - self.blocks.append(block) + self.blocks.append(cond_block) + self.blocks.append(body_block) @staticmethod def get_compat_shape(type1, type2): @@ -415,7 +458,7 @@ def get_compat_shape(type1, type2): def type_inference(self): # Skip the conditional var - return tuple(v.sym_type for v in self.blocks[0].outputs[1:]) + return tuple(v.sym_type for v in self.blocks[1].outputs) @register_op(doc_str="") @@ -452,8 +495,8 @@ class make_list(Operation): input_spec = InputSpec( init_length=IntInputType(optional=True, default=1), - dynamic_length=BoolInputType(optional=True, default=True), - elem_shape=TensorInputType(const=True), + dynamic_length=BoolInputType(const=True, optional=True, default=True), + elem_shape=IntTensorInputType(), dtype=StringInputType(const=True, optional=True, default="fp32"), ) diff --git a/coremltools/converters/mil/mil/ops/defs/conv.py b/coremltools/converters/mil/mil/ops/defs/conv.py index 4109099be..e4f3ed4ff 100644 --- a/coremltools/converters/mil/mil/ops/defs/conv.py +++ b/coremltools/converters/mil/mil/ops/defs/conv.py @@ -140,6 +140,49 @@ def type_inference(self): return types.tensor(self.x.dtype, tuple(retshape)) +@register_op(doc_str="") +class conv_quantized(conv): + """ + Note: This is experimental and may change in the future. + Supports weight quantization for parameters while performing convolution over input. + ``W_float = W_quantized * scale + bias`` + + Parameters + ---------- + In addition to convolutional layer parameters the following additional parameters are required. + + quantization_type: const str (Required) + * One of ``linear``, or ``lut`` + + nbits: const tensor<[], i32> (Optional. Default to 8) + * Denotes the bit-width of the quantization. ``1 <= nbits <= 8`` + + quant_scale: tensor<*?, T> (Required) + * Denotes the scale of quantization. + + quant_bias: tensor<*?, T> (Required) + * Denotes the bias that is used to quantize/dequantize. + + Returns + ------- + tensor<[n, C_out, *d_out], T> + * Output activation has the same rank and spatial dimension as the input (i.e., ``len(d_out) == len(d_in)``) + + Attributes + ---------- + T: fp32 + """ + + input_spec = InputSpec( + quantization_type=StringInputType(const=True, optional=False, default=None), + nbits=IntInputType(const=True, optional=False, default=8), + quant_scale=ScalarOrTensorInputType(const=True, optional=False, default=None), + quant_bias=ScalarOrTensorInputType(const=True, optional=False, default=None)) + conv.input_spec + + def __init__(self, **kwargs): + super(conv_quantized, self).__init__(**kwargs) + + @register_op(doc_str="") class conv_transpose(Operation): """ diff --git a/coremltools/converters/mil/mil/ops/defs/elementwise_binary.py b/coremltools/converters/mil/mil/ops/defs/elementwise_binary.py index 416f172a9..55d99bd50 100644 --- a/coremltools/converters/mil/mil/ops/defs/elementwise_binary.py +++ b/coremltools/converters/mil/mil/ops/defs/elementwise_binary.py @@ -70,7 +70,7 @@ def _cast_check_value_inferene(self, a, b): """ -Elementwise Binary Op Implmentation(s) +Elementwise Binary Op Implementation(s) """ diff --git a/coremltools/converters/mil/mil/ops/defs/elementwise_unary.py b/coremltools/converters/mil/mil/ops/defs/elementwise_unary.py index ca6660adb..743f3c936 100644 --- a/coremltools/converters/mil/mil/ops/defs/elementwise_unary.py +++ b/coremltools/converters/mil/mil/ops/defs/elementwise_unary.py @@ -22,7 +22,7 @@ def type_inference(self): """ -Elementwise unary op implmentation(s) +Elementwise unary op implementation(s) """ @@ -390,13 +390,17 @@ def value_inference(self): @register_op(doc_str="") -class inverse(elementwise_unary): +class inverse(Operation): """ Returns the reciprocal value of the input ``x``, element-wise. Parameters ---------- x: tensor<[*d], T> (Required) + epsilon: const f32 (Optional, default=1e-4) + * this is a small constant that is added to the input, before taking its inverse, + for stability. + * y = 1 / (x + epsilon) Returns ------- @@ -408,22 +412,33 @@ class inverse(elementwise_unary): T: fp32 """ + input_spec = InputSpec( + x=ScalarOrTensorInputType(), + epsilon=FloatInputType(const=True, default=1e-4), + ) + def __init__(self, **kwargs): super(inverse, self).__init__(**kwargs) + def type_inference(self): + return self.x.sym_type + @precondition(allow=VALUE) def value_inference(self): - return np.reciprocal(self.x.val) + return np.reciprocal(self.x.val + self.epsilon.val) @register_op(doc_str="") -class log(elementwise_unary): +class log(Operation): """ Returns the natural logarithm value of the input ``x``, element-wise. Parameters ---------- x: tensor<[*d], T> (Required) + epsilon: const f32 (Optional, default=1e-45) + * this is a small constant that is added to the input, before taking log. + * y = log(x + epsilon) Returns ------- @@ -435,12 +450,20 @@ class log(elementwise_unary): T: fp32 """ + input_spec = InputSpec( + x=ScalarOrTensorInputType(), + epsilon=FloatInputType(const=True, default=1e-45), + ) + def __init__(self, **kwargs): super(log, self).__init__(**kwargs) + def type_inference(self): + return self.x.sym_type + @precondition(allow=VALUE) def value_inference(self): - return np.log(self.x.val) + return np.log(self.x.val + self.epsilon.val) @register_op(doc_str="") @@ -498,13 +521,17 @@ def value_inference(self): @register_op(doc_str="") -class rsqrt(elementwise_unary): +class rsqrt(Operation): """ Returns the reciprocal value of the square root of the input ``x``, element-wise. Parameters ---------- x: tensor<[*d], T> (Required) + epsilon: const f32 (Optional, default=1e-12) + * this is a small constant that is added to the input, before applying the rsqrt function, + for stability. + * y = 1 / sqrt(x + epsilon) Returns ------- @@ -516,12 +543,20 @@ class rsqrt(elementwise_unary): T: fp32 """ + input_spec = InputSpec( + x=ScalarOrTensorInputType(), + epsilon=FloatInputType(const=True, default=1e-12), + ) + def __init__(self, **kwargs): super(rsqrt, self).__init__(**kwargs) + def type_inference(self): + return self.x.sym_type + @precondition(allow=VALUE) def value_inference(self): - return 1.0 / np.sqrt(self.x.val) + return 1.0 / np.sqrt(self.x.val + self.epsilon.val) @register_op(doc_str="") diff --git a/coremltools/converters/mil/mil/ops/defs/image_resizing.py b/coremltools/converters/mil/mil/ops/defs/image_resizing.py index 02cd989ce..f2831d9d0 100644 --- a/coremltools/converters/mil/mil/ops/defs/image_resizing.py +++ b/coremltools/converters/mil/mil/ops/defs/image_resizing.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) 2020, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be @@ -20,19 +21,18 @@ class upsample_nearest_neighbor(Operation): * Scale factor for the height dimension (``axis=-2``). upscale_factor_width: const (Optional, default=1) * Scale factor for the width dimension (``axis=-1``). - + Returns ------- tensor<[*D, H2, W2],T> * Tensor with same type as the input. * ``H2`` = ``H1`` * ``upscale_factor_height``. * ``W2`` = ``W1`` * ``upscale_factor_width``. - + Attributes ---------- T: fp32 """ - input_spec = InputSpec( x=TensorInputType(), upscale_factor_height=IntInputType(const=True, default=1), @@ -122,13 +122,13 @@ class upsample_bilinear(Operation): * Tensor with same type as the input. * ``H2`` = floor(``H1`` * ``scale_factor_height``). * ``W2`` = floor(``W1`` * ``scale_factor_width``). - + Attributes ---------- T: fp32 T2 : fp32 or int32 """ - + input_spec = InputSpec( x=TensorInputType(), scale_factor_height=IntOrFloatInputType(const=True, default=1), @@ -229,6 +229,7 @@ class resize_bilinear(Operation): ``tf.raw_ops.ResizeBilinear(align_corners=True, half_pixel_centers=False)``. + Returns ------- tensor<[*D, H2, W2],T> @@ -240,7 +241,7 @@ class resize_bilinear(Operation): ---------- T: fp32 """ - + input_spec = InputSpec( x=TensorInputType(), target_size_height=IntInputType(const=True, default=1), @@ -360,7 +361,7 @@ class crop_resize(Operation): ---------- T: fp32 """ - + input_spec = InputSpec( x=TensorInputType(), roi=TensorInputType(), @@ -437,7 +438,7 @@ class crop(Operation): ---------- T: fp32 """ - + input_spec = InputSpec( x=TensorInputType(), crop_height=IntTensorInputType(const=True), diff --git a/coremltools/converters/mil/mil/ops/defs/normalization.py b/coremltools/converters/mil/mil/ops/defs/normalization.py index 94d1a434f..b77bc8a79 100644 --- a/coremltools/converters/mil/mil/ops/defs/normalization.py +++ b/coremltools/converters/mil/mil/ops/defs/normalization.py @@ -4,6 +4,9 @@ # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause from ._op_reqs import * +from coremltools.converters.mil.mil.types.symbolic import ( + any_symbolic, +) @register_op(doc_str="") class batch_norm(Operation): @@ -39,7 +42,7 @@ class batch_norm(Operation): ------- tensor<[n,C,*D], T> * Output tensor has the same shape and type as the input ``x``. - + Attributes ---------- T: fp32 @@ -58,7 +61,8 @@ def __init__(self, **kwargs): super(batch_norm, self).__init__(**kwargs) def type_inference(self): - return self.x.sym_type + x_shape = self.x.shape + return types.tensor(types.fp32, tuple(x_shape)) @register_op(doc_str="") @@ -98,13 +102,15 @@ def __init__(self, **kwargs): super(instance_norm, self).__init__(**kwargs) def type_inference(self): - return self.x.sym_type + x_shape = self.x.shape + return types.tensor(types.fp32, tuple(x_shape)) @register_op(doc_str="") class l2_norm(Operation): """ - Apply L2 normalization to the n-dimensional input tensor on given ``axes``: + Apply L2 normalization to the n-dimensional input tensor. That is, divide the input + tensor by the square root of the sum of squares of all elements of the input. .. math:: x_i \\leftarrow \\dfrac{x_i}{\\sqrt{\\sum{x_i^2} + \\epsilon}} @@ -112,19 +118,19 @@ class l2_norm(Operation): Parameters ---------- - x: tensor<[n,C,*D], T> (Required) - * Input tensor, ``3 <= rank(x) <= 4``. - * ``*D`` refers to the spatial dimensions, ``1 <= rank(*D) <= 2``. + x: tensor<[*D,C,H,W], T> (Required) + * Input tensor, ``rank(x) >= 3``. + * ``*D`` refers to the spatial dimensions, ``rank(*D) >= 0``. * ``n`` is the batch dimension. - axes: const tensor<[K], i32> (Required) - * Dimensions to perform normalizations. + * For ranks greater than 3, the leading dimensions, starting from ``0`` to ``-4`` (inclusive), + are all treated as batch. epsilon: const fp32 (Optional) * Small constant to avoid division by ``0``. - * Optional, defaults to ``1e-12``. - + * Optional, defaults to ``1e-6``. + Returns ------- - tensor<[n,C,*D], T> + tensor<[*D,C,H,W], T> * Same type and shape as the input tensor ``x``. Attributes @@ -134,15 +140,15 @@ class l2_norm(Operation): input_spec = InputSpec( x=TensorInputType(), - axes=IntTensorInputType(), - epsilon=FloatInputType(const=True, default=1e-12), + epsilon=FloatInputType(const=True, default=1e-6), ) def __init__(self, **kwargs): super(l2_norm, self).__init__(**kwargs) def type_inference(self): - return self.x.sym_type + x_shape = self.x.shape + return types.tensor(types.fp32, tuple(x_shape)) @register_op(doc_str="") @@ -151,7 +157,7 @@ class layer_norm(Operation): Apply layer normalization to the n-dimensional input tensor: .. math:: - out = gamma * (input - mean) / sqrt(variance + epsilon) + beta + out = gamma * (input - E[x]) / sqrt(Var[x] + epsilon) + beta Parameters @@ -162,10 +168,12 @@ class layer_norm(Operation): * Dimensions to perform layer normalization. * Default is ``None`` (all dimensions). gamma: const tensor<[K], T> (Optional) - * Same shape as normalized_shape. + * if provided, the shape must be be ``x.shape[axes]``, + * for instance, if with input ``x`` with shape ``(3,4,5,6)`` and ``axes = [2,3]``, + gamma must have shape ``(5,6)``. * Default is all ones. beta: const tensor<[K], T> (Optional) - * Same shape as normalized_shape. + * Same shape as gamma. * Default is all zeros. epsilon: const fp32 (Optional) * Small constant to avoid division by ``0``. @@ -174,7 +182,11 @@ class layer_norm(Operation): Returns ------- tensor<*?, T>: - * Tensor with same shape and type as the input tensor ``x``. + * Tensor with same shape and type as the input tensor ``x``. + + Attributes + ---------- + T: fp32 """ input_spec = InputSpec( @@ -188,15 +200,45 @@ class layer_norm(Operation): def __init__(self, **kwargs): super(layer_norm, self).__init__(**kwargs) + @staticmethod + def _is_compatible_shape(shapea, shapeb): + if not len(shapea) == len(shapeb): + return False + for a,b in zip(shapea, shapeb): + if any_symbolic([a,b]): + continue + if a != b: + return False + return True + def type_inference(self): - return self.x.sym_type + rank = self.x.rank + + # check valid axes + positive_axes = [axis + rank if axis < 0 else axis for axis in self.axes.val] + if not all([axis >= 0 and axis < rank for axis in positive_axes]): + raise ValueError("axes must in the range of [-x.rank, x.rank-1].") + + # check shape of gamma and beta + normalized_shape = [self.x.shape[i] for i in range(rank) if i in positive_axes] + if self.gamma is not None and not layer_norm._is_compatible_shape(list(self.gamma.shape), normalized_shape): + raise ValueError("Expect shape {} for gamma, but get shape {} instead".format(normalized_shape, self.gamma.shape)) + + if self.beta is not None and not layer_norm._is_compatible_shape(list(self.gamma.shape), normalized_shape): + raise ValueError("Expect shape {} for beta, but get shape {} instead".format(normalized_shape, self.beta.shape)) + + x_shape = self.x.shape + return types.tensor(types.fp32, tuple(x_shape)) + @precondition(allow=VALUE) def value_inference(self): def np_layer_norm(x, axes, gamma, beta, epsilon=1e-5): - normalized_shape = x.shape[-len(axes) :] - gamma = np.ones(shape=normalized_shape) if gamma is None else gamma - beta = np.zeros(shape=normalized_shape) if beta is None else beta + rank = len(x.shape) + axes = [axis + rank if axis < 0 else axis for axis in axes] + normalized_shape = [x.shape[i] if i in axes else 1 for i in range(rank)] + gamma = np.ones(shape=normalized_shape) if gamma is None else np.reshape(gamma, normalized_shape) + beta = np.zeros(shape=normalized_shape) if beta is None else np.reshape(beta, normalized_shape) num = x - np.mean(x, axis=tuple(axes), keepdims=True) dem = np.sqrt( np.sum(np.square(num), axis=tuple(axes), keepdims=True) @@ -230,10 +272,10 @@ class local_response_norm(Operation): * Amount of neighboring channels to normalize. alpha: const fp32 (Optional) * Scale factor. - * Default is ``1.0``. + * Default is ``1e-4``. beta: const fp32 (Optional) * An exponent. - * Default is ``0.5``. + * Default is ``0.75``. k: const fp32 (Optional) * Additive factor. * Default is ``1.0``. @@ -260,4 +302,5 @@ def __init__(self, **kwargs): super(local_response_norm, self).__init__(**kwargs) def type_inference(self): - return self.x.sym_type + x_shape = self.x.shape + return types.tensor(types.fp32, tuple(x_shape)) diff --git a/coremltools/converters/mil/mil/ops/defs/reduction.py b/coremltools/converters/mil/mil/ops/defs/reduction.py index 8bbbdbcef..8ab50f089 100644 --- a/coremltools/converters/mil/mil/ops/defs/reduction.py +++ b/coremltools/converters/mil/mil/ops/defs/reduction.py @@ -62,8 +62,7 @@ class ReductionAxis(Operation): def __init__(self, **kwargs): super(ReductionAxis, self).__init__(**kwargs) - def type_inference(self): - x_type = self.x.dtype + def _find_reduced_shape(self): x_shape = self.x.shape axis = self.axis.val @@ -73,12 +72,20 @@ def type_inference(self): reduced_shape[axis] = 1 else: reduced_shape.pop(axis) + return reduced_shape + def type_inference(self): + x_type = self.x.dtype + reduced_shape = self._find_reduced_shape_and_axis() return types.tensor(x_type, tuple(reduced_shape)) @precondition(allow=VALUE) def value_inference(self): - return self.get_operator()(self.x.val, axis=self.axis.val) + tmp = self.get_operator()(self.x.val, axis=self.axis.val) + reduced_shape = self._find_reduced_shape() + if self.keep_dims.val: + tmp = np.reshape(tmp, reduced_shape) + return tmp def get_operator(self): raise NotImplementedError() diff --git a/coremltools/converters/mil/mil/ops/defs/tensor_operation.py b/coremltools/converters/mil/mil/ops/defs/tensor_operation.py index 2140f498c..a4a9a34cb 100644 --- a/coremltools/converters/mil/mil/ops/defs/tensor_operation.py +++ b/coremltools/converters/mil/mil/ops/defs/tensor_operation.py @@ -410,6 +410,9 @@ def type_inference(self): pad = self.pad if len(pad.shape) != 1: raise ValueError("Pad should be a 1D tensor!") + if self.mode and not self.mode.val in {'constant', 'reflect', 'replicate'}: + raise ValueError("Pad mode should be one of {'constant', 'reflect', 'replicate'}") + if pad.val is None: for i in range(self.pad.shape[0]//2): ret_shape[-self.pad.shape[0]//2+i] = get_new_symbol() @@ -818,12 +821,12 @@ class concat(Operation): Attributes ---------- - T: fp32 + T: fp32, int32 """ input_spec = InputSpec(values=TupleInputType(), axis=IntInputType(const=True), - interleave=BoolInputType(const=True, default=False)) + interleave=BoolInputType(const=True, optional=True, default=False)) def __init__(self, **kwargs): super(concat, self).__init__(**kwargs) @@ -1112,22 +1115,4 @@ def type_inference(self): @precondition(allow=VALUE | SYMBOL) def value_inference(self): - return self.x.sym_val - - -@register_op(doc_str="") -class isfinite(Operation): - input_spec = InputSpec(x=ScalarOrTensorInputType(),) - """ - Should deprecate this op. - """ - - def __init__(self, **kwargs): - super(isfinite, self).__init__(**kwargs) - - def type_inference(self): - return types.tensor(types.bool, list(self.x.shape)) - - @precondition(allow=VALUE) - def value_inference(self): - return np.isfinite(self.x.val) + return self.x.sym_val \ No newline at end of file diff --git a/coremltools/converters/mil/mil/ops/tests/test_const.py b/coremltools/converters/mil/mil/ops/tests/test_const.py new file mode 100644 index 000000000..fc8264418 --- /dev/null +++ b/coremltools/converters/mil/mil/ops/tests/test_const.py @@ -0,0 +1,50 @@ +# Copyright (c) 2020, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +from scipy import special +import scipy +from coremltools.converters.mil import testing_reqs +from coremltools.converters.mil.testing_reqs import * + +from .testing_utils import run_compare_builder + +backends = testing_reqs.backends + + +class TestConst: + @pytest.mark.parametrize( + "use_cpu_only, backend, dtype", itertools.product( + [True], + backends, + [np.float32, np.int32] + ) + ) + def test_builder_to_backend_smoke(self, use_cpu_only, backend, dtype): + t = np.random.randint(0, 100, (100, 2)).astype(np.float32) + constant = np.random.randint(0, 100, (100, 2)).astype(dtype) + input_placeholders = { + "x": mb.placeholder(shape=t.shape), + } + input_values = {"x": t} + + def build(x): + y = mb.const(val=constant, mode="file_value") + x = mb.cast(x=x, dtype='int32') + z = mb.add(x=x, y=y) + return mb.cast(x=z, dtype='fp32') + + expected_output_types = (100, 2, types.fp32) + expected_outputs = t + constant.astype(np.float32) + + run_compare_builder( + build, + input_placeholders, + input_values, + expected_output_types, + expected_outputs, + use_cpu_only=use_cpu_only, + frontend_only=False, + backend=backend, + ) \ No newline at end of file diff --git a/coremltools/converters/mil/mil/ops/tests/test_control_flow.py b/coremltools/converters/mil/mil/ops/tests/test_control_flow.py index 922360b22..fc05f3405 100644 --- a/coremltools/converters/mil/mil/ops/tests/test_control_flow.py +++ b/coremltools/converters/mil/mil/ops/tests/test_control_flow.py @@ -144,7 +144,7 @@ def false_fn(): class TestWhileLoop: @pytest.mark.parametrize( - "use_cpu_only, backend", itertools.product([True, False], backends,) + "use_cpu_only, backend", itertools.product([True,False], backends,) ) def test_builder_to_backend_smoke(self, use_cpu_only, backend): def body(a, b): @@ -186,6 +186,116 @@ def build(a, b): backend=backend, ) + @pytest.mark.parametrize( + "use_cpu_only, backend", itertools.product([True, False], backends,) + ) + def test_builder_to_backend_power(self, use_cpu_only, backend): + + input_placeholders = { + "a": mb.placeholder(shape=(1,)), + "b": mb.placeholder(shape=(1,)), + } + + def build(a, b): + # Compute a^b + def body(res, bx): + return mb.mul(x=res, y=a), mb.add(x=bx, y=np.float32(1)) + + def cond(res, bx): + return mb.less(x=bx, y=b) + + res, ignored = mb.while_loop(_cond=cond, _body=body, + loop_vars=([1.], [0.])) + return res + + input_values = { + "a": np.array([2], dtype=np.float32), + "b": np.array([4], dtype=np.float32), + } + + expected_output_types = [ + (1, types.fp32), + ] + + expected_outputs = [ + np.array([16], dtype=np.float32), + ] + run_compare_builder( + build, + input_placeholders, + input_values, + expected_output_types, + expected_outputs, + use_cpu_only=use_cpu_only, + frontend_only=False, + backend=backend, + ) + + @pytest.mark.parametrize( + "use_cpu_only, backend", itertools.product([True, False], backends,) + ) + def test_builder_to_backend_nested(self, use_cpu_only, backend): + if backend == 'nn_proto': + pytest.xfail("nn_proto backend add const has issue") + + input_placeholders = { + "x": mb.placeholder(shape=(1,)), + "y": mb.placeholder(shape=(1,)), + } + + def build(x, y): + # i, j = x, y + # while i < j: + # while 2*i < i+2: + # i += 1 + # i += 2 + # return i, j + + # Create const outside of while loop for testing purpose + two = mb.const(val=[2.], name='const_two') + one = mb.const(val=[1.], name='const_one') + + def cond2(i): + return mb.less(x=mb.mul(x=two, y=i), y=mb.add(x=i, y=two)) + + def body2(i): + return mb.add(x=i, y=one) + + def cond1(i, j): + return mb.less(x=i, y=j) + + def body1(i, j): + new_i = mb.while_loop(_cond=cond2, _body=body2, + loop_vars=(i,)) + return mb.add(x=new_i, y=two), j + + return mb.while_loop(_cond=cond1, _body=body1, + loop_vars=(x, y)) + + input_values = { + "x": np.array([0], dtype=np.float32), + "y": np.array([10], dtype=np.float32), + } + + expected_output_types = [ + (1, types.fp32), + (1, types.fp32), + ] + + expected_outputs = [ + np.array([10], dtype=np.float32), + np.array([10], dtype=np.float32), + ] + run_compare_builder( + build, + input_placeholders, + input_values, + expected_output_types, + expected_outputs, + use_cpu_only=use_cpu_only, + frontend_only=False, + backend=backend, + ) class TestList: @pytest.mark.parametrize( @@ -247,6 +357,7 @@ def build(a, b): "use_cpu_only, backend", itertools.product([True, False], backends,) ) def test_builder_to_backend_while(self, use_cpu_only, backend): + # The while_loop appends [1, 2]*i to `ls` for each iteration # i = 0, ... num_iters-1. def body(i, num_iters, ls, update): diff --git a/coremltools/converters/mil/mil/ops/tests/test_conv.py b/coremltools/converters/mil/mil/ops/tests/test_conv.py index 3bb71c42a..431980055 100644 --- a/coremltools/converters/mil/mil/ops/tests/test_conv.py +++ b/coremltools/converters/mil/mil/ops/tests/test_conv.py @@ -11,7 +11,6 @@ backends = testing_reqs.backends - @pytest.mark.skip(reason="rdar://65198011 (Re-enable Conv3dTranspose and DynamicTile unit tests)") class TestConvTranspose: diff --git a/coremltools/converters/mil/mil/ops/tests/test_elementwise_unary.py b/coremltools/converters/mil/mil/ops/tests/test_elementwise_unary.py index d82dcb9ba..b80532181 100644 --- a/coremltools/converters/mil/mil/ops/tests/test_elementwise_unary.py +++ b/coremltools/converters/mil/mil/ops/tests/test_elementwise_unary.py @@ -85,9 +85,6 @@ def test_builder_to_backend_smoke(self, use_cpu_only, backend, mode): ) build = lambda x: mb.atan(x=x) elif mode == "atanh": - if backend == "mil_proto": - # TODO - return val = np.array([[-0.8, -0.5, 0], [0.4, 0.5, 0.8]], dtype=np.float32) expected_outputs = np.array( [[-1.09861229, -0.54930614, 0.0], [0.42364893, 0.54930614, 1.09861229]], @@ -166,8 +163,6 @@ def test_builder_to_backend_smoke(self, use_cpu_only, backend, mode): build = lambda x: mb.floor(x=x) elif mode == "inverse": - if backend == "mil_proto": # TODO - return val = np.array([[-1, 2, -3], [4, -5, 6]], dtype=np.float32) expected_outputs = np.array( [[-1.0, 0.5, -0.33333334], [0.25, -0.2, 0.16666667]], dtype=np.float32 @@ -522,3 +517,99 @@ def test_builder_threshold_eval(self): expected_outputs = np.array([[1.0, 2, 1.0], [4.5, 1.0, 6.7]], dtype=np.float32) assert is_close(expected_outputs, v.val) + + @pytest.mark.parametrize( + "use_cpu_only, backend, epsilon", + itertools.product( + [True, False], + backends, + [1e-3, 1e-1, 1.0], + ), + ) + def test_builder_to_backend_stress_inverse( + self, use_cpu_only, backend, epsilon + ): + x = np.array([[1, -2, 3], [4, -5, 6]], dtype=np.float32) + numpy_pred = 1 / (x + epsilon) + + input_placeholder_dict = {"x": mb.placeholder(shape=x.shape)} + input_value_dict = {"x": x} + + def build(x): + return mb.inverse(x=x, epsilon=epsilon) + + expected_output_type = x.shape + (types.fp32,) + run_compare_builder( + build, + input_placeholder_dict, + input_value_dict, + expected_output_type, + numpy_pred, + use_cpu_only=use_cpu_only, + frontend_only=False, + backend=backend, + ) + + @pytest.mark.parametrize( + "use_cpu_only, backend, epsilon", + itertools.product( + [True, False], + backends, + [1e-3, 1e-1, 1.0], + ), + ) + def test_builder_to_backend_stress_rsqrt( + self, use_cpu_only, backend, epsilon + ): + x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32) + numpy_pred = 1.0 / np.sqrt(x + epsilon) + + input_placeholder_dict = {"x": mb.placeholder(shape=x.shape)} + input_value_dict = {"x": x} + + def build(x): + return mb.rsqrt(x=x, epsilon=epsilon) + + expected_output_type = x.shape + (types.fp32,) + run_compare_builder( + build, + input_placeholder_dict, + input_value_dict, + expected_output_type, + numpy_pred, + use_cpu_only=use_cpu_only, + frontend_only=False, + backend=backend, + ) + + @pytest.mark.parametrize( + "use_cpu_only, backend, epsilon", + itertools.product( + [True, False], + backends, + [1e-3, 1e-1, 1.0], + ), + ) + def test_builder_to_backend_stress_log( + self, use_cpu_only, backend, epsilon + ): + x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32) + numpy_pred = np.log(x + epsilon) + + input_placeholder_dict = {"x": mb.placeholder(shape=x.shape)} + input_value_dict = {"x": x} + + def build(x): + return mb.log(x=x, epsilon=epsilon) + + expected_output_type = x.shape + (types.fp32,) + run_compare_builder( + build, + input_placeholder_dict, + input_value_dict, + expected_output_type, + numpy_pred, + use_cpu_only=use_cpu_only, + frontend_only=False, + backend=backend, + ) diff --git a/coremltools/converters/mil/mil/ops/tests/test_image_resizing.py b/coremltools/converters/mil/mil/ops/tests/test_image_resizing.py index 0c6fff214..57c56e4f3 100644 --- a/coremltools/converters/mil/mil/ops/tests/test_image_resizing.py +++ b/coremltools/converters/mil/mil/ops/tests/test_image_resizing.py @@ -13,7 +13,10 @@ class TestResizeBilinear: - def test_builder_to_backend_smoke(self, use_cpu_only=True, backend="nn_proto"): + @pytest.mark.parametrize( + "use_cpu_only, backend", itertools.product([True, False], backends,) + ) + def test_builder_to_backend_smoke(self, use_cpu_only, backend): x = np.array([0, 1], dtype=np.float32).reshape(1, 1, 2) input_placeholder_dict = {"x": mb.placeholder(shape=x.shape)} input_value_dict = {"x": x} @@ -86,9 +89,11 @@ def build_mode_3(x): ) -@pytest.mark.skip("Broken for mil backend rdar://problem/66964398") class TestUpsampleBilinear: - def test_builder_to_backend_smoke(self, use_cpu_only=True, backend="nn_proto"): + @pytest.mark.parametrize( + "use_cpu_only, backend", itertools.product([True, False], backends,) + ) + def test_builder_to_backend_smoke(self, use_cpu_only, backend): x = np.array([0, 1], dtype=np.float32).reshape(1, 1, 2) input_placeholder_dict = {"x": mb.placeholder(shape=x.shape)} input_value_dict = {"x": x} @@ -135,7 +140,9 @@ def build_upsample_fractional(x): backend=backend, ) + # TODO: enable GPU test: rdar://problem/60309338 + @pytest.mark.skip("Broken for mil backend rdar://problem/66964398") @pytest.mark.skipif(not testing_reqs._HAS_TORCH, reason="PyTorch not installed.") @pytest.mark.parametrize( "use_cpu_only, backend, input_shape, scale_factor, align_corners", @@ -188,7 +195,10 @@ def build_upsample(x): class TestUpsampleNearestNeighbor: - def test_builder_to_backend_smoke(self, use_cpu_only=True, backend="nn_proto"): + @pytest.mark.parametrize( + "use_cpu_only, backend", itertools.product([True, False], backends,) + ) + def test_builder_to_backend_smoke(self, use_cpu_only, backend): x = np.array([1.5, 2.5, 3.5], dtype=np.float32).reshape(1, 1, 1, 3) input_placeholder_dict = {"x": mb.placeholder(shape=x.shape)} input_value_dict = {"x": x} diff --git a/coremltools/converters/mil/mil/ops/tests/test_linear.py b/coremltools/converters/mil/mil/ops/tests/test_linear.py index 762edbb1f..bd6d4b797 100644 --- a/coremltools/converters/mil/mil/ops/tests/test_linear.py +++ b/coremltools/converters/mil/mil/ops/tests/test_linear.py @@ -55,10 +55,11 @@ def test_builder_eval(self): itertools.product([True, False], backends, [2, 4, 8]), ) def test_builder_to_backend_stress(self, use_cpu_only, backend, dim): - shape = np.array([dim, dim]) + out_channels, in_channels = dim, dim + 4 + shape = np.array([out_channels, in_channels]) x_val = np.random.rand(*shape) weight_val = np.random.rand(*shape).astype(np.float32) - bias_val = np.random.rand(dim).astype(np.float32) + bias_val = np.random.rand(out_channels).astype(np.float32) input_placeholders = { "x": mb.placeholder(shape=x_val.shape), } diff --git a/coremltools/converters/mil/mil/ops/tests/test_normalization.py b/coremltools/converters/mil/mil/ops/tests/test_normalization.py index 4013fe77d..733dd8f39 100644 --- a/coremltools/converters/mil/mil/ops/tests/test_normalization.py +++ b/coremltools/converters/mil/mil/ops/tests/test_normalization.py @@ -5,8 +5,9 @@ from coremltools.converters.mil import testing_reqs from coremltools.converters.mil.testing_reqs import * +from coremltools.converters.mil.mil import Program, Function, get_new_symbol -from .testing_utils import run_compare_builder +from .testing_utils import UNK_SYM, run_compare_builder backends = testing_reqs.backends @@ -243,7 +244,7 @@ def test_builder_to_backend_smoke(self, use_cpu_only, backend): input_values = {"x": x_val} def build(x): - return [mb.l2_norm(x=x, axes=[-1], epsilon=1e-10)] + return [mb.l2_norm(x=x, epsilon=1e-10)] expected_output_types = [(1, 3, 2, types.fp32)] expected_outputs = [ @@ -269,8 +270,70 @@ def build(x): backend=backend, ) + @pytest.mark.parametrize( + "use_cpu_only, backend, rank", itertools.product([True, False], backends, [3, 4, 5]) + ) + def test_builder_to_backend_stress(self, use_cpu_only, backend, rank): + shape = np.random.randint(low=2, high=6, size=rank) + x_val = random_gen(shape=shape, rand_min=-10.0, rand_max=10.0) + input_placeholders = {"x": mb.placeholder(shape=shape)} + input_values = {"x": x_val} + + def build(x): + return [mb.l2_norm(x=x, epsilon=1e-12)] + + # compute for the answer + batch_dims = rank - 3 + if batch_dims == 0: + norm = la.norm(x_val) + output = x_val/norm + else: + batch_dim_prod = np.prod(shape[:batch_dims]) + reshape_x_val = np.reshape(x_val,(batch_dim_prod,-1)) + norm = la.norm(reshape_x_val, axis=1, keepdims=True) + output = reshape_x_val/norm + output = np.reshape(output, shape) + + expected_output_types = [list(output.shape) + [types.fp32]] + expected_outputs = [ + output + ] + + run_compare_builder( + build, + input_placeholders, + input_values, + expected_output_types, + expected_outputs, + use_cpu_only=use_cpu_only, + backend=backend, + ) + class TestNormalizationLayerNorm: + + @staticmethod + def _keras_layer_norm( x, axes, epsilon): + layer = tf.keras.layers.LayerNormalization(axis=axes, epsilon=epsilon) + data = tf.constant(x, dtype=tf.float32) + output = layer(data) + return output.numpy() + + @staticmethod + def _np_layer_norm(x, axes, gamma=None, beta=None, epsilon=1e-5): + rank = len(x.shape) + axes = [axis + rank if axis < 0 else axis for axis in axes] + normalized_shape = [x.shape[i] if i in axes else 1 for i in range(rank)] + gamma = np.ones(shape=normalized_shape) if gamma is None else np.reshape(gamma, normalized_shape) + beta = np.zeros(shape=normalized_shape) if beta is None else np.reshape(beta, normalized_shape) + num = x - np.mean(x, axis=tuple(axes), keepdims=True) + dem = np.sqrt( + np.sum(np.square(num), axis=tuple(axes), keepdims=True) + / np.prod(normalized_shape) + + epsilon + ) + return num / dem * gamma + beta + @pytest.mark.parametrize( "use_cpu_only, backend", itertools.product([True, False], backends,) ) @@ -292,9 +355,9 @@ def build(x): np.array( [ [ - [0.9999969, -0.9999969], - [0.99999839, -0.99999839], - [0.99995005, -0.99995005], + [ 0.9999969, -0.9999969 ], + [ 0.99999833, -0.99999833], + [ 0.99995005, -0.99995005], ] ], dtype=np.float32, @@ -302,9 +365,9 @@ def build(x): np.array( [ [ - [0.8268512, -1.0630943], - [1.771824, -0.8268511], - [-0.11812156, -0.590608], + [ 0.82687193, -1.06312108], + [ 1.77186835, -0.82687193], + [-0.11812456, -0.59062278], ] ], dtype=np.float32, @@ -321,26 +384,142 @@ def build(x): backend=backend, ) - @ssa_fn - def test_builder_eval(self): - def np_layer_norm(x, axes, gamma, beta, epsilon=1e-5): - normalized_shape = x.shape[-len(axes) :] - gamma = np.ones(shape=normalized_shape) if gamma is None else gamma - beta = np.zeros(shape=normalized_shape) if beta is None else beta - num = x - np.mean(x, axis=tuple(axes), keepdims=True) - dem = np.sqrt( - np.sum(np.square(num), axis=tuple(axes), keepdims=True) - / np.prod(normalized_shape) - + epsilon - ) - return num / dem * gamma + beta + @pytest.mark.parametrize( + "use_cpu_only, backend", itertools.product([True, False], backends,) + ) + def test_builder_to_backend_smoke_with_dynamic_shape(self, use_cpu_only, backend): + x_val = np.array([[[1.0, -7.0], [5.0, -6.0], [-3.0, -5.0]]], dtype=np.float32) + shape = (get_new_symbol(), get_new_symbol(), 2) + input_placeholders = {"x": mb.placeholder(shape=shape)} + input_values = {"x": x_val} + + def build(x): + return [ + mb.layer_norm(x=x, axes=[2], epsilon=1e-4), + ] + + expected_output_types = [(UNK_SYM, UNK_SYM, 2, types.fp32)] + expected_outputs = [ + np.array( + [ + [ + [ 0.9999969, -0.9999969 ], + [ 0.99999833, -0.99999833], + [ 0.99995005, -0.99995005], + ] + ], + dtype=np.float32, + ), + ] + + run_compare_builder( + build, + input_placeholders, + input_values, + expected_output_types, + expected_outputs, + use_cpu_only=use_cpu_only, + backend=backend, + ) + + @pytest.mark.parametrize( + "use_cpu_only, backend, rank_and_axes, epsilon, provides_gamma_beta", + itertools.product([True, False], backends, + [[3,[0,2]], [3,[-2]], [4,[0,1,3]], [5,[0,4]], [5,[-5,-4,-3,-2,-1]] + ], + [0.0001, 0.01], + [True, False]), + ) + def test_builder_to_backend_stress_numpy(self, use_cpu_only, backend, rank_and_axes, epsilon, provides_gamma_beta): + rank, axes = rank_and_axes + shape = np.random.randint(low=2, high=6, size=rank) + x_val = random_gen(shape=shape, rand_min=-100.0, rand_max=100.0) + input_placeholders = {"x": mb.placeholder(shape=x_val.shape)} + input_values = {"x": x_val} + + gamma, beta = None, None + + if provides_gamma_beta: + positive_axes = [axis+rank if axis <0 else axis for axis in axes] + normalized_shape = [shape[i] for i in range(rank) if i in positive_axes] + gamma = random_gen(shape=normalized_shape, rand_min=-100, rand_max=100) + beta = random_gen(shape=normalized_shape, rand_min=-100, rand_max=100) + + def build(x): + return [ + mb.layer_norm(x=x, axes=axes, epsilon=epsilon, gamma=gamma, beta=beta) + ] + + output = TestNormalizationLayerNorm._np_layer_norm(x=x_val, axes=axes, epsilon=epsilon, gamma=gamma, beta=beta) + expected_output_types = [tuple(output.shape) + (types.fp32,)] + expected_outputs = [ + output + ] + + run_compare_builder( + build, + input_placeholders, + input_values, + expected_output_types, + expected_outputs, + use_cpu_only=use_cpu_only, + backend=backend, + ) + + @pytest.mark.skipif(not testing_reqs._HAS_TF_2, reason="Tensorflow not found.") + @pytest.mark.parametrize( + "use_cpu_only, backend, rank_and_axes, epsilon", + itertools.product([True, False], backends, + [[3,[0,2]], [3,[-2]], [4,[0,1,3]], [5,[0,4]], [5,[-5,-4,-3,-2,-1]] + ], + [0.0001, 0.01]), + ) + def test_builder_to_backend_stress_keras(self, use_cpu_only, backend, rank_and_axes, epsilon): + rank, axes = rank_and_axes + shape = np.random.randint(low=2, high=6, size=rank) + x_val = random_gen(shape=shape, rand_min=-100.0, rand_max=100.0) + input_placeholders = {"x": mb.placeholder(shape=x_val.shape)} + input_values = {"x": x_val} + + def build(x): + return [ + mb.layer_norm(x=x, axes=axes, epsilon=epsilon) + ] + + output = TestNormalizationLayerNorm._keras_layer_norm(x=x_val, axes=axes, epsilon=epsilon) + expected_output_types = [tuple(output.shape) + (types.fp32,)] + expected_outputs = [ + output + ] - x_val = random_gen(shape=(1, 3, 4, 4), rand_min=-100.0, rand_max=100.0) - g = random_gen(shape=(4, 4), rand_min=1.0, rand_max=2.0) - b = random_gen(shape=(4, 4), rand_min=0.0, rand_max=1.0) - res = mb.layer_norm(x=x_val, axes=[-2, -1], gamma=g, beta=b) - ref = np_layer_norm(x=x_val, axes=[-2, -1], gamma=g, beta=b) - assert is_close(ref, res.val) + run_compare_builder( + build, + input_placeholders, + input_values, + expected_output_types, + expected_outputs, + use_cpu_only=use_cpu_only, + backend=backend, + ) + + @pytest.mark.parametrize("rank_and_axes, epsilon", + itertools.product( + [[3,[0,2]], [3,[-2,-1]], [4,[0,1,2,3]], [5,[0,2,-1]], [5,[-5,-4,-3,-2,-1]]], + [0.0001, 0.01], + ), + ) + def test_builder_eval_stress(self, rank_and_axes, epsilon): + rank, axes = rank_and_axes + shape = np.random.randint(low=2, high=6, size=rank) + x_val = random_gen(shape=shape, rand_min=-100.0, rand_max=100.0) + positive_axes = [axis+rank if axis <0 else axis for axis in axes] + normalized_shape = [shape[i] for i in range(rank) if i in positive_axes] + gamma_val = random_gen(shape=normalized_shape, rand_min=-100, rand_max=100) + beta_val = random_gen(shape=normalized_shape, rand_min=-100, rand_max=100) + with Function({}) as ssa_func: + res = mb.layer_norm(x=x_val, axes=axes, epsilon=epsilon, gamma=gamma_val, beta=beta_val) + ref = TestNormalizationLayerNorm._np_layer_norm(x=x_val, axes=axes, epsilon=epsilon, gamma=gamma_val, beta=beta_val) + assert is_close(ref, res.val) class TestNormalizationLocalResponseNorm: diff --git a/coremltools/converters/mil/mil/ops/tests/test_scatter_gather.py b/coremltools/converters/mil/mil/ops/tests/test_scatter_gather.py index 7039a47bc..9d549a275 100644 --- a/coremltools/converters/mil/mil/ops/tests/test_scatter_gather.py +++ b/coremltools/converters/mil/mil/ops/tests/test_scatter_gather.py @@ -388,6 +388,45 @@ def build(x, indices): backend=backend, ) + @pytest.mark.parametrize( + "use_cpu_only, backend", itertools.product([True, False], backends,) + ) + def test_embedding_builder_to_backend_smoke(self, use_cpu_only, backend): + x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32) + indices = np.array([1, 0], dtype=np.int32) + input_placeholders = { + "indices": mb.placeholder(shape=indices.shape, dtype=types.int32), + } + + input_values = {"indices": indices} + + def build(indices): + return [ + mb.gather(x=x, indices=indices, axis=0), + mb.gather(x=x, indices=indices, axis=-2), + ] + + expected_output_types = [ + (2, 3, types.fp32), + (2, 3, types.fp32), + ] + + expected_outputs = [ + np.array([[4, 5, 6], [1, 2, 3]], dtype=np.float32), + np.array([[4, 5, 6], [1, 2, 3]], dtype=np.float32), + ] + + run_compare_builder( + build, + input_placeholders, + input_values, + expected_output_types, + expected_outputs, + use_cpu_only=use_cpu_only, + frontend_only=False, + backend=backend, + ) + @ssa_fn def test_builder_eval(self): x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32) diff --git a/coremltools/converters/mil/mil/ops/tests/test_tensor_operation.py b/coremltools/converters/mil/mil/ops/tests/test_tensor_operation.py index 2a7295334..d2f86b3ac 100644 --- a/coremltools/converters/mil/mil/ops/tests/test_tensor_operation.py +++ b/coremltools/converters/mil/mil/ops/tests/test_tensor_operation.py @@ -1217,7 +1217,7 @@ def build(x): expected_outputs = [ref] expected_output_types = [ - tuple(ref.shape) + (types.fp32,), + tuple(list(ref.shape) + [types.fp32]), ] run_compare_builder( @@ -1449,41 +1449,3 @@ def test_builder_eval(self): x_val = random_gen(shape=(1, 3, 2, 2), rand_min=-100, rand_max=100) res = mb.argsort(x=x_val, axis=-3) assert is_close(np.argsort(x_val, axis=-3), res.val) - - -class TestIsFinite: - @pytest.mark.parametrize( - "use_cpu_only, backend", itertools.product([True, False], backends,) - ) - def test_builder_to_backend_smoke(self, use_cpu_only, backend): - val = np.array([[np.inf, -np.inf, 0], [-np.inf, 5, 6]], dtype=np.float32) - input_placeholders = {"x": mb.placeholder(shape=val.shape)} - input_values = {"x": val} - - def build(x): - return [mb.isfinite(x=x)] - - expected_output_types = [(2, 3, types.bool)] - expected_outputs = [ - np.array([[False, False, True], [False, True, True]], dtype=np.bool) - ] - - run_compare_builder( - build, - input_placeholders, - input_values, - expected_output_types, - expected_outputs, - use_cpu_only=use_cpu_only, - backend=backend, - ) - - @ssa_fn - def test_builder_eval(self): - shape = (3, 3, 3, 3) - x_val = random_gen(shape=shape, rand_min=-1, rand_max=1) - random_map = np.random.choice([np.inf, -np.inf, 0], size=shape) - x_val[np.where(random_map == np.inf)] = np.inf - x_val[np.where(random_map == -np.inf)] = -np.inf - res = mb.isfinite(x=x_val) - assert is_close(np.isfinite(x_val), res.val) diff --git a/coremltools/converters/mil/mil/ops/tests/test_tensor_transformation.py b/coremltools/converters/mil/mil/ops/tests/test_tensor_transformation.py index b2d222e8b..113f89a91 100644 --- a/coremltools/converters/mil/mil/ops/tests/test_tensor_transformation.py +++ b/coremltools/converters/mil/mil/ops/tests/test_tensor_transformation.py @@ -918,7 +918,6 @@ def build(x, y): backend=backend, ) - @pytest.mark.parametrize( "use_cpu_only, backend", itertools.product([True, False], backends, ) ) @@ -954,6 +953,105 @@ def build(x): backend=backend, ) + @pytest.mark.parametrize( + "use_cpu_only, backend, rank, n_inputs, negative_index", + itertools.product( + [True, False], + backends, + [1, 2, 3, 4, 5], + [2, 3], + [False, True], + ) + ) + @pytest.mark.skip(reason="rdar://65198011 (Re-enable Conv3dTranspose, concat interleave and DynamicTile unit tests)") + def test_builder_to_backend_stress_interleave(self, use_cpu_only, backend, + rank, n_inputs, negative_index): + + def np_concat_interleave(arrays, axis): + step = len(arrays) + in_shape = arrays[0].shape + out_shape = list(in_shape) + if axis < 0: + axis += len(in_shape) + out_shape[axis] = step * in_shape[axis] + concat_tensor = np.empty(tuple(out_shape), dtype=np.float32) + for i in range(step): + if rank == 5: + if axis == 4: + concat_tensor[:, :, :, :, i::step] = arrays[i] + if axis == 3: + concat_tensor[:, :, :, i::step, :] = arrays[i] + if axis == 2: + concat_tensor[:, :, i::step, :, :] = arrays[i] + if axis == 1: + concat_tensor[:, i::step, :, :, :] = arrays[i] + if axis == 0: + concat_tensor[i::step, :, :, :, :] = arrays[i] + if rank == 4: + if axis == 3: + concat_tensor[:, :, :, i::step] = arrays[i] + if axis == 2: + concat_tensor[:, :, i::step, :] = arrays[i] + if axis == 1: + concat_tensor[:, i::step, :, :] = arrays[i] + if axis == 0: + concat_tensor[i::step, :, :, :] = arrays[i] + if rank == 3: + if axis == 2: + concat_tensor[:, :, i::step] = arrays[i] + if axis == 1: + concat_tensor[:, i::step, :] = arrays[i] + if axis == 0: + concat_tensor[i::step, :, :] = arrays[i] + if rank == 2: + if axis == 1: + concat_tensor[:, i::step] = arrays[i] + if axis == 0: + concat_tensor[i::step, :] = arrays[i] + if rank == 1: + concat_tensor[i::step] = arrays[i] + return concat_tensor + + input_shape = [4, 2, 3, 6, 5] + for axis in range(rank): + if negative_index: + axis = axis - rank + shape = tuple(input_shape[:rank]) + t1 = np.random.normal(size=shape).astype(np.float32) + t2 = np.random.normal(size=shape).astype(np.float32) + all_input_arrs = [t1, t2] + input_placeholders = { + "x": mb.placeholder(shape=t1.shape), + "y": mb.placeholder(shape=t2.shape), + } + input_values = {"x": t1, "y": t2} + if n_inputs == 3: + t3 = np.random.normal(size=shape).astype(np.float32) + input_placeholders["z"] = mb.placeholder(shape=t3.shape) + input_values["z"] = t3 + all_input_arrs.append(t3) + + def build_2_inputs(x, y): + return (mb.concat(values=(x, y), axis=axis, interleave=True),) + + def build_3_inputs(x, y, z): + return (mb.concat(values=(x, y, z), axis=axis, interleave=True),) + + np_out = np_concat_interleave(all_input_arrs, axis) + expected_output_types = [np_out.shape + (types.fp32,)] + expected_outputs = [np_out] + + run_compare_builder( + build_3_inputs if n_inputs == 3 else build_2_inputs, + input_placeholders, + input_values, + expected_output_types, + expected_outputs, + use_cpu_only=use_cpu_only, + frontend_only=False, + backend=backend, + ) + @ssa_fn def test_builder_eval(self): values = [ diff --git a/coremltools/converters/mil/mil/ops/tests/testing_utils.py b/coremltools/converters/mil/mil/ops/tests/testing_utils.py index c08646e61..e9502b6b5 100644 --- a/coremltools/converters/mil/mil/ops/tests/testing_utils.py +++ b/coremltools/converters/mil/mil/ops/tests/testing_utils.py @@ -25,6 +25,7 @@ def run_compare_builder( backend="nn_proto", atol=1e-04, rtol=1e-05, + inputs=None, ): """ Inputs: @@ -36,6 +37,9 @@ def run_compare_builder( dict as MLModel doesn't support function with no input. + - input_values: str -> np.array or PIL.Image. Keys must match those in + input_placeholders. + - expected_output_types: list[(shape, builtin_type)] or (shape, builtin_type). None skips type inference validation. @@ -43,6 +47,8 @@ def run_compare_builder( frontend_only == False - frontend_only: True to test up to proto generation. + + - inputs: type of inputs (either None (defaults to tensor) or [ct.ImageType]) """ if not isinstance(expected_output_types, list): expected_output_types = [expected_output_types] @@ -100,7 +106,7 @@ def run_compare_builder( if output_shape != expected_shape: raise ValueError(msg) - proto = converter._convert(prog, convert_from="mil", convert_to=backend) + proto = converter._convert(prog, convert_from="mil", convert_to=backend, inputs=inputs) if frontend_only: return diff --git a/coremltools/converters/mil/mil/passes/common_pass.py b/coremltools/converters/mil/mil/passes/common_pass.py index a5bcdb98b..9259f2072 100644 --- a/coremltools/converters/mil/mil/passes/common_pass.py +++ b/coremltools/converters/mil/mil/passes/common_pass.py @@ -20,6 +20,7 @@ def common_pass(prog): 'common::noop_elimination', "common::fuse_matmul_weight_bias", "common::fuse_gelu_tanh_approximation", + "common::pad_conv_connect", 'common::image_input_preprocess', "common::reduce_transposes", "common::fuse_bias_conv", diff --git a/coremltools/converters/mil/mil/passes/loop_invariant_elimination.py b/coremltools/converters/mil/mil/passes/loop_invariant_elimination.py index ed4a3120a..e74c05929 100644 --- a/coremltools/converters/mil/mil/passes/loop_invariant_elimination.py +++ b/coremltools/converters/mil/mil/passes/loop_invariant_elimination.py @@ -17,10 +17,10 @@ def detect_loop_invariants(while_op): - block = while_op.blocks[0] + block = while_op.blocks[1] # body block loop_invariant_ids = [] # list of index in op.loop_vars, block.inputs for i, vx_in in enumerate(block.inputs): - vx_out = block.outputs[i + 1] # first output is cond var. + vx_out = block.outputs[i] # first output is cond var. return_input_as_output = vx_in == vx_out # this block output is a var from outside of the block output_from_outside_of_block = ( @@ -77,7 +77,6 @@ def loop_invariant_elimination_block(block): for op in list(block.operations): if op.op_type != "while_loop": continue - block = op.blocks[0] loop_invariant_ids = detect_loop_invariants(op) loop_variant_vars = [] @@ -85,12 +84,15 @@ def loop_invariant_elimination_block(block): # replace uses of loop_invariants with its source from outside of the # while_loop op. for i in loop_invariant_ids: - block.replace_uses_of_var_after_op( - anchor_op=None, old_var=block.inputs[i], new_var=op.loop_vars[i] - ) + for block in op.blocks: + block.replace_uses_of_var_after_op( + anchor_op=None, old_var=block.inputs[i], + new_var=op.loop_vars[i] + ) # replace block inputs - block.remove_inputs([block.inputs[i] for i in loop_invariant_ids]) + for block in op.blocks: + block.remove_inputs([block.inputs[i] for i in loop_invariant_ids]) # remove invariants from while_loop loop_vars for i in loop_invariant_ids: @@ -108,13 +110,12 @@ def loop_invariant_elimination_block(block): ) op._input_vars["loop_vars"] = op.loop_vars - # remove invariants from while_loop outputs - # block.outputs[0] is cond var - block.set_outputs( - [block.outputs[0]] - + [ + # remove invariants from while_loop body_block outputs + body_block = op.blocks[1] + body_block.set_outputs( + [ v - for i, v in enumerate(block.outputs[1:]) + for i, v in enumerate(body_block.outputs) if i not in loop_invariant_ids ] ) diff --git a/coremltools/converters/mil/mil/passes/noop_elimination.py b/coremltools/converters/mil/mil/passes/noop_elimination.py index 8ac55c117..c7382d3fc 100644 --- a/coremltools/converters/mil/mil/passes/noop_elimination.py +++ b/coremltools/converters/mil/mil/passes/noop_elimination.py @@ -7,61 +7,130 @@ from coremltools.converters.mil.mil import Builder as mb import numpy as np -def remove_reshape(reshape_op, block): - input_var = reshape_op.x - input_op = input_var.op - - reshape_op.enclosing_block.replace_uses_of_var_after_op(anchor_op=input_op, - old_var=reshape_op.outputs[0], new_var=input_var) +def _remove_elementwise_binary(op, block, x, y): + # We remove the ops that has op.x == x or op.y == y + if x is not None and op.x.val is not None and np.all(op.x.val == x): + input_var = op.y + input_op = input_var.op + elif y is not None and op.y.val is not None and np.all(op.y.val == y): + input_var = op.x + input_op = input_var.op + else: + return False + + input_shape = input_var.sym_type + output_shape = op.outputs[0].sym_type + + # We might be using elementwise as broadcasting + if input_shape != output_shape: + return False + + op.enclosing_block.replace_uses_of_var_after_op( + anchor_op=input_op, old_var=op.outputs[0], new_var=input_var + ) + block.remove_ops([op]) - # Remove all the ops at once - block.remove_ops([reshape_op]) return True -def remove_split(split_op, block): - input_var = split_op.x +def remove_elementwise(op, block): + + if op.op_type in {"add"}: + return _remove_elementwise_binary(op, block, 0, 0) + elif op.op_type in {"mul"}: + return _remove_elementwise_binary(op, block, 1, 1) + elif op.op_type in {"floor_div", "pow", "real_div"}: + return _remove_elementwise_binary(op, block, None, 1) + elif op.op_type in {"sub"}: + return _remove_elementwise_binary(op, block, None, 0) + else: + return False + + +def remove_same_shape(op, block): + input_shape = op.x.sym_type + output_shape = op.outputs[0].sym_type + + if input_shape != output_shape: + return False + + input_var = op.x input_op = input_var.op - split_op.enclosing_block.replace_uses_of_var_after_op(anchor_op=input_op, - old_var=split_op.outputs[0], new_var=input_var) + op.enclosing_block.replace_uses_of_var_after_op( + anchor_op=input_op, old_var=op.outputs[0], new_var=input_var + ) # Remove all the ops at once - block.remove_ops([split_op]) + block.remove_ops([op]) return True -def remove_slice(slice_op, block): - input_var = slice_op.x + +def remove_linear(op, block): + if op.alpha.val != 1 or op.beta.val != 0: + return False + + input_var = op.x input_op = input_var.op - slice_op.enclosing_block.replace_uses_of_var_after_op(anchor_op=input_op, - old_var=slice_op.outputs[0], new_var=input_var) + op.enclosing_block.replace_uses_of_var_after_op( + anchor_op=input_op, old_var=op.outputs[0], new_var=input_var + ) # Remove all the ops at once - block.remove_ops([slice_op]) + block.remove_ops([op]) return True -op_to_removal_fn = {'reshape': remove_reshape, - 'split': remove_split, - 'slice_by_index': remove_slice, - 'slice_by_size': remove_slice, +_SUPPORTED_OPS = { + "add", + "mul", + "floor_div", + "pow", + "real_div", + "sub", + "reshape", + "split", + "slice_by_index", + "slice_by_size", + "pad", + "tile", + "upsample_nearest_neighbor", + "upsample_bilinear", + "resize_bilinear", + "crop", + "linear_activation" +} +op_to_removal_fn = { + "add": remove_elementwise, + "mul": remove_elementwise, + "floor_div": remove_elementwise, + "pow": remove_elementwise, + "real_div": remove_elementwise, + "sub": remove_elementwise, + "reshape": remove_same_shape, + "split": remove_same_shape, + "slice_by_index": remove_same_shape, + "slice_by_size": remove_same_shape, + "pad": remove_same_shape, + "tile": remove_same_shape, + "upsample_nearest_neighbor": remove_same_shape, + "upsample_bilinear": remove_same_shape, + "resize_bilinear": remove_same_shape, + "crop": remove_same_shape, + "linear_activation": remove_linear, } + def match_pattern(op): # abort if op output is a block output if op.outputs[0] in op.enclosing_block.outputs: return None - if op.op_type in {'reshape', 'split', 'slice_by_index', 'slice_by_size'}: - - input_shape = op.x.sym_type - if len(op.outputs) != 1: - return None - output_shape = op.outputs[0].sym_type + if op.op_type in _SUPPORTED_OPS: - if input_shape != output_shape: + if len(op.outputs) != 1: return None return op_to_removal_fn[op.op_type] @@ -86,10 +155,11 @@ def noop_elimination_block(block): return status return False + @register_pass(namespace="common") def noop_elimination(prog): """ - We remove reshape/slice/split if it's a no-op + We remove ops that has no effect. Given: %1 (1, 96, 128, 64, fp32) = ... @@ -108,4 +178,3 @@ def noop_elimination(prog): block_changed = True while block_changed: block_changed = noop_elimination_block(f) - diff --git a/coremltools/converters/mil/mil/passes/pad_conv_connect.py b/coremltools/converters/mil/mil/passes/pad_conv_connect.py new file mode 100644 index 000000000..93177ed69 --- /dev/null +++ b/coremltools/converters/mil/mil/passes/pad_conv_connect.py @@ -0,0 +1,131 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2020, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +from __future__ import print_function as _ +from __future__ import division as _ +from __future__ import absolute_import as _ + +from coremltools.converters.mil.mil.passes.pass_registry import register_pass +from coremltools.converters.mil.mil import Builder as mb +import numpy as np +import copy + +def match_pattern(op): + ret = set([]) + child_ops = op.outputs[0].child_ops + + for child_op in child_ops: + if child_op.op_type != "transpose": + continue + skip_ops = child_op.outputs[0].child_ops + for skip_op in skip_ops: + if "conv" not in skip_op.op_type: + continue + ret.update([child_op]) + + return ret if len(ret) != 0 else None + + +def try_to_transform(pad_op, transpose_ops, block): + + def _compute_new_pad_values(transpose_op): + if pad_op.inputs["pad"].val is None: + return None + pad_amounts = np.reshape(pad_op.inputs["pad"].val, [-1, 2]) + transpose_axes = transpose_op.inputs["perm"].val + rank_diff = len(transpose_axes) - pad_amounts.shape[0] + pad_amounts_new = copy.deepcopy(pad_amounts) + # append "rank_diff" rows of zeros to the top + pad_amounts_new = np.concatenate( + (np.zeros((2 * rank_diff)).reshape(-1, 2), pad_amounts_new) + ) + pad_amounts_new = pad_amounts_new.astype(pad_amounts.dtype) + pad_amounts = np.concatenate( + (np.zeros((2 * rank_diff)).reshape(-1, 2), pad_amounts) + ) + for i, axis in enumerate(transpose_axes): + pad_amounts_new[i][0] = pad_amounts[axis][0] + pad_amounts_new[i][1] = pad_amounts[axis][1] + + # get the top "rank_diff" rows + top_rows = pad_amounts_new[:rank_diff, :] + if not np.all(top_rows == 0): + return False + # cut "rank_diff" from the top + pad_amounts_new = pad_amounts_new[rank_diff:, :] + pad_amounts_new = pad_amounts_new.flatten() + return pad_amounts_new + + if pad_op.outputs[0] in pad_op.enclosing_block.outputs: + return False + if len(set(pad_op.outputs[0].child_ops)) != len(transpose_ops): + return False + + for transpose_op in transpose_ops: + pad_amounts_new = _compute_new_pad_values(transpose_op) + if pad_amounts_new is None: + continue + + with pad_op.enclosing_block: + new_transpose_var = mb.transpose(x=pad_op.inputs["x"], perm=transpose_op.inputs["perm"].val, before_op=transpose_op) + new_pad_inputs = {"x": new_transpose_var, "pad": pad_amounts_new} + for k, v in pad_op.inputs.items(): + if k not in new_pad_inputs: + new_pad_inputs[k] = v + new_pad_var = mb.pad(before_op=transpose_op, **new_pad_inputs) + pad_op.enclosing_block.replace_uses_of_var_after_op( + anchor_op=transpose_op, old_var=transpose_op.outputs[0], new_var=new_pad_var + ) + + pad_op.enclosing_block.remove_ops(list(transpose_ops) + [pad_op]) + + return True + +def pad_conv_connect_block(block): + fusion_status = False + for op in list(block.operations): + for b in op.blocks: + block_changed = True + while block_changed: + block_changed = pad_conv_connect_block(b) + + if op.op_type != "pad": + continue + + transpose_ops = match_pattern(op) + if transpose_ops is not None: + with block: + fusion_status = try_to_transform(op, transpose_ops, block) + # has to break as the downstream iterator is affected. + if fusion_status: + return fusion_status + return fusion_status + + +@register_pass(namespace="common") +def pad_conv_connect(prog): + """ + When we observe pad -> transpose -> conv, we move the pad to be next to conv. + This allows us to meld pad + conv if possible. + + Given: + %1 = pad(%0, ...) + %2 = transpose(%1, ...) + %3 = conv(%2, ...) + ... + + Result: + %1.a = transpose(%0, ...) + $2.a = pad(%1.a, ...) + %3 = conv(%2.a) + ... + + """ + for f_name, f in prog.functions.items(): + block_changed = True + while block_changed: + block_changed = pad_conv_connect_block(f) diff --git a/coremltools/converters/mil/mil/passes/reduce_transposes.py b/coremltools/converters/mil/mil/passes/reduce_transposes.py index 79c414c8c..c2b88edeb 100644 --- a/coremltools/converters/mil/mil/passes/reduce_transposes.py +++ b/coremltools/converters/mil/mil/passes/reduce_transposes.py @@ -927,7 +927,11 @@ def _remove_transpose_ops(self, starting_transpose_op): # Change the name of the input_var to match the block output if input_var is not changed. # If the same input_var is in output twice, we can't rename it twice, therefore we initiate an # Identity op to match the name - if input_var not in name_changed_vars: + if input_var in self.block.inputs.values(): + with self.block: + input_var = mb.identity(x=input_var, before_op=op, name=output_var.name) + parent_op = None # set anchor op as None. + elif input_var not in name_changed_vars: input_var.name = output_var.name input_var.op.name = output_var.op.name name_changed_vars.update([input_var]) diff --git a/coremltools/converters/mil/mil/passes/test_noop_elimination.py b/coremltools/converters/mil/mil/passes/test_noop_elimination.py index 2e4c5e130..4c2110254 100644 --- a/coremltools/converters/mil/mil/passes/test_noop_elimination.py +++ b/coremltools/converters/mil/mil/passes/test_noop_elimination.py @@ -17,6 +17,69 @@ import numpy as np +@pytest.mark.parametrize("op_type, pos, val", itertools.product(['add', 'mul', 'floor_div', 'pow', 'real_div', 'sub'], ['x', 'y'], [0, 1, [0, 0, 0, 0], [1, 1, 1, 1]])) +def test_elementwise_elimination(op_type, pos, val): + if 'div' in op_type and np.prod(val) == 0: + return + if 'pow' in op_type and (val != 0 or val != 1): + return + + test_op = getattr(mb, op_type) + + @mb.program(input_specs=[mb.TensorSpec(shape=(2, 4))]) + def prog(x): + if pos == "x": + r1 = test_op(x=val, y=x) + else: + r1 = test_op(x=x, y=val) + return mb.relu(x=r1) + + prev_prog, prev_block, block = apply_pass_and_basic_check( + prog, "common::noop_elimination" + ) + original_program = [op_type, "relu"] + new_program = original_program + if op_type in {'add'}: + if val == 0 or val == [0, 0, 0, 0]: + new_program = ["relu"] + elif op_type in {'mul'}: + if val == 1 or val == [1, 1, 1, 1]: + new_program = ["relu"] + elif op_type in {'pow', 'real_div', 'floor_div'}: + if pos == 'y' and (val == 1 or val == [1, 1, 1, 1]): + new_program = ["relu"] + elif op_type in {'sub'}: + if pos == 'y' and (val == 0 or val == [0, 0, 0, 0]): + new_program = ["relu"] + + assert get_op_types_in_program(prev_prog) == original_program + assert get_op_types_in_program(prog) == new_program + assert_model_is_valid( + prog, + {"x": (2, 4)}, + expected_output_shapes={block.outputs[0].name: (2, 4)}, + ) + +def test_elementwise_broadcast(): + + @mb.program(input_specs=[mb.TensorSpec(shape=[4])]) + def prog(x): + r1 = mb.add(x=x, y=[[0, 0, 0, 0], [0, 0, 0, 0]]) + return mb.relu(x=r1) + + prev_prog, prev_block, block = apply_pass_and_basic_check( + prog, "common::noop_elimination" + ) + original_program = ["add", "relu"] + + assert get_op_types_in_program(prev_prog) == original_program + assert get_op_types_in_program(prog) == original_program + assert_model_is_valid( + prog, + {"x": [4]}, + expected_output_shapes={block.outputs[0].name: (2, 4)}, + ) + def test_reshape_elimination(): @mb.program(input_specs=[mb.TensorSpec(shape=(2, 4))]) def prog(x): @@ -154,3 +217,165 @@ def prog(x): ) +def test_pad_elimination(): + @mb.program(input_specs=[mb.TensorSpec(shape=(2, 4))]) + def prog(x): + r1 = mb.pad(x=x, pad=[0, 0, 0, 0]) + return mb.relu(x=r1) + + prev_prog, prev_block, block = apply_pass_and_basic_check( + prog, "common::noop_elimination" + ) + assert get_op_types_in_program(prev_prog) == ["pad", "relu"] + assert get_op_types_in_program(prog) == ["relu"] + assert_model_is_valid( + prog, + {"x": (2, 4)}, + expected_output_shapes={block.outputs[0].name: (2, 4)}, + ) + + +def test_keep_pad(): + @mb.program(input_specs=[mb.TensorSpec(shape=(2, 4))]) + def prog(x): + r1 = mb.pad(x=x, pad=[4, 4, 2, 2]) + return mb.relu(x=r1) + + prev_prog, prev_block, block = apply_pass_and_basic_check( + prog, "common::noop_elimination" + ) + assert get_op_types_in_program(prev_prog) == ["pad", "relu"] + assert get_op_types_in_program(prog) == ["pad", "relu"] + assert_model_is_valid( + prog, + {"x": (2, 4)}, + expected_output_shapes={block.outputs[0].name: (10, 8)}, + ) + + +def test_tile_elimination(): + @mb.program(input_specs=[mb.TensorSpec(shape=(2, 4))]) + def prog(x): + r1 = mb.tile(x=x, reps=[1, 1]) + return mb.relu(x=r1) + + prev_prog, prev_block, block = apply_pass_and_basic_check( + prog, "common::noop_elimination" + ) + assert get_op_types_in_program(prev_prog) == ["tile", "relu"] + assert get_op_types_in_program(prog) == ["relu"] + assert_model_is_valid( + prog, + {"x": (2, 4)}, + expected_output_shapes={block.outputs[0].name: (2, 4)}, + ) + + +def test_keep_tile(): + @mb.program(input_specs=[mb.TensorSpec(shape=(2, 4))]) + def prog(x): + r1 = mb.tile(x=x, reps=[2, 2]) + return mb.relu(x=r1) + + prev_prog, prev_block, block = apply_pass_and_basic_check( + prog, "common::noop_elimination" + ) + assert get_op_types_in_program(prev_prog) == ["tile", "relu"] + assert get_op_types_in_program(prog) == ["tile", "relu"] + assert_model_is_valid( + prog, + {"x": (2, 4)}, + expected_output_shapes={block.outputs[0].name: (4, 8)}, + ) + + +def test_upsample_nearest_neighbor_elimination(): + @mb.program(input_specs=[mb.TensorSpec(shape=(3, 2, 4))]) + def prog(x): + r1 = mb.upsample_nearest_neighbor(x=x) + return mb.relu(x=r1) + + prev_prog, prev_block, block = apply_pass_and_basic_check( + prog, "common::noop_elimination" + ) + assert get_op_types_in_program(prev_prog) == ["upsample_nearest_neighbor", "relu"] + assert get_op_types_in_program(prog) == ["relu"] + assert_model_is_valid( + prog, + {"x": (3, 2, 4)}, + expected_output_shapes={block.outputs[0].name: (3, 2, 4)}, + ) + + +def test_upsample_bilinear_elimination(): + @mb.program(input_specs=[mb.TensorSpec(shape=(3, 2, 4))]) + def prog(x): + r1 = mb.upsample_bilinear(x=x) + return mb.relu(x=r1) + + prev_prog, prev_block, block = apply_pass_and_basic_check( + prog, "common::noop_elimination" + ) + assert get_op_types_in_program(prev_prog) == ["upsample_bilinear", "relu"] + assert get_op_types_in_program(prog) == ["relu"] + assert_model_is_valid( + prog, + {"x": (3, 2, 4)}, + expected_output_shapes={block.outputs[0].name: (3, 2, 4)}, + ) + + +def test_resize_bilinear_elimination(): + @mb.program(input_specs=[mb.TensorSpec(shape=(3, 2, 4))]) + def prog(x): + r1 = mb.resize_bilinear(x=x, target_size_height=2, target_size_width=4) + return mb.relu(x=r1) + + prev_prog, prev_block, block = apply_pass_and_basic_check( + prog, "common::noop_elimination" + ) + assert get_op_types_in_program(prev_prog) == ["resize_bilinear", "relu"] + assert get_op_types_in_program(prog) == ["relu"] + assert_model_is_valid( + prog, + {"x": (3, 2, 4)}, + expected_output_shapes={block.outputs[0].name: (3, 2, 4)}, + ) + + +def test_crop_elimination(): + @mb.program(input_specs=[mb.TensorSpec(shape=(3, 2, 4))]) + def prog(x): + r1 = mb.crop(x=x, crop_height=[0, 0], crop_width=[0, 0]) + return mb.relu(x=r1) + + prev_prog, prev_block, block = apply_pass_and_basic_check( + prog, "common::noop_elimination" + ) + assert get_op_types_in_program(prev_prog) == ["crop", "relu"] + assert get_op_types_in_program(prog) == ["relu"] + assert_model_is_valid( + prog, + {"x": (3, 2, 4)}, + expected_output_shapes={block.outputs[0].name: (3, 2, 4)}, + ) + + +def test_linear_elimination(): + @mb.program(input_specs=[mb.TensorSpec(shape=(2, 4))]) + def prog(x): + r1 = mb.linear_activation(x=x, alpha=1.0, beta=0.0) + return mb.relu(x=r1) + + prev_prog, prev_block, block = apply_pass_and_basic_check( + prog, "common::noop_elimination" + ) + assert get_op_types_in_program(prev_prog) == ["linear_activation", "relu"] + assert get_op_types_in_program(prog) == ["relu"] + assert_model_is_valid( + prog, + {"x": (2, 4)}, + expected_output_shapes={block.outputs[0].name: (2, 4)}, + ) + + diff --git a/coremltools/converters/mil/mil/passes/test_pad_conv_pass.py b/coremltools/converters/mil/mil/passes/test_pad_conv_pass.py new file mode 100644 index 000000000..05e881026 --- /dev/null +++ b/coremltools/converters/mil/mil/passes/test_pad_conv_pass.py @@ -0,0 +1,126 @@ +# Copyright (c) 2020, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +from coremltools.converters.mil.mil import Builder as mb +from coremltools.converters.mil.testing_utils import ( + assert_op_count_match, + assert_model_is_valid, + get_op_types_in_program, + apply_pass_and_basic_check, +) +import unittest +import pytest + +import numpy as np + +np.random.seed(1984) + + +class PadConvOptimizationPass(unittest.TestCase): + """ + Input graph: + input -----> pad -----> transpose -----> conv -----> transpose ---> out + + Output graph: + input -----> transpose -----> pad ----> conv -----> transpose ----> out + """ + + def test_simple_direct_output(self): + @mb.program(input_specs=[mb.TensorSpec(shape=(1, 16, 20, 24))]) + def prog(x): + x = mb.pad(x=x, pad=[0,0,1,1,1,1,0,0]) + x = mb.transpose(x=x, perm=[0, 3, 1, 2]) + x = mb.conv(x=x, weight=np.random.random([24,24,3,3]), pad_type="valid") + x = mb.transpose(x=x, perm=[0, 2, 3, 1]) + return x + + prev_prog, prev_block, block = apply_pass_and_basic_check( + prog, "common::pad_conv_connect" + ) + self.assertEqual( + get_op_types_in_program(prev_prog), ["pad", "transpose", "conv", "transpose"] + ) + self.assertEqual(get_op_types_in_program(prog), ["transpose", "pad", "conv", "transpose"]) + assert_model_is_valid( + prog, + {"x": (1, 16, 20, 24)}, + expected_output_shapes={block.outputs[0].name: (1, 16, 20, 24)}, + ) + + """ + Input graph: + input -----> pad -----> transpose -----> conv -----> transpose ---> out + | + | + --------> transpose -----> conv -----> transpose ---> out + + Output graph: + input ---------> transpose -----> pad -----> conv -----> transpose ---> out + | + | + ------> transpose -----> pad -----> conv -----> transpose ---> out + + """ + + def test_pad_transposed_forked_conv(self): + @mb.program(input_specs=[mb.TensorSpec(shape=(1, 16, 20, 24))]) + def prog(x): + pad = mb.pad(x=x, pad=[0,0,1,1,1,1,0,0]) + x = mb.transpose(x=pad, perm=[0, 3, 1, 2]) + x = mb.conv(x=x, weight=np.random.random([24,24,3,3]), pad_type="valid") + x = mb.transpose(x=x, perm=[0, 2, 3, 1]) + y = mb.transpose(x=pad, perm=[0, 3, 1, 2]) + y = mb.conv(x=y, weight=np.random.random([24,24,3,3]), pad_type="valid") + y = mb.transpose(x=y, perm=[0, 2, 3, 1]) + return x, y + + prev_prog, prev_block, block = apply_pass_and_basic_check( + prog, "common::pad_conv_connect" + ) + self.assertEqual( + get_op_types_in_program(prev_prog), ["pad", "transpose", "conv", "transpose", "transpose", "conv", "transpose"] + ) + self.assertEqual(get_op_types_in_program(prog), ["transpose", "pad", "conv", "transpose", "transpose", "pad", "conv", "transpose"]) + assert_model_is_valid( + prog, + {"x": (1, 16, 20, 24)}, + expected_output_shapes={block.outputs[0].name: (1, 16, 20, 24), + block.outputs[1].name: (1, 16, 20, 24)}, + ) + + """ + Input graph: + input -----> pad -----> transpose -----> conv -----> transpose ---> out + | + | + ---------> out + + Output graph: + No change. + """ + + def test_pad_output(self): + @mb.program(input_specs=[mb.TensorSpec(shape=(1, 16, 20, 24))]) + def prog(x): + pad = mb.pad(x=x, pad=[0,0,1,1,1,1,0,0]) + x = mb.transpose(x=pad, perm=[0, 3, 1, 2]) + x = mb.conv(x=x, weight=np.random.random([24,24,3,3]), pad_type="valid") + x = mb.transpose(x=x, perm=[0, 2, 3, 1]) + return x, pad + + prev_prog, prev_block, block = apply_pass_and_basic_check( + prog, "common::pad_conv_connect" + ) + self.assertEqual( + get_op_types_in_program(prev_prog), ["pad", "transpose", "conv", "transpose"] + ) + self.assertEqual(get_op_types_in_program(prog), ["pad", "transpose", "conv", "transpose"]) + assert_model_is_valid( + prog, + {"x": (1, 16, 20, 24)}, + expected_output_shapes={block.outputs[0].name: (1, 16, 20, 24), + block.outputs[1].name: (1, 18, 22, 24)}, + ) + diff --git a/coremltools/converters/mil/mil/passes/test_reduce_transposes_pass.py b/coremltools/converters/mil/mil/passes/test_reduce_transposes_pass.py index 1b0457ddf..37074e860 100644 --- a/coremltools/converters/mil/mil/passes/test_reduce_transposes_pass.py +++ b/coremltools/converters/mil/mil/passes/test_reduce_transposes_pass.py @@ -21,6 +21,34 @@ class TransposeOptimizationPass(unittest.TestCase): """""" + """ + Input graph: + input -----> transpose(axis=[1,0]) -----> transpose(axis=[1,0]) ---> out + + Output graph: + input -----> relu -----> out + """ + + def test_simple_consecutive_ops_fusion_direct_output(self): + @mb.program(input_specs=[mb.TensorSpec(shape=(10, 20))]) + def prog(x): + x = mb.transpose(x=x, perm=[1, 0]) + x = mb.transpose(x=x, perm=[1, 0]) + return x + + prev_prog, prev_block, block = apply_pass_and_basic_check( + prog, "common::reduce_transposes" + ) + self.assertEqual( + get_op_types_in_program(prev_prog), ["transpose", "transpose"] + ) + self.assertEqual(get_op_types_in_program(prog), ["identity"]) + assert_model_is_valid( + prog, + {"x": (10, 20)}, + expected_output_shapes={block.outputs[0].name: (10, 20)}, + ) + """ Input graph: input -----> transpose(axis=[1,0]) -----> transpose(axis=[1,0]) ----> relu ---> out diff --git a/coremltools/converters/mil/mil/types/__init__.py b/coremltools/converters/mil/mil/types/__init__.py index 6a1a43ec4..8e6be9176 100644 --- a/coremltools/converters/mil/mil/types/__init__.py +++ b/coremltools/converters/mil/mil/types/__init__.py @@ -43,6 +43,7 @@ is_scalar, is_tensor, is_tuple, + is_dict, is_str, is_builtin, promote_types, diff --git a/coremltools/converters/mil/mil/types/type_mapping.py b/coremltools/converters/mil/mil/types/type_mapping.py index f152498fc..8b4b67e00 100644 --- a/coremltools/converters/mil/mil/types/type_mapping.py +++ b/coremltools/converters/mil/mil/types/type_mapping.py @@ -67,7 +67,7 @@ def np_dtype_to_py_type(np_dtype): # Can't use dict, as hash(np.int32) != hash(val.dtype) if np_dtype in [np.int32, np.int64]: return int - if np_dtype == np.bool: + if np_dtype in [np.bool, np.bool_]: return bool if np_dtype in [np.float32, np.float64]: return float @@ -167,6 +167,15 @@ def is_tuple(t): return False return type_info == "tuple" +def is_dict(t): + if t is None: + return False + try: + type_info = get_type_info(t).name + except TypeError: + return False + return type_info == "dict" + def is_builtin(t): return is_scalar(t) or is_tensor(t) or is_str(t) or is_tuple(t) diff --git a/coremltools/converters/mil/mil/var.py b/coremltools/converters/mil/mil/var.py index c57df4ed3..ddc7d258f 100644 --- a/coremltools/converters/mil/mil/var.py +++ b/coremltools/converters/mil/mil/var.py @@ -163,12 +163,22 @@ def shape_str(self): shape_str = str(self.shape)[:-1] # trim the ")" if self.rank > 1: shape_str += ", " - shape_str += types.builtin_to_string(self.dtype) + ")" + annotation + if types.builtin_to_string(self.dtype) is None: + shape_str += ")" + annotation + else: + shape_str += types.builtin_to_string(self.dtype) + ")" + annotation return shape_str def type_str(self): is_tensor = types.is_tensor(self.sym_type) - return "(Tensor)" if is_tensor else "(Scalar)" + is_list = types.is_list(self.sym_type) + if is_tensor: + type_string = "(Tensor)" + elif is_list: + type_string = "(List)" + else: + type_string = "(Scalar)" + return type_string def set_name(self, name): self.name = name @@ -181,7 +191,7 @@ class ListVar(Var): __slots__ = ["_elem_type", "init_length", "dynamic_length"] def __init__( - self, name, elem_type=None, init_length=None, dynamic_length=True, **kwargs + self, name, elem_type=None, init_length=None, dynamic_length=True, sym_val=None, **kwargs ): """ elem_type (builtin.tensor) @@ -190,11 +200,13 @@ def __init__( dynamic_length (bool): True to allow list to grow. False uses init_length as the fixed size (init_length is runtime length). + + sym_val: value of the list, if available """ super(ListVar, self).__init__( name=name, sym_type=types.list(elem_type, init_length, dynamic_length), - sym_val=None, + sym_val=sym_val, **kwargs ) self._elem_type = elem_type @@ -229,13 +241,18 @@ def shape_str(self): length = str(self.init_length) if self._elem_type == types.unknown: return "List[{}, unknown]".format(length) - elem_shape = self._elem_type.get_shape() - elem_dtype = self._elem_type.get_primitive() - shape_str = str(elem_shape)[:-1] # trim the ")" - if len(elem_shape) > 1: - shape_str += ", " - shape_str += types.builtin_to_string(elem_dtype) + ")" - return "List[{}, {}]".format(length, shape_str) + if self._elem_type == types.str: + return "List[{}, str]".format(length) + elif self._elem_type == types.int64: + return "List[{}, int]".format(length) + else: + elem_shape = self._elem_type.get_shape() + elem_dtype = self._elem_type.get_primitive() + shape_str = str(elem_shape)[:-1] # trim the ")" + if len(elem_shape) > 1: + shape_str += ", " + shape_str += types.builtin_to_string(elem_dtype) + ")" + return "List[{}, {}]".format(length, shape_str) class InternalVar(Var): diff --git a/coremltools/converters/mil/testing_reqs.py b/coremltools/converters/mil/testing_reqs.py index cc84639a3..929e8774f 100644 --- a/coremltools/converters/mil/testing_reqs.py +++ b/coremltools/converters/mil/testing_reqs.py @@ -6,6 +6,7 @@ import os import itertools import numpy as np +from numpy import linalg as la import pytest from coremltools.converters.mil.mil import Builder as mb @@ -21,7 +22,7 @@ ) from .testing_utils import ssa_fn, is_close, random_gen, converter, _converter -backends = _converter.ConverterRegistry.backends.keys() +backends = ['nn_proto'] np.random.seed(1984) diff --git a/coremltools/converters/mil/testing_utils.py b/coremltools/converters/mil/testing_utils.py index 3ca2e12b7..98e1bc9c3 100644 --- a/coremltools/converters/mil/testing_utils.py +++ b/coremltools/converters/mil/testing_utils.py @@ -13,6 +13,7 @@ from coremltools.converters.mil.mil import Program, Function from coremltools.converters.mil.mil.passes.pass_registry import PASS_REGISTRY from coremltools._deps import _IS_MACOS +import PIL.Image converter = converter _converter = _converter @@ -115,7 +116,7 @@ def random_gen( r = dtype((rand_max - rand_min) * np.random.random() + rand_min) if not allow_duplicate and r in ret: continue - if np.fabs(np.round(r) - r) > eps_from_int: + if np.issubdtype(dtype, np.integer) or np.fabs(np.round(r) - r) > eps_from_int: ret.append(r) break ret = np.array(ret).reshape(shape) @@ -164,17 +165,13 @@ def is_close(expected, actual, atol=1e-04, rtol=1e-05): def run_core_ml_predict(proto, input_key_values, use_cpu_only=False): model = coremltools.models.MLModel(proto, useCPUOnly=use_cpu_only) - input_key_values = dict( - [ - ( - k, - v.astype(np.float32) - if not np.isscalar(v) and not v.shape == () - else np.array([v], dtype=np.float32), - ) - for k, v in input_key_values.items() - ] - ) + for k, v in input_key_values.items(): + if isinstance(v, PIL.Image.Image): + continue + elif not np.isscalar(v) and not v.shape == (): + input_key_values[k] = v.astype(np.float32) + else: + input_key_values[k] = np.array([v], dtype=np.float32) return model.predict(input_key_values, useCPUOnly=use_cpu_only) @@ -215,7 +212,7 @@ def compare_backend( for o, expected in expected_outputs.items(): msg = ( "Output {} differs. useCPUOnly={}.\nInput={}, " - + "Expected={}, Output={}\n" + + "\nExpected={}, \nOutput={}\n" ) assert is_close(expected, pred[o], atol, rtol), msg.format( o, use_cpu_only, input_key_values, expected, pred[o] @@ -229,7 +226,7 @@ def compare_shapes( Inputs: - proto: MLModel proto. - - input_key_values: str -> np.array. Keys must match those in + - input_key_values: str -> np.array or PIL.Image. Keys must match those in input_placeholders. - expected_outputs: dict[str, np.array]. diff --git a/coremltools/converters/onnx/_operators.py b/coremltools/converters/onnx/_operators.py index 832096cf4..618d3cdd0 100644 --- a/coremltools/converters/onnx/_operators.py +++ b/coremltools/converters/onnx/_operators.py @@ -430,7 +430,8 @@ def _get_conv_params(builder, node, graph, err, params_dict, axis=None): else: params_dict["W"] = params_dict["W"].transpose((2, 3, 0, 1)) # type: ignore - if "auto_pad" in node.attrs and not _compare(node.attrs["auto_pad"], "VALID"): + if "auto_pad" in node.attrs and \ + not (_compare(node.attrs["auto_pad"], 'VALID') or _compare(node.attrs["auto_pad"], 'NOTSET')): params_dict["padding_type"] = "same" if _compare(node.attrs["auto_pad"], "SAME_LOWER"): params_dict["same_padding_asymmetry_mode"] = "TOP_LEFT_HEAVY" @@ -820,7 +821,8 @@ def _get_pool_params(builder, node, graph, err, params_dict, axis=None): params_dict["stride_height"] = strides[0] params_dict["stride_width"] = strides[1] - if "auto_pad" in node.attrs and not _compare(node.attrs["auto_pad"], "VALID"): + if "auto_pad" in node.attrs and \ + not (_compare(node.attrs["auto_pad"], 'VALID') or _compare(node.attrs["auto_pad"], 'NOTSET')): params_dict["padding_type"] = "SAME" if _compare(node.attrs["auto_pad"], "SAME_LOWER"): params_dict["same_padding_asymmetry_mode"] = "TOP_LEFT_HEAVY" diff --git a/coremltools/converters/onnx/_operators_nd.py b/coremltools/converters/onnx/_operators_nd.py index f85a81b46..bf092e2fb 100644 --- a/coremltools/converters/onnx/_operators_nd.py +++ b/coremltools/converters/onnx/_operators_nd.py @@ -1606,6 +1606,12 @@ def _convert_pad(builder, node, graph, err): if mode == "constant": pads = node.attrs.get("pads", []) value = node.attrs.get("value", 0.0) + # onnx padding spec: [x1_top, ..., xn_top, x1_bottom, ..., xn_bottom] + # coreml padding spec: [x1_top, x1_bottom, ..., xn_top, xn_bottom] + assert len(pads) % 2 == 0, 'even number of pads expected' + pads_coreml = [None] * len(pads) + pads_coreml[::2] = pads[:len(pads) // 2] + pads_coreml[1::2] = pads[len(pads) // 2:] builder.add_constant_pad( name=node.name, @@ -1613,7 +1619,7 @@ def _convert_pad(builder, node, graph, err): output_name=node.outputs[0], value=value, pad_to_given_output_size_mode=False, - pad_amounts=pads, + pad_amounts=pads_coreml, ) else: _convert_pad_5d(builder, node, graph, err) diff --git a/coremltools/models/neural_network/builder.py b/coremltools/models/neural_network/builder.py index 4ed97bc04..215758935 100644 --- a/coremltools/models/neural_network/builder.py +++ b/coremltools/models/neural_network/builder.py @@ -86,12 +86,12 @@ def _verify_quantization_arguments(weight=bytes(), output_channels=1, **kwargs): raise ValueError( "quant_scale and quant_bias parameters must be provided for linear quantization type" ) - if len(quant_scale) != 1 and len(quant_scale) != output_channels: + if not _np.isscalar(quant_scale) and (len(quant_scale) != 1 and len(quant_scale) != output_channels): raise ValueError( "quant_scale should be of type float or an array of length outputChannels" ) if not int_8_dynamic_quantize: - if len(quant_bias) != 1 and len(quant_bias) != output_channels: + if not _np.isscalar(quant_scale) and len(quant_bias) != 1 and len(quant_bias) != output_channels: raise ValueError( "quant_bias should be of type float or an array of length outputChannels" ) @@ -2346,7 +2346,8 @@ def add_convolution( return # Weight assignments - if len(kwargs) > 0: + quantization = len(kwargs) > 0 and ('quantization_type' in kwargs and kwargs.get('quantization_type') != None) + if quantization: _verify_quantization_arguments( weight=W, output_channels=output_channels, **kwargs ) @@ -2377,7 +2378,7 @@ def add_convolution( # Assign weights weights = spec_layer_params.weights - if len(kwargs) == 0: # no quantization + if not quantization: # no quantization weights.floatValue.extend(Wt.flatten()) else: # there is quantization W_bytes = bytes() @@ -4091,7 +4092,7 @@ def add_unary( alpha=1.0, shift=0, scale=1.0, - epsilon=1e-6, + epsilon=None, ): """ Add a Unary layer. Applies the specified function (mode) to all the elements of the input. @@ -4128,6 +4129,14 @@ def add_unary( """ spec_layer = self._add_generic_layer(name, [input_name], [output_name]) spec_layer_params = spec_layer.unary + if epsilon is None: + # Use the default value of epsilon to be 1e-4, instead of 1e-6, if mode = "rsqrt" or "inverse" + if mode == "inverse" or mode == "rsqrt": + epsilon = 1e-4 + elif mode == "log": + epsilon = 1e-45 + else: + epsilon = 1e-6 spec_layer_params.epsilon = epsilon spec_layer_params.alpha = alpha spec_layer_params.shift = shift diff --git a/coremltools/test/api/test_api_examples.py b/coremltools/test/api/test_api_examples.py index eabfa5c3d..752f08985 100644 --- a/coremltools/test/api/test_api_examples.py +++ b/coremltools/test/api/test_api_examples.py @@ -1097,7 +1097,7 @@ def forward(self, x, y): np.testing.assert_allclose(result[name], expected.detach().numpy()) ############################################################################### -# Note: all tests are examples provided to other teams for testing +# Note: all tests are examples provided to other teams for testing # Each test case is expected to be runnable and self-complete. ############################################################################### @@ -1130,13 +1130,13 @@ def test_convert_tf2_keras(tmpdir): @staticmethod @pytest.mark.skipif(not _HAS_TORCH, reason=MSG_TORCH_NOT_FOUND) @pytest.mark.skipif(ct.utils._python_version() < (3, 0, 0), - reason="PyTorch no longer supports Python 2.7") + reason="PyTorch no longer supports Python 2.7") def test_convert_torch_traced_model(tmpdir): import torch from torch import nn class Network(nn.Module): def __init__(self): - super().__init__() + super(Network, self).__init__() self.hidden = nn.Linear(100, 10) self.output = nn.Linear(10, 2) self.sigmoid = nn.Sigmoid() @@ -1151,7 +1151,7 @@ def forward(self, x): torch_model = Network() torch_model.eval() - example_input = torch.rand(1, 100) + example_input = torch.rand(1, 100) traced_model = torch.jit.trace(torch_model, example_input) model = ct.convert( traced_model, diff --git a/coremltools/test/neural_network/test_nn_builder.py b/coremltools/test/neural_network/test_nn_builder.py index 3c932a7ae..776205f30 100644 --- a/coremltools/test/neural_network/test_nn_builder.py +++ b/coremltools/test/neural_network/test_nn_builder.py @@ -137,6 +137,7 @@ def build_quant_conv_layer( quant_scale=None, quant_bias=None, quant_lut=None, + output_channels=2, ): input_features = [("data", datatypes.Array(1, 2, 2))] output_features = [("out", datatypes.Array(2, 1, 1))] @@ -144,7 +145,7 @@ def build_quant_conv_layer( builder.add_convolution( name="conv", kernel_channels=1, - output_channels=2, + output_channels=output_channels, height=2, width=2, stride_height=1, @@ -196,6 +197,24 @@ def test_linear_quant_convolution_8bit_vector_scalebias(self): expected_out = np.reshape(np.array([8, 44]), (2, 1, 1)) self.assertTrue(np.allclose(out, expected_out)) + @unittest.skip(" Investigate numerical discrepancy during quantization in CoreML") + def test_linear_quant_convolution_8bit_float_scale_and_bias(self): + W = np.array(([[[[1, 248], [248, 248]]]]), dtype=np.uint8) + mlmodel = self.build_quant_conv_layer( + W=W.flatten().tobytes(), + quantization_type="linear", + nbits=8, + quant_scale=[15.346457], + quant_bias=[-3913.3464], + output_channels=1, + ) + data = np.ones((1, 2, 2)) + data_dict = {"data": data} + out = mlmodel.predict(data_dict, useCPUOnly=True)["out"] + # Output should be equal to: (scale*(1+248+248+248)+(4*bias)) + expected_out = np.reshape(np.array([-4220.275]), (1, 1, 1, 1, 1)) + self.assertTrue(np.allclose(out, expected_out)) + def test_lut_quant_convolution_2bit(self): W = np.zeros((2, 2, 1, 2), dtype=np.uint8) W[:, :, :, 0] = 0 diff --git a/coremltools/test/neural_network/test_simple_nn_inference.py b/coremltools/test/neural_network/test_simple_nn_inference.py new file mode 100644 index 000000000..1855ff48d --- /dev/null +++ b/coremltools/test/neural_network/test_simple_nn_inference.py @@ -0,0 +1,43 @@ +import coremltools +import coremltools.models.datatypes as datatypes +from coremltools.models import neural_network as neural_network +import numpy as np +import os +import pytest + +class TestNeuralNetworkPrediction: + + @staticmethod + def test_lrn_model(tmpdir): + + input_dim = (1, 3, 3) + input_features = [("data", datatypes.Array(*input_dim))] + output_features = [("output", datatypes.Array(*input_dim))] + + builder = neural_network.NeuralNetworkBuilder(input_features, output_features) + builder.add_lrn( + name="lrn", + input_name="data", + output_name="output", + alpha=2, + beta=3, + local_size=1, + k=8, + ) + + input = {"data": np.ones((1, 3, 3))} + expected = 1e-3 * np.ones((1, 3, 3)) + model_path = os.path.join(str(tmpdir), "lrn_model.mlmodel") + coremltools.models.utils.save_spec(builder.spec, model_path) + + try: + model = coremltools.models.MLModel(model_path) + out = model.predict(input, useCPUOnly=True) + except RuntimeError as e: + print(e) + assert str(e) == "Error compiling model: \"The file couldn’t be saved.\"." + else: + assert out['output'].shape == (1, 3, 3) + np.testing.assert_allclose(expected, out['output']) + print("Core ML output", out) + diff --git a/coremltools/version.py b/coremltools/version.py index 62a2ef6e1..ad2e066ed 100644 --- a/coremltools/version.py +++ b/coremltools/version.py @@ -4,4 +4,4 @@ # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause -__version__ = "4.0b3" # VERSION_STRING +__version__ = "4.0b4" # VERSION_STRING diff --git a/mlmodel/format/Model.proto b/mlmodel/format/Model.proto index cd8d86990..737233f2e 100644 --- a/mlmodel/format/Model.proto +++ b/mlmodel/format/Model.proto @@ -224,6 +224,27 @@ message SerializedModel { * - NLP Gazeteer * - NLP WordEmbedding * + * 5 : iOS 14, macOS 11, tvOS 14, watchOS 7 (Core ML 4) + * - Model Deployment + * - Model Encryption + * - Unified converter API with PyTorch and Tensorflow 2 Support in coremltools 4 + * - MIL builder for neural networks and composite ops in coremltools 4 + * - New layers in neural network: + * - CumSum + * - OneHot + * - ClampedReLu + * - ArgSort + * - SliceBySize + * - Convolution3D + * - Pool3D + * - Bilinear Upsample with align corners and fractional factors + * - PixelShuffle + * - MatMul with int8 weights and int8 activations + * - Concat interleave + * - See NeuralNetwork.proto + * - Enhanced Xcode model view with interactive previews + * - Enhanced Xcode Playground support for Core ML models + * */ message Model { int32 specificationVersion = 1; diff --git a/mlmodel/src/NeuralNetworkBuffer.cpp b/mlmodel/src/NeuralNetworkBuffer.cpp index a58cdca7e..90f94d553 100644 --- a/mlmodel/src/NeuralNetworkBuffer.cpp +++ b/mlmodel/src/NeuralNetworkBuffer.cpp @@ -7,6 +7,7 @@ // #include "NeuralNetworkBuffer.hpp" + #include #include #include @@ -16,118 +17,121 @@ namespace NNBuffer { - /* - * getOpenMode - Returns open model as per the mode provided - */ - static std::ios_base::openmode getOpenMode(bufferMode mode) - { - return (mode == bufferMode::read) - ? (std::fstream::in | std::ios::binary) - : (std::fstream::in | std::fstream::out | std::ios::binary - | (mode == bufferMode::write ? std::ios::trunc : std::ios::app)); +/* + * getOpenMode - Returns open model as per the mode provided + */ +static std::ios_base::openmode getOpenMode(BufferMode mode) +{ + return (mode == BufferMode::Read) + ? (std::fstream::in | std::ios::binary) + : (std::fstream::in | std::fstream::out | std::ios::binary + | (mode == BufferMode::Write ? std::ios::trunc : std::ios::app)); + +} +/* + * NeuralNetworkBuffer - NeuralNetworkBuffer + */ +NeuralNetworkBuffer::NeuralNetworkBuffer(const std::string& bufferFilePath, BufferMode mode) + : bufferFilePath(bufferFilePath), + bufferStream(bufferFilePath, getOpenMode(mode)) +{ + if (!bufferStream) { + throw std::runtime_error(std::string("Could not open buffer file '" + bufferFilePath + "': ") + std::strerror(errno) + '.'); } +} - /* - * NeuralNetworkBuffer - NeuralNetworkBuffer - */ - NeuralNetworkBuffer::NeuralNetworkBuffer(const std::string& bufferFilePath, bufferMode mode) - : bufferFilePath(bufferFilePath), - bufferStream(bufferFilePath, getOpenMode(mode)) - { - if (!bufferStream) { - throw std::runtime_error(std::string("Could not open buffer file '" + bufferFilePath + "': ") + std::strerror(errno) + '.'); - } +/* + * NeuralNetworkBuffer - NeuralNetworkBuffer + */ +NeuralNetworkBuffer::~NeuralNetworkBuffer() = default; + +/* + * NeuralNetworkBuffer - addBuffer + * Writes given data into buffer file + * Writes in following order + * [Length of data, data type, data] + * Number of bytes written = Length_Of_Data * Size_Of_Data_Type + */ +template +uint64_t NeuralNetworkBuffer::AddBuffer(const std::vector& buffer) { + bufferStream.seekp(0, std::ios::end); + if (!bufferStream.good()) { + throw std::runtime_error(std::string("Could not seek to end of data file: ") + std::strerror(errno) + '.'); } - /* - * NeuralNetworkBuffer - NeuralNetworkBuffer - */ - NeuralNetworkBuffer::~NeuralNetworkBuffer() = default; - - /* - * NeuralNetworkBuffer - addBuffer - * Writes given data into buffer file - * Writes in following order - * [Length of data, data type, data] - * Number of bytes written = Length_Of_Data * Size_Of_Data_Type - */ - template - uint64_t NeuralNetworkBuffer::addBuffer(const std::vector& buffer) { - bufferStream.seekp(0, std::ios::end); - if (!bufferStream.good()) { - throw std::runtime_error(std::string("Could not seek to end of data file: ") + std::strerror(errno) + '.'); - } - - // Get offset - auto offset = bufferStream.tellp(); - - // Write length, size of data type and buffer - int64_t lenOfBuffer = static_cast(buffer.size()); - int64_t sizeOfBlock = sizeof(T); - - bufferStream.write((char*)&lenOfBuffer, sizeof(lenOfBuffer)); - if (bufferStream.fail()) { - throw std::runtime_error(std::string("Could not write length of data file: ") + std::strerror(errno) + '.'); - } - - bufferStream.write((char*)&sizeOfBlock, sizeof(sizeOfBlock)); - if (bufferStream.fail()) { - throw std::runtime_error(std::string("Could not write size of data block: ") + std::strerror(errno) + '.'); - } - - bufferStream.write((char*)&buffer[0], static_cast(sizeOfBlock * lenOfBuffer)); - if (bufferStream.fail()) { - throw std::runtime_error(std::string("Could not write data to data file: ") + std::strerror(errno) + '.'); - } - - return static_cast(offset); + // Get offset + auto offset = bufferStream.tellp(); + + // Write length, size of data type and buffer + int64_t lenOfBuffer = static_cast(buffer.size()); + int64_t sizeOfBlock = sizeof(T); + + bufferStream.write((char*)&lenOfBuffer, sizeof(lenOfBuffer)); + if (bufferStream.fail()) { + throw std::runtime_error(std::string("Could not write length of data file: ") + std::strerror(errno) + '.'); } - /* - * NeuralNetworkBuffer - getBuffer - * Reads data from given offset - */ - template - void NeuralNetworkBuffer::getBuffer(const uint64_t offset, std::vector& buffer) { - int64_t lenOfBuffer = 0; - int64_t sizeOfBlock = 0; - - bufferStream.seekg(static_cast(offset), std::ios::beg); - if (!bufferStream.good()) { - throw std::runtime_error(std::string("Could not seek to beginning of data file: ") + std::strerror(errno) + '.'); - } - - // Read length of buffer and size of each block - bufferStream.read((char*)&lenOfBuffer, sizeof(lenOfBuffer)); - if (bufferStream.fail()) { - throw std::runtime_error(std::string("Could not read length of data file: ") + std::strerror(errno) + '.'); - } - - bufferStream.read((char*)&sizeOfBlock, sizeof(sizeOfBlock)); - if (bufferStream.fail()) { - throw std::runtime_error(std::string("Could not read size of data block: ") + std::strerror(errno) + '.'); - } - - // TODO: assert if sizeOfBlock != sizeof(T) or resize accordingly. - // Resize buffer to fit buffer - buffer.resize(static_cast::size_type>(lenOfBuffer)); - - // Read buffer - bufferStream.read((char*)&buffer[0], static_cast(sizeOfBlock * lenOfBuffer)); - if (bufferStream.fail()) { - throw std::runtime_error(std::string("Could not read data from data file: ") + std::strerror(errno) + '.'); - } + bufferStream.write((char*)&sizeOfBlock, sizeof(sizeOfBlock)); + if (bufferStream.fail()) { + throw std::runtime_error(std::string("Could not write size of data block: ") + std::strerror(errno) + '.'); } - // Explicit include templated functions - template uint64_t NeuralNetworkBuffer::addBuffer(const std::vector&); - template uint64_t NeuralNetworkBuffer::addBuffer(const std::vector&); - template uint64_t NeuralNetworkBuffer::addBuffer(const std::vector&); - template uint64_t NeuralNetworkBuffer::addBuffer(const std::vector&); + bufferStream.write((char*)&buffer[0], static_cast(sizeOfBlock * lenOfBuffer)); + if (bufferStream.fail()) { + throw std::runtime_error(std::string("Could not write data to data file: ") + std::strerror(errno) + '.'); + } + + return static_cast(offset); +} + +/* + * NeuralNetworkBuffer - getBuffer + * Reads data from given offset + */ +template +void NeuralNetworkBuffer::GetBuffer(uint64_t offset, std::vector& buffer) { + int64_t lenOfBuffer = 0; + int64_t sizeOfBlock = 0; + + bufferStream.seekg(static_cast(offset), std::ios::beg); + if (!bufferStream.good()) { + throw std::runtime_error(std::string("Could not seek to beginning of data file: ") + std::strerror(errno) + '.'); + } + + // Read length of buffer and size of each block + bufferStream.read((char*)&lenOfBuffer, sizeof(lenOfBuffer)); + if (bufferStream.fail()) { + throw std::runtime_error(std::string("Could not read length of data file: ") + std::strerror(errno) + '.'); + } + + bufferStream.read((char*)&sizeOfBlock, sizeof(sizeOfBlock)); + if (bufferStream.fail()) { + throw std::runtime_error(std::string("Could not read size of data block: ") + std::strerror(errno) + '.'); + } + + // TODO: rdar://67747690 assert if sizeOfBlock != sizeof(T) or resize accordingly. + // Resize buffer to fit buffer + buffer.resize(static_cast::size_type>(lenOfBuffer)); + + // Read buffer + bufferStream.read((char*)&buffer[0], static_cast(sizeOfBlock * lenOfBuffer)); + if (bufferStream.fail()) { + throw std::runtime_error(std::string("Could not read data from data file: ") + std::strerror(errno) + '.'); + } +} + +// Explicit include templated functions +template uint64_t NeuralNetworkBuffer::AddBuffer(const std::vector&); +template uint64_t NeuralNetworkBuffer::AddBuffer(const std::vector&); +template uint64_t NeuralNetworkBuffer::AddBuffer(const std::vector&); +template uint64_t NeuralNetworkBuffer::AddBuffer(const std::vector&); +template uint64_t NeuralNetworkBuffer::AddBuffer(const std::vector&); + +template void NeuralNetworkBuffer::GetBuffer(const uint64_t, std::vector&); +template void NeuralNetworkBuffer::GetBuffer(const uint64_t, std::vector&); +template void NeuralNetworkBuffer::GetBuffer(const uint64_t, std::vector&); +template void NeuralNetworkBuffer::GetBuffer(const uint64_t, std::vector&); +template void NeuralNetworkBuffer::GetBuffer(const uint64_t, std::vector&); - template void NeuralNetworkBuffer::getBuffer(const uint64_t, std::vector&); - template void NeuralNetworkBuffer::getBuffer(const uint64_t, std::vector&); - template void NeuralNetworkBuffer::getBuffer(const uint64_t, std::vector&); - template void NeuralNetworkBuffer::getBuffer(const uint64_t, std::vector&); } diff --git a/mlmodel/src/NeuralNetworkBuffer.hpp b/mlmodel/src/NeuralNetworkBuffer.hpp index b4244a605..5ddbbe003 100644 --- a/mlmodel/src/NeuralNetworkBuffer.hpp +++ b/mlmodel/src/NeuralNetworkBuffer.hpp @@ -6,46 +6,50 @@ // Copyright © 2019 Apple Inc. All rights reserved. // -#ifndef NeuralNetworkBuffer_hpp -#define NeuralNetworkBuffer_hpp +#pragma once +#include #include #include -#include namespace NNBuffer { - // - // NeuralNetworkBuffer - Network parameter read-write management to file - // Current management policy - // Each parameter is written to binary file in following order. - // [Length of data (size_t), Data type of data (size_t), data (length of data * size of data type)] - // - - enum bufferMode { - write=0, - append, - read - }; - - class NeuralNetworkBuffer { - private: - std::string bufferFilePath; - std::fstream bufferStream; - - public: - // Must be constructed with file path to store parameters - NeuralNetworkBuffer(const std::string&, bufferMode mode=bufferMode::write); - ~NeuralNetworkBuffer(); - - // Stores given buffer and returns offset in buffer file - template - uint64_t addBuffer(const std::vector&); - - // Reads buffer from given offset and stores in vector - // passed by reference. - // Note that, this routine resizes the given vector. - template - void getBuffer(const uint64_t, std::vector&); - }; -} -#endif /* NeuralNetworkBuffer_hpp */ + +enum class BufferMode { + Write=0, + Append, + Read +}; + +// +// NeuralNetworkBuffer - Network parameter read-write management to file +// Current management policy +// Each parameter is written to binary file in following order. +// [Length of data (size_t), Data type of data (size_t), data (length of data * size of data type)] +// +class NeuralNetworkBuffer { +public: + // Must be constructed with file path to store parameters + NeuralNetworkBuffer(const std::string& bufferFilePath, BufferMode mode=BufferMode::Write); + ~NeuralNetworkBuffer(); + + NeuralNetworkBuffer(const NeuralNetworkBuffer&) = delete; + NeuralNetworkBuffer(NeuralNetworkBuffer&&) = delete; + NeuralNetworkBuffer& operator=(const NeuralNetworkBuffer&) = delete; + NeuralNetworkBuffer& operator=(NeuralNetworkBuffer&&) = delete; + + // Stores given buffer and returns offset in buffer file + template + uint64_t AddBuffer(const std::vector& buffer); + + // Reads buffer from given offset and stores in vector + // passed by reference. + // Note that, this routine resizes the given vector. + template + void GetBuffer(uint64_t offset, std::vector& buffer); + +private: + std::string bufferFilePath; + std::fstream bufferStream; +}; + +} // namespace NNBuffer diff --git a/mlmodel/src/Validation/InterfaceValidators.cpp b/mlmodel/src/Validation/InterfaceValidators.cpp index 64d8c65f9..33d577eb7 100644 --- a/mlmodel/src/Validation/InterfaceValidators.cpp +++ b/mlmodel/src/Validation/InterfaceValidators.cpp @@ -494,7 +494,8 @@ namespace CoreML { // only for NeuralNetwork models with Spec 5 (iOS 14) onwards. if (format.Type_case() != Specification::Model::kNeuralNetwork && format.Type_case() != Specification::Model::kNeuralNetworkRegressor && - format.Type_case() != Specification::Model::kNeuralNetworkClassifier) { + format.Type_case() != Specification::Model::kNeuralNetworkClassifier && + format.Type_case() != Specification::Model::kSerializedModel ) { return Result(ResultType::INVALID_MODEL_PARAMETERS, "Default optional values are only allowed for neural networks."); } @@ -542,6 +543,7 @@ namespace CoreML { case Specification::Model::kNeuralNetwork: case Specification::Model::kNeuralNetworkRegressor: case Specification::Model::kNeuralNetworkClassifier: + case Specification::Model::kSerializedModel: r = validateOptionalNN(format.description()); break; case Specification::Model::kTreeEnsembleRegressor: diff --git a/reqs/test_tf2.pip b/reqs/test_tf2.pip index 957ad2d78..b2dd59aa9 100644 --- a/reqs/test_tf2.pip +++ b/reqs/test_tf2.pip @@ -1,5 +1,5 @@ -tensorflow==2.1.0; python_version < '3.8' -tensorflow==2.2.0; python_version >= '3.8' +tensorflow==2.1.0; python_version <= '2.7' +tensorflow==2.3.0; python_version >= '3.5' tensorflow-addons==0.7.1; python_version == '2.7' tensorflow-addons==0.8.3; python_version > '2.7' and python_version < '3.8' diff --git a/scripts/build.sh b/scripts/build.sh index 266d9ade2..0aea2b115 100755 --- a/scripts/build.sh +++ b/scripts/build.sh @@ -79,7 +79,7 @@ cd ${BUILD_DIR} ADDITIONAL_CMAKE_OPTIONS="" if [[ "$OSTYPE" == "darwin"* ]]; then NUM_PROCS=$(sysctl -n hw.ncpu) - ADDITIONAL_CMAKE_OPTIONS=" -DCMAKE_OSX_DEPLOYMENT_TARGET=10.13" + ADDITIONAL_CMAKE_OPTIONS="-DCMAKE_OSX_DEPLOYMENT_TARGET=10.13" else NUM_PROCS=$(nproc) fi diff --git a/scripts/build_docs.sh b/scripts/build_docs.sh index 832dfab46..82e809c90 100755 --- a/scripts/build_docs.sh +++ b/scripts/build_docs.sh @@ -29,7 +29,7 @@ print_help() { echo echo " --wheel-path=* Specify which wheel to use to make docs." echo " --python=* Python to use for configuration." - echo " --version=* ReadMe ersion to upload to. Default is the installed coremltools version." + echo " --version=* ReadMe version to upload to. Default is the installed coremltools version." echo " --upload Upload these docs with the current coremltools version." echo " --release Release the uploaded docs with the current coremltools version." echo " --from-source-version=* If a version must be created, use this as the base to copy from.\ diff --git a/setup.py b/setup.py index c08317d8d..7f3d6ae48 100755 --- a/setup.py +++ b/setup.py @@ -78,6 +78,7 @@ "scipy", 'enum34;python_version < "3.4"', "tqdm", + "packaging", 'typing;python_version < "3.5"', ], entry_points={"console_scripts": ["coremlconverter = coremltools:_main"]},