Skip to content
This repository has been archived by the owner on Nov 11, 2023. It is now read-only.

Added switch --dont_save_keras_model to avoid writing model to disk #120

Merged
merged 2 commits into from
Aug 26, 2022

Conversation

travisjayday
Copy link
Contributor

Problem

During TFLite conversion, openvino2tensorflow.py script calls

tf.lite.TFLiteConverter.from_keras_model(model)
or
tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)

Both these functions write or require weights/model to have been saved to disk. Saving a model to disk takes a long time because of I/O bottlenecks.

Solution

For backward compatibility, I added a switch --dont_save_keras_model that avoids saving the model to disk during TFLite conversion.

Implementation

Inspired by this SO post, we can create a TFLiteModel converter using concrete functions to avoid saving the model to disk. We call

tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])

instead of the above constructors.

Results

Model conversion time decreased from ~70s to ~25s for my particular model (MobileNetV3, 224x224x3), but time should relatively decrease more as model size increases.

Copy link
Owner

@PINTO0309 PINTO0309 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. There are several areas that I would like to see corrected.

Comment on lines 7226 to 7228
concrete_func = run_model.get_concrete_function(
tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype)
)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is better to be able to accommodate models with more than two INPUTs.

Comment on lines 7223 to 7225
def get_TFLiteConverter():
if dont_save_keras_model:
run_model = tf.function(lambda x : model(x))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Limit the variable scope of model and dont_save_keras_model to within a function. This is a bit verbose, but it should be passed as a function argument.

Comment on lines 7376 to 7382
if output_integer_quant_tflite:
try:
print(f'{Color.REVERCE}Integer Quantization started{Color.RESET}', '=' * 56)
converter = tf.lite.TFLiteConverter.from_saved_model(model_output_path)
converter = get_TFLiteConverter()
converter.experimental_new_quantizer = use_experimental_new_quantizer
converter._experimental_disable_per_channel = not use_per_channel
converter.optimizations = [tf.lite.Optimize.DEFAULT]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some models that cannot be converted by tf.lite.TFLiteConverter.from_keras_model(model). Therefore, do not modify here or call tf.lite.TFLiteConverter.from_saved_model(model_output_path) if an exception occurs during conversion. The output_saved_model flag must then be True.

Comment on lines 7395 to 7400
if output_full_integer_quant_tflite:
try:
print(f'{Color.REVERCE}Full Integer Quantization started{Color.RESET}', '=' * 51)
converter = tf.lite.TFLiteConverter.from_saved_model(model_output_path)
converter = get_TFLiteConverter()
converter.experimental_new_quantizer = use_experimental_new_quantizer
converter._experimental_disable_per_channel = not use_per_channel
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some models that cannot be converted by tf.lite.TFLiteConverter.from_keras_model(model). Therefore, do not modify here or call tf.lite.TFLiteConverter.from_saved_model(model_output_path) if an exception occurs during conversion. The output_saved_model flag must then be True.

Copy link
Owner

@PINTO0309 PINTO0309 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional review content.

@@ -67,6 +67,7 @@ def convert(
output_h5,
output_weight_and_json,
output_pb,
dont_save_keras_model,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be precise, dont_convert_from_keras_model is more correct in meaning than dont_save_keras_model, since we are not outputting keras_model, but only reading keras model objects.

@@ -7661,6 +7672,7 @@ def main():
parser.add_argument('--output_h5', action='store_true', help='.h5 output switch')
parser.add_argument('--output_weight_and_json', action='store_true', help='weight of h5 and json output switch')
parser.add_argument('--output_pb', action='store_true', help='.pb output switch')
parser.add_argument('--dont_save_keras_model', action='store_true', help='use TFLiteConverter.from_concrete_functions to avoid saving Keras model to disk during TFLite conversion')
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be precise, dont_convert_from_keras_model is more correct in meaning than dont_save_keras_model, since we are not outputting keras_model, but only reading keras model objects.

@@ -7852,7 +7865,7 @@ def main():

del package_list
os.makedirs(model_output_path, exist_ok=True)
convert(model, model_output_path, output_saved_model, output_h5, output_weight_and_json, output_pb,
convert(model, model_output_path, output_saved_model, output_h5, output_weight_and_json, output_pb, dont_save_keras_model,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be precise, dont_convert_from_keras_model is more correct in meaning than dont_save_keras_model, since we are not outputting keras_model, but only reading keras model objects.

@travisjayday
Copy link
Contributor Author

travisjayday commented Aug 23, 2022

@PINTO0309 do you have a model with multiple inputs/outputs? I'm trying to validate two input models with concrete_function tflite converter, but I'm having trouble getting such a model.

I created this simple OpenVino model
model.zip that has two inputs and two outputs. However, openvino2tensorflow is throwing errors about dimension mismatch. Here is the code to generate the model:

from torch import nn

class TwoInputsNet(nn.Module):
  def __init__(self):
    super(TwoInputsNet, self).__init__()
    self.conv = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3)  # set up your layer here
    self.fc1 = nn.Linear(in_features=28, out_features=10 )  # set up first FC layer
    self.fc2 = nn.Linear(in_features=2868, out_features=10)  # set up the other FC layer

  def forward(self, input1, input2):
    c = self.conv(input1)
    f = self.fc1(input2)
    # now we can reshape `c` and `f` to 2D and concat them
    combined = torch.cat((c.view(c.size(0), -1),
                          f.view(f.size(0), -1)), dim=1)
    out = self.fc2(combined)
    out1 = self.fc2(combined)
    return out, out1

model = TwoInputsNet()
model.eval()
in1 = torch.rand((1, 3, 28, 28))
in2 = torch.rand((1, 3, 28, 28))

out = model(in1, in2)
torch.onnx.export(
    model,
    (in1, in2),
    './model/model.onnx',
    opset_version=13,
    input_names=['input_0', 'input_1'],
    output_names=['output_0', 'output_1'],
)
print(len(out))
print(out)

and here is the openvino2tensorflow error

ValueError: Dimensions must be equal, but are 3 and 28 for '{{node tf.linalg.matmul/MatMul}} = BatchMatMulV2[T=DT_FLOAT, adj_x=false, adj_y=true](Placeholder, tf.linalg.matmul/MatMul/b)' with input shapes: [1,28,28,3], [10,28].
ERROR: Please refer to 6-7 in the README first. https://github.com/PINTO0309/openvino2tensorflow

@PINTO0309
Copy link
Owner

There was insufficient consideration of the case where the MatMul is immediately after the 4D input layer. However, since the problem is essentially unrelated to the problem we want to solve with this pull request, it can be verified with a simpler multiplication-only model, as shown below.

model.zip

import torch
from torch import nn
import onnx
from onnxsim import simplify

class TwoInputsNet(nn.Module):
    def forward(self, input1, input2):
        a = input1 * 2.0
        b = input2 * 3.0

        out1 = a
        out2 = a + b

        return out1, out2

model = TwoInputsNet()
model.eval()
in1 = torch.rand((1, 3, 28, 28))
in2 = torch.rand((1, 3, 28, 28))

out = model(in1, in2)

onnx_file = 'model.onnx'
torch.onnx.export(
    model,
    (in1, in2),
    onnx_file,
    opset_version=11,
    input_names=['input_0', 'input_1'],
    output_names=['output_0', 'output_1'],
)
print(len(out))
print(out)

model_onnx2 = onnx.load(onnx_file)
model_simp, check = simplify(model_onnx2)
onnx.save(model_simp, onnx_file)

"""
mo \
--framework onnx \
--input_model model.onnx \
--data_type FP32 \
--output_dir openvino/FP32 \
--model_name model

openvino2tensorflow \
--model_path openvino/FP32/model.xml \
--output_saved_model \
--output_pb \
--non_verbose \
--output_no_quant_float32_tflite
"""

@travisjayday
Copy link
Contributor Author

OK @PINTO0309, I made the changes you suggested. Multiple inputs are handled correctly now with try/catch logic in case concrete function fails. What do you think?

@travisjayday travisjayday requested a review from PINTO0309 August 24, 2022 18:58
@PINTO0309 PINTO0309 merged commit b14e5b6 into PINTO0309:main Aug 26, 2022
@PINTO0309
Copy link
Owner

LGTM. Thanks for your contribution. 👍

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants