forked from microsoft/onnxruntime
-
Notifications
You must be signed in to change notification settings - Fork 0
/
op_registration_validator.py
124 lines (98 loc) · 4.75 KB
/
op_registration_validator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# !/usr/bin/env python3
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""
Validate ORT kernel registrations.
"""
import argparse
import os
import sys
import typing
import op_registration_utils
from logger import get_logger
log = get_logger("op_registration_validator")
# deprecated ops where the last registration should have an end version.
# value for each entry is the opset when it was deprecated. end version of last registration should equal value - 1.
deprecated_ops = {
"kOnnxDomain:Scatter": 11,
"kOnnxDomain:Upsample": 10,
# LayerNormalization, MeanVarianceNormalization and ThresholdedRelu were in contrib ops and incorrectly registered
# using the kOnnxDomain. They became official ONNX operators later and are registered there now. That leaves
# entries in the contrib ops registrations with end versions for when the contrib op was 'deprecated'
# and became an official op.
"kOnnxDomain:LayerNormalization": 17,
"kOnnxDomain:MeanVarianceNormalization": 9,
"kOnnxDomain:ThresholdedRelu": 10,
}
class RegistrationValidator(op_registration_utils.RegistrationProcessor):
def __init__(self):
self.last_op_registrations = {}
self.failed = False
def process_registration(
self,
lines: typing.List[str],
domain: str,
operator: str,
start_version: int,
end_version: typing.Optional[int] = None,
type: typing.Optional[str] = None,
):
key = domain + ":" + operator
prev_start, prev_end = self.last_op_registrations[key] if key in self.last_op_registrations else (None, None)
if prev_start:
# a typed registration where the to/from matches for each entry so nothing to update
if prev_start == start_version and prev_end == end_version:
return
# previous registration was unversioned but should have been if we are seeing another registration
if not prev_end:
log.error(
"Invalid registration for {}. Registration for opset {} has no end version but was "
"superceeded by version {}.".format(key, prev_start, start_version)
)
self.failed = True
return
# previous registration end opset is not adjacent to the start of the next registration
if prev_end != start_version - 1:
log.error(
"Invalid registration for {}. Registration for opset {} should have end version of {}".format(
key, prev_start, start_version - 1
)
)
self.failed = True
return
self.last_op_registrations[key] = (start_version, end_version)
def ok(self):
return not self.failed
def validate_last_registrations(self):
# make sure we have an unversioned last entry for each operator unless it's deprecated
for entry in self.last_op_registrations.items():
key, value = entry
opset_from, opset_to = value
allow_missing_unversioned_registration = key in deprecated_ops and opset_to == deprecated_ops[key] - 1
# special handling for ArgMin/ArgMax, which CUDA EP doesn't yet support for opset 12+
# TODO remove once CUDA EP supports ArgMin/ArgMax for opset 12+
ops_with_incomplete_support = ["kOnnxDomain:ArgMin", "kOnnxDomain:ArgMax"]
if key in ops_with_incomplete_support:
log.warn("Allowing missing unversioned registration for op with incomplete support: {}".format(key))
allow_missing_unversioned_registration = True
if opset_to and not allow_missing_unversioned_registration:
log.error("Missing unversioned registration for {}".format(key))
self.failed = True
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Script to validate operator kernel registrations.")
parser.add_argument(
"--ort_root",
type=str,
help="Path to ONNXRuntime repository root. " "Inferred from the location of this script if not provided.",
)
args = parser.parse_args()
ort_root = os.path.abspath(args.ort_root) if args.ort_root else ""
include_cuda = True # validate CPU and CUDA EP registrations
registration_files = op_registration_utils.get_kernel_registration_files(ort_root, include_cuda)
for file in registration_files:
log.info("Processing {}".format(file))
processor = RegistrationValidator()
op_registration_utils.process_kernel_registration_file(file, processor)
processor.validate_last_registrations()
if not processor.ok():
sys.exit(-1)