Skip to content

Commit

Permalink
Simplify onnx_parser API (#727)
Browse files Browse the repository at this point in the history
* simplify onnx_parser API
* fix pickling issue
  • Loading branch information
xadupre authored Sep 2, 2021
1 parent d1d1ada commit c503962
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 27 deletions.
6 changes: 6 additions & 0 deletions .azure-pipelines/linux-CI-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ jobs:
vmImage: 'ubuntu-latest'
strategy:
matrix:
Py39-OnnxLatest-Sk0242:
python.version: '3.9'
numpy.version: '>=1.18.3'
onnx.version: ''
onnxrt.version: '-i https://test.pypi.org/simple/ ort-nightly'
sklearn.version: '==0.24.2'
Py38-Onnx170-Sk0240:
python.version: '3.8'
numpy.version: '>=1.18.3'
Expand Down
7 changes: 7 additions & 0 deletions .azure-pipelines/win32-CI-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@ jobs:
vmImage: 'vs2017-win2016'
strategy:
matrix:
Py39-Onnx1101-nightlyRT-Sk0242:
python.version: '3.9'
onnx.version: 'onnx==1.10.1'
numpy.version: 'numpy>=1.18.1'
scipy.version: 'scipy'
onnxrt.version: '-i https://test.pypi.org/simple/ ort-nightly'
sklearn.version: '==0.24.2'
Py38-Onnx170-nightlyRT-Sk0240:
python.version: '3.8'
onnx.version: 'onnx==1.7.0'
Expand Down
15 changes: 13 additions & 2 deletions skl2onnx/_parse.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0

import warnings
import numpy as np

from sklearn import pipeline
Expand Down Expand Up @@ -105,9 +106,19 @@ def _parse_sklearn_simple_model(scope, model, inputs, custom_parsers=None):
this_operator.inputs = inputs

if hasattr(model, 'onnx_parser'):
parser_names = model.onnx_parser(scope=scope, inputs=inputs)
parser_names = model.onnx_parser()
if parser_names is not None:
names = parser_names()
try:
names = parser_names(scope=scope, inputs=inputs)
except TypeError as e:
warnings.warn(
"Calling parser %r for model type %r failed due to %r. "
"This warnings will become an exception in version 1.11. "
"The parser signature should parser(scope=None, "
"inputs=None)." % (
parser_names, e, type(model)),
DeprecationWarning)
names = parser_names()
if names is not None:
for name in names:
var = scope.declare_local_variable(
Expand Down
20 changes: 20 additions & 0 deletions skl2onnx/algebra/onnx_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,14 @@ def get_output_name(self, i=0):
raise IndexError("Can only return the first item.")
return self.onx_op.get_output_name(self.index)

def get_output(self, i=0):
"""
Returns the output.
"""
if i != 0:
raise IndexError("Can only return the first item.")
return self.onx_op.get_output(self.index)

@property
def outputs(self):
"""
Expand Down Expand Up @@ -483,6 +491,18 @@ def get_output_name(self, i, scope=None):
self._set_output_names_(getattr(self, 'scope', None) or scope, None)
return self.output_names_[i]

def get_output(self, i, scope=None):
"Returns name of output *i*."
if self.state is not None:
return self.state.computed_outputs_[i]
if self.output_names_ is not None:
res = self.output_names_[i]
if not isinstance(res, (tuple, Variable)):
raise RuntimeError(
"Unable to retrieve output %r from %r."
"" % (i, self))
return res

def _set_output_names_(self, scope, operator):
"Called by add_to."
if operator is not None:
Expand Down
22 changes: 8 additions & 14 deletions skl2onnx/algebra/onnx_operator_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,23 +94,21 @@ def to_onnx_operator(self, inputs=None, outputs=None,
"""
raise NotImplementedError()

def onnx_parser(self, scope=None, inputs=None):
def onnx_parser(self):
"""
Returns a parser for this model.
If not overloaded, it calls the converter to guess the number
of outputs. If it still fails, it fetches the parser
mapped to the first *scikit-learn* parent
it can find.
"""
if inputs:
self.parsed_inputs_ = inputs
try:
op = self.to_onnx_operator(inputs=inputs, outputs=None)
except NotImplementedError:
self._find_sklearn_parent()
return None
def parser(scope=None, inputs=None):
try:
op = self.to_onnx_operator(inputs=inputs, outputs=None)
except NotImplementedError:
self._find_sklearn_parent()
return None

def parser():
names = []
while True:
try:
Expand Down Expand Up @@ -147,12 +145,8 @@ def onnx_shape_calculator(self):
"Class '{}' should have an attribute 'op_version'.".format(
self.__class__.__name__))

inputs = getattr(self, "parsed_inputs_", None)
try:
if inputs:
op = self.to_onnx_operator(inputs=inputs)
else:
op = self.to_onnx_operator()
op = self.to_onnx_operator()
except NotImplementedError:
parent = self._find_sklearn_parent()
name = sklearn_operator_name_map.get(
Expand Down
18 changes: 7 additions & 11 deletions tests/test_algebra_custom_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def test_base_api(self):
except RuntimeError as e:
assert "Method enumerate_initial_types is missing" in str(e)

@unittest.skipIf(StrictVersion(onnx.__version__) <= StrictVersion("1.7.0"),
reason="checm_model crashes")
def test_custom_scaler(self):
mat = np.array([[0., 1.], [0., 1.], [2., 2.]])
tr = CustomOpTransformerShape(op_version=TARGET_OPSET)
Expand All @@ -81,14 +83,13 @@ def test_custom_scaler(self):

matf = mat.astype(np.float32)
model_onnx = tr.to_onnx(matf)
# Next instructions fails...
# Field 'shape' of type is required but missing.
# onnx.checker.check_model(model_onnx)

onnx.checker.check_model(model_onnx)
dump_data_and_model(
mat.astype(np.float32), tr, model_onnx,
basename="CustomTransformerAlgebra")

@unittest.skipIf(StrictVersion(onnx.__version__) <= StrictVersion("1.7.0"),
reason="checm_model crashes")
def test_custom_scaler_pipeline_right(self):
pipe = make_pipeline(
StandardScaler(),
Expand All @@ -100,12 +101,7 @@ def test_custom_scaler_pipeline_right(self):

matf = mat.astype(np.float32)
model_onnx = to_onnx(pipe, matf, target_opset=TARGET_OPSET)
# Next instructions fails...
# Field 'shape' of type is required but missing.
# onnx.checker.check_model(model_onnx)

# use assert_consistent_outputs
# calls dump_data_and_model
onnx.checker.check_model(model_onnx)
dump_data_and_model(
mat.astype(np.float32), pipe, model_onnx,
basename="CustomTransformerPipelineRightAlgebra")
Expand All @@ -125,7 +121,7 @@ def test_custom_scaler_pipeline_left(self):
try:
model_onnx = to_onnx(pipe, matf, target_opset=TARGET_OPSET)
except RuntimeError as e:
assert "cannot be infered" in str(e)
assert "inputs should contain one name" in str(e)

pipe = make_pipeline(
CustomOpTransformerShape(op_version=TARGET_OPSET),
Expand Down
7 changes: 7 additions & 0 deletions tests/test_utils/tests_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,12 +345,18 @@ def lambda_original(): return model.transform(dataone) # noqa

dest = os.path.join(folder, basename + ".model.pkl")
names.append(dest)
load_pickle = True
with open(dest, "wb") as f:
try:
pickle.dump(model, f)
except AttributeError as e:
print("[dump_data_and_model] cannot pickle model '{}'"
" due to {}.".format(dest, e))
load_pickle = False
if load_pickle and os.path.exists(dest):
# Test unpickle works.
with open(dest, "rb") as f:
pickle.load(f)

if dump_error_log:
error_dump = os.path.join(folder, basename + ".err")
Expand Down Expand Up @@ -894,6 +900,7 @@ def make_report_backend(folder, as_df=False):
benched = 0
files = os.listdir(folder)
for name in files:
print(name)
if name.endswith(".expected.pkl"):
model = name.split(".")[0]
if model not in res:
Expand Down

0 comments on commit c503962

Please sign in to comment.