Skip to content

Commit

Permalink
add logs for detecting oom layers
Browse files Browse the repository at this point in the history
  • Loading branch information
ranhomri committed Nov 17, 2024
1 parent 3c4aef1 commit 3699592
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 19 deletions.
4 changes: 3 additions & 1 deletion onnx2kerastl/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import tensorflow as tf
from keras.models import Model

from onnx2kerastl.logger_helpers import log_oom_layers

from .customonnxlayer import onnx_custom_objects_map
from .exceptions import UnsupportedLayer, OnnxUnsupported
from .layers import AVAILABLE_CONVERTERS
Expand Down Expand Up @@ -94,7 +96,7 @@ def onnx_to_keras(onnx_model, input_names, name_policy=None, verbose=True, chang
logger = logging.getLogger('onnx2keras')

logger.info('Converter is called.')

log_oom_layers(onnx_model, logger)
onnx_weights = onnx_model.graph.initializer
onnx_inputs = onnx_model.graph.input
onnx_outputs = [i.name for i in onnx_model.graph.output]
Expand Down
18 changes: 1 addition & 17 deletions onnx2kerastl/elementwise_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,23 +132,7 @@ def convert_elementwise_add(node, params, layers, lambda_func, node_name, keras_
name=keras_name
)(variable_input)
else:
# Both inputs are constants; perform addition at runtime without embedding the result
constant_value_0 = tf.constant(input_0)
constant_value_1 = tf.constant(input_1)

# Create a dummy input to satisfy Keras's requirements
dummy_input = keras.Input(shape=(1,), name='dummy_input')

# Define a function outside the Lambda to ensure serialization
def add_constants(x):
return constant_value_0 + constant_value_1 + x * 0 # x * 0 ensures connection to input

layers[node_name] = keras.layers.Lambda(
add_constants,
name=keras_name
)(dummy_input)

layers['dummy_input'] = dummy_input
layers[node_name] = input_0 + input_1



Expand Down
46 changes: 46 additions & 0 deletions onnx2kerastl/logger_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
def log_oom_layers(model, logger):
# Load the ONNX model
graph = model.graph

# Log general model information
logger.debug("IR_VERSION: {}".format(model.ir_version))
logger.debug("PRODUCER: {}".format(model.producer_name))
logger.debug("PRODUCER_VERSION: {}".format(model.producer_version))
logger.debug("DOMAIN: {}".format(model.domain))
logger.debug("MODEL_VERSION: {}".format(model.model_version))
logger.debug("OPSET_IMPORT: {}".format(str({opset.domain: opset.version for opset in model.opset_import})))

# Log input information
for input_tensor in graph.input:
logger.debug("Input: {}; Shape: {}; Type: {}".format(
input_tensor.name,
str(input_tensor.type.tensor_type.shape.dim).replace('\n', ''),
input_tensor.type.tensor_type.elem_type
))

# Log output information
for output_tensor in graph.output:
logger.debug("Output: {}; Shape: {}; Type: {}".format(
output_tensor.name,
str(output_tensor.type.tensor_type.shape.dim).replace('\n', ''),
output_tensor.type.tensor_type.elem_type
))

# Log initializers
for initializer in graph.initializer:
raw_data_log = initializer.raw_data if len(initializer.raw_data)<50 else ""
logger.debug("Initializer: {}; Shape: {}; Type: {}; Raw Data: {}".format(
initializer.name,
initializer.dims,
initializer.data_type,
raw_data_log
))

# Log each node's information
for node in graph.node:
logger.debug("Node: {}; Name: {}".format(node.op_type, node.name))
logger.debug(" Inputs: {}".format(node.input))
logger.debug(" Outputs: {}".format(node.output))
logger.debug(" Attributes:")
for attr in node.attribute:
logger.debug(" - {}: {}".format(attr.name, attr))
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "onnx2kerastl"
version = "0.0.152"
version = "0.0.153"
description = ""
authors = ["dorhar <[email protected]>"]
license = "MIT"
Expand Down

0 comments on commit 3699592

Please sign in to comment.