forked from microsoft/onnxruntime
-
Notifications
You must be signed in to change notification settings - Fork 0
/
symbolic_shape_infer.py
executable file
·2625 lines (2384 loc) · 114 KB
/
symbolic_shape_infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# -*- coding: UTF-8 -*-
import argparse
import logging
import numpy as np
import onnx
import sympy
from onnx import helper, numpy_helper, shape_inference
from packaging import version
assert version.parse(onnx.__version__) >= version.parse("1.8.0")
logger = logging.getLogger(__name__)
def get_attribute(node, attr_name, default_value=None):
found = [attr for attr in node.attribute if attr.name == attr_name]
if found:
return helper.get_attribute_value(found[0])
return default_value
def get_dim_from_proto(dim):
return getattr(dim, dim.WhichOneof("value")) if type(dim.WhichOneof("value")) == str else None
def is_sequence(type_proto):
cls_type = type_proto.WhichOneof("value")
assert cls_type in ["tensor_type", "sequence_type"]
return cls_type == "sequence_type"
def get_shape_from_type_proto(type_proto):
assert not is_sequence(type_proto)
if type_proto.tensor_type.HasField("shape"):
return [get_dim_from_proto(d) for d in type_proto.tensor_type.shape.dim]
else:
return None # note no shape is different from shape without dim (scalar)
def get_elem_type_from_type_proto(type_proto):
if is_sequence(type_proto):
return type_proto.sequence_type.elem_type.tensor_type.elem_type
else:
return type_proto.tensor_type.elem_type
def get_shape_from_value_info(vi):
cls_type = vi.type.WhichOneof("value")
if cls_type is None:
return None
if is_sequence(vi.type):
if "tensor_type" == vi.type.sequence_type.elem_type.WhichOneof("value"):
return get_shape_from_type_proto(vi.type.sequence_type.elem_type)
else:
return None
else:
return get_shape_from_type_proto(vi.type)
def make_named_value_info(name):
vi = onnx.ValueInfoProto()
vi.name = name
return vi
def get_shape_from_sympy_shape(sympy_shape):
return [None if i is None else (int(i) if is_literal(i) else str(i)) for i in sympy_shape]
def is_literal(dim):
return type(dim) in [int, np.int64, np.int32, sympy.Integer] or (hasattr(dim, "is_number") and dim.is_number)
def handle_negative_axis(axis, rank):
assert axis < rank and axis >= -rank
return axis if axis >= 0 else rank + axis
def get_opset(mp, domain=None):
domain = domain or ["", "onnx", "ai.onnx"]
if type(domain) != list:
domain = [domain]
for opset in mp.opset_import:
if opset.domain in domain:
return opset.version
return None
def as_scalar(x):
if type(x) == list:
assert len(x) == 1
return x[0]
elif type(x) == np.ndarray:
return x.item()
else:
return x
def as_list(x, keep_none):
if type(x) == list:
return x
elif type(x) == np.ndarray:
return list(x)
elif keep_none and x is None:
return None
else:
return [x]
def sympy_reduce_product(x):
if type(x) == list:
value = sympy.Integer(1)
for v in x:
value = value * v
else:
value = x
return value
class SymbolicShapeInference:
def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""):
self.dispatcher_ = {
"Add": self._infer_symbolic_compute_ops,
"ArrayFeatureExtractor": self._infer_ArrayFeatureExtractor,
"AveragePool": self._infer_Pool,
"BatchNormalization": self._infer_BatchNormalization,
"Cast": self._infer_Cast,
"CategoryMapper": self._infer_CategoryMapper,
"Compress": self._infer_Compress,
"Concat": self._infer_Concat,
"ConcatFromSequence": self._infer_ConcatFromSequence,
"Constant": self._infer_Constant,
"ConstantOfShape": self._infer_ConstantOfShape,
"Conv": self._infer_Conv,
"CumSum": self._pass_on_shape_and_type,
"Div": self._infer_symbolic_compute_ops,
"Einsum": self._infer_Einsum,
"Expand": self._infer_Expand,
"Equal": self._infer_symbolic_compute_ops,
"Floor": self._infer_symbolic_compute_ops,
"Gather": self._infer_Gather,
"GatherElements": self._infer_GatherElements,
"GatherND": self._infer_GatherND,
"Identity": self._pass_on_shape_and_type,
"If": self._infer_If,
"Loop": self._infer_Loop,
"MatMul": self._infer_MatMul,
"MatMulInteger16": self._infer_MatMulInteger,
"MaxPool": self._infer_Pool,
"Max": self._infer_symbolic_compute_ops,
"Min": self._infer_symbolic_compute_ops,
"Mul": self._infer_symbolic_compute_ops,
"NonMaxSuppression": self._infer_NonMaxSuppression,
"NonZero": self._infer_NonZero,
"OneHot": self._infer_OneHot,
"Pad": self._infer_Pad,
"Range": self._infer_Range,
"Reciprocal": self._pass_on_shape_and_type,
"ReduceSum": self._infer_ReduceSum,
"ReduceProd": self._infer_ReduceProd,
"Reshape": self._infer_Reshape,
"Resize": self._infer_Resize,
"Round": self._pass_on_shape_and_type,
"Scan": self._infer_Scan,
"ScatterElements": self._infer_ScatterElements,
"SequenceAt": self._infer_SequenceAt,
"SequenceInsert": self._infer_SequenceInsert,
"Shape": self._infer_Shape,
"Size": self._infer_Size,
"Slice": self._infer_Slice,
"SoftmaxCrossEntropyLoss": self._infer_SoftmaxCrossEntropyLoss,
"SoftmaxCrossEntropyLossInternal": self._infer_SoftmaxCrossEntropyLoss,
"NegativeLogLikelihoodLossInternal": self._infer_SoftmaxCrossEntropyLoss,
"Split": self._infer_Split,
"SplitToSequence": self._infer_SplitToSequence,
"Squeeze": self._infer_Squeeze,
"Sub": self._infer_symbolic_compute_ops,
"Tile": self._infer_Tile,
"TopK": self._infer_TopK,
"Transpose": self._infer_Transpose,
"Unsqueeze": self._infer_Unsqueeze,
"Where": self._infer_symbolic_compute_ops,
"ZipMap": self._infer_ZipMap,
"Neg": self._infer_symbolic_compute_ops,
# contrib ops:
"Attention": self._infer_Attention,
"BiasGelu": self._infer_BiasGelu,
"MultiHeadAttention": self._infer_MultiHeadAttention,
"EmbedLayerNormalization": self._infer_EmbedLayerNormalization,
"FastGelu": self._infer_FastGelu,
"Gelu": self._infer_Gelu,
"GemmFastGelu": self._infer_GemmFastGelu,
"LayerNormalization": self._infer_LayerNormalization,
"LongformerAttention": self._infer_LongformerAttention,
"PythonOp": self._infer_PythonOp,
"SimplifiedLayerNormalization": self._infer_LayerNormalization,
"SkipLayerNormalization": self._infer_SkipLayerNormalization,
"SkipSimplifiedLayerNormalization": self._infer_SkipLayerNormalization,
"GroupNorm": self._infer_GroupNorm,
"BiasSplitGelu": self._infer_BiasSplitGelu,
"BiasAdd": self._infer_BiasAdd,
"NhwcConv": self._infer_NhwcConv,
}
self.aten_op_dispatcher_ = {
"embedding": self._infer_Gather,
"bitwise_or": self._infer_aten_bitwise_or,
"diagonal": self._infer_aten_diagonal,
"max_pool2d_with_indices": self._infer_aten_pool2d,
"max": self._infer_aten_minmax,
"min": self._infer_aten_minmax,
"multinomial": self._infer_aten_multinomial,
"unfold": self._infer_aten_unfold,
"argmax": self._infer_aten_argmax,
"avg_pool2d": self._infer_aten_pool2d,
"_adaptive_avg_pool2d": self._infer_aten_pool2d,
"numpy_T": self._infer_Transpose,
"native_group_norm": self._infer_aten_group_norm,
"upsample_nearest1d": self._infer_aten_upsample,
"upsample_nearest2d": self._infer_aten_upsample,
"upsample_nearest3d": self._infer_aten_upsample,
"upsample_bilinear2d": self._infer_aten_upsample,
}
self.run_ = True
self.suggested_merge_ = {}
self.symbolic_dims_ = {}
self.input_symbols_ = {}
self.auto_merge_ = auto_merge
self.guess_output_rank_ = guess_output_rank
self.verbose_ = verbose
self.int_max_ = int_max
self.subgraph_id_ = 0
self.prefix_ = prefix
def _add_suggested_merge(self, symbols, apply=False):
assert all([(type(s) == str and s in self.symbolic_dims_) or is_literal(s) for s in symbols])
symbols = set(symbols)
for k, v in self.suggested_merge_.items():
if k in symbols:
symbols.remove(k)
symbols.add(v)
map_to = None
# if there is literal, map to it first
for s in symbols:
if is_literal(s):
map_to = s
break
# when no literals, map to input symbolic dims, then existing symbolic dims
if map_to is None:
for s in symbols:
if s in self.input_symbols_:
map_to = s
break
if map_to is None:
for s in symbols:
if type(self.symbolic_dims_[s]) == sympy.Symbol:
map_to = s
break
# when nothing to map to, use the shorter one
if map_to is None:
if self.verbose_ > 0:
logger.warning("Potential unsafe merge between symbolic expressions: ({})".format(",".join(symbols)))
symbols_list = list(symbols)
lens = [len(s) for s in symbols_list]
map_to = symbols_list[lens.index(min(lens))]
symbols.remove(map_to)
for s in symbols:
if s == map_to:
continue
if is_literal(map_to) and is_literal(s):
assert int(map_to) == int(s)
self.suggested_merge_[s] = int(map_to) if is_literal(map_to) else map_to
for k, v in self.suggested_merge_.items():
if v == s:
self.suggested_merge_[k] = map_to
if apply and self.auto_merge_:
self._apply_suggested_merge()
def _apply_suggested_merge(self, graph_input_only=False):
if not self.suggested_merge_:
return
for i in list(self.out_mp_.graph.input) + ([] if graph_input_only else list(self.out_mp_.graph.value_info)):
for d in i.type.tensor_type.shape.dim:
if d.dim_param in self.suggested_merge_:
v = self.suggested_merge_[d.dim_param]
if is_literal(v):
d.dim_value = int(v)
else:
d.dim_param = v
def _preprocess(self, in_mp):
self.out_mp_ = onnx.ModelProto()
self.out_mp_.CopyFrom(in_mp)
self.graph_inputs_ = dict([(i.name, i) for i in list(self.out_mp_.graph.input)])
self.initializers_ = dict([(i.name, i) for i in self.out_mp_.graph.initializer])
self.known_vi_ = dict([(i.name, i) for i in list(self.out_mp_.graph.input)])
self.known_vi_.update(
dict(
[
(
i.name,
helper.make_tensor_value_info(i.name, i.data_type, list(i.dims)),
)
for i in self.out_mp_.graph.initializer
]
)
)
def _merge_symbols(self, dims):
if not all([type(d) == str for d in dims]):
if self.auto_merge_:
unique_dims = list(set(dims))
is_int = [is_literal(d) for d in unique_dims]
assert sum(is_int) <= 1 # if there are more than 1 unique ints, something is wrong
if sum(is_int) == 1:
int_dim = is_int.index(1)
if self.verbose_ > 0:
logger.debug(
"dim {} has been merged with value {}".format(
unique_dims[:int_dim] + unique_dims[int_dim + 1 :],
unique_dims[int_dim],
)
)
self._check_merged_dims(unique_dims, allow_broadcast=False)
return unique_dims[int_dim]
else:
if self.verbose_ > 0:
logger.debug("dim {} has been mergd with dim {}".format(unique_dims[1:], unique_dims[0]))
return dims[0]
else:
return None
if all([d == dims[0] for d in dims]):
return dims[0]
merged = [self.suggested_merge_[d] if d in self.suggested_merge_ else d for d in dims]
if all([d == merged[0] for d in merged]):
assert merged[0] in self.symbolic_dims_
return merged[0]
else:
return None
# broadcast from right to left, and merge symbolic dims if needed
def _broadcast_shapes(self, shape1, shape2):
new_shape = []
rank1 = len(shape1)
rank2 = len(shape2)
new_rank = max(rank1, rank2)
for i in range(new_rank):
dim1 = shape1[rank1 - 1 - i] if i < rank1 else 1
dim2 = shape2[rank2 - 1 - i] if i < rank2 else 1
if dim1 == 1 or dim1 == dim2:
new_dim = dim2
elif dim2 == 1:
new_dim = dim1
else:
new_dim = self._merge_symbols([dim1, dim2])
if not new_dim:
# warning about unsupported broadcast when not auto merge
# note that auto merge has the risk of incorrectly merge symbols while one of them being 1
# for example, 'a' = 1, 'b' = 5 at runtime is valid broadcasting, but with auto merge 'a' == 'b'
if self.auto_merge_:
self._add_suggested_merge([dim1, dim2], apply=True)
else:
logger.warning("unsupported broadcast between " + str(dim1) + " " + str(dim2))
new_shape = [new_dim] + new_shape
return new_shape
def _get_shape(self, node, idx):
name = node.input[idx]
if name in self.known_vi_:
vi = self.known_vi_[name]
return get_shape_from_value_info(vi)
else:
assert name in self.initializers_
return list(self.initializers_[name].dims)
def _get_shape_rank(self, node, idx):
return len(self._get_shape(node, idx))
def _get_sympy_shape(self, node, idx):
sympy_shape = []
for d in self._get_shape(node, idx):
if type(d) == str:
sympy_shape.append(
self.symbolic_dims_[d]
if d in self.symbolic_dims_
else sympy.Symbol(d, integer=True, nonnegative=True)
)
else:
assert None != d
sympy_shape.append(d)
return sympy_shape
def _get_value(self, node, idx):
name = node.input[idx]
assert name in self.sympy_data_ or name in self.initializers_
return self.sympy_data_[name] if name in self.sympy_data_ else numpy_helper.to_array(self.initializers_[name])
def _try_get_value(self, node, idx):
if idx >= len(node.input):
return None
name = node.input[idx]
if name in self.sympy_data_ or name in self.initializers_:
return self._get_value(node, idx)
return None
def _update_computed_dims(self, new_sympy_shape):
for i, new_dim in enumerate(new_sympy_shape):
if not is_literal(new_dim) and not type(new_dim) == str:
str_dim = str(new_dim)
if str_dim in self.suggested_merge_:
if is_literal(self.suggested_merge_[str_dim]):
continue # no need to create dim for literals
new_sympy_shape[i] = self.symbolic_dims_[self.suggested_merge_[str_dim]]
else:
# add new_dim if it's a computational expression
if not str(new_dim) in self.symbolic_dims_:
self.symbolic_dims_[str(new_dim)] = new_dim
def _onnx_infer_single_node(self, node):
# skip onnx shape inference for some ops, as they are handled in _infer_*
skip_infer = node.op_type in [
"If",
"Loop",
"Scan",
"SplitToSequence",
"ZipMap", # contrib ops
"Attention",
"BiasGelu",
"EmbedLayerNormalization",
"FastGelu",
"Gelu",
"GemmFastGelu",
"LayerNormalization",
"LongformerAttention",
"SimplifiedLayerNormalization",
"SkipLayerNormalization",
"SkipSimplifiedLayerNormalization",
"PythonOp",
"MultiHeadAttention",
"GroupNorm",
"BiasSplitGelu",
"BiasAdd",
"NhwcConv",
]
if not skip_infer:
# Only pass initializers that satisfy the following condition:
# (1) Operator need value of some input for shape inference.
# For example, Unsqueeze in opset 13 uses the axes input to calculate shape of output.
# (2) opset version >= 9. In older version, initializer is required in graph input by onnx spec.
# (3) The initializer is not in graph input. The means the node input is "constant" in inference.
initializers = []
if (get_opset(self.out_mp_) >= 9) and node.op_type in ["Unsqueeze"]:
initializers = [
self.initializers_[name]
for name in node.input
if (name in self.initializers_ and name not in self.graph_inputs_)
]
# run single node inference with self.known_vi_ shapes
tmp_graph = helper.make_graph(
[node],
"tmp",
[self.known_vi_[i] for i in node.input if i],
[make_named_value_info(i) for i in node.output],
initializers,
)
self.tmp_mp_.graph.CopyFrom(tmp_graph)
self.tmp_mp_ = shape_inference.infer_shapes(self.tmp_mp_)
for i_o in range(len(node.output)):
o = node.output[i_o]
vi = self.out_mp_.graph.value_info.add()
if not skip_infer:
vi.CopyFrom(self.tmp_mp_.graph.output[i_o])
else:
vi.name = o
self.known_vi_[o] = vi
def _onnx_infer_subgraph(self, node, subgraph, use_node_input=True, inc_subgraph_id=True):
if self.verbose_ > 2:
logger.debug(
"Inferencing subgraph of node {} with output({}...): {}".format(node.name, node.output[0], node.op_type)
)
# node inputs are not passed directly to the subgraph
# it's up to the node dispatcher to prepare subgraph input
# for example, with Scan/Loop, subgraph input shape would be trimmed from node input shape
# besides, inputs in subgraph could shadow implicit inputs
subgraph_inputs = set([i.name for i in list(subgraph.initializer) + list(subgraph.input)])
subgraph_implicit_input = set([name for name in self.known_vi_.keys() if not name in subgraph_inputs])
tmp_graph = helper.make_graph(
list(subgraph.node),
"tmp",
list(subgraph.input) + [self.known_vi_[i] for i in subgraph_implicit_input],
[make_named_value_info(i.name) for i in subgraph.output],
)
tmp_graph.initializer.extend([i for i in self.out_mp_.graph.initializer if i.name in subgraph_implicit_input])
tmp_graph.initializer.extend(subgraph.initializer)
self.tmp_mp_.graph.CopyFrom(tmp_graph)
symbolic_shape_inference = SymbolicShapeInference(
self.int_max_,
self.auto_merge_,
self.guess_output_rank_,
self.verbose_,
prefix=self.prefix_ + "_" + str(self.subgraph_id_),
)
if inc_subgraph_id:
self.subgraph_id_ += 1
all_shapes_inferred = False
symbolic_shape_inference._preprocess(self.tmp_mp_)
symbolic_shape_inference.suggested_merge_ = self.suggested_merge_.copy()
while symbolic_shape_inference.run_:
all_shapes_inferred = symbolic_shape_inference._infer_impl(self.sympy_data_.copy())
symbolic_shape_inference._update_output_from_vi()
if use_node_input:
# if subgraph uses node input, it needs to update to merged dims
subgraph.ClearField("input")
subgraph.input.extend(symbolic_shape_inference.out_mp_.graph.input[: len(node.input)])
subgraph.ClearField("output")
subgraph.output.extend(symbolic_shape_inference.out_mp_.graph.output)
subgraph.ClearField("value_info")
subgraph.value_info.extend(symbolic_shape_inference.out_mp_.graph.value_info)
subgraph.ClearField("node")
subgraph.node.extend(symbolic_shape_inference.out_mp_.graph.node)
# for new symbolic dims from subgraph output, add to main graph symbolic dims
subgraph_shapes = [get_shape_from_value_info(o) for o in symbolic_shape_inference.out_mp_.graph.output]
subgraph_new_symbolic_dims = set(
[d for s in subgraph_shapes if s for d in s if type(d) == str and not d in self.symbolic_dims_]
)
new_dims = {}
for d in subgraph_new_symbolic_dims:
assert d in symbolic_shape_inference.symbolic_dims_
new_dims[d] = symbolic_shape_inference.symbolic_dims_[d]
self.symbolic_dims_.update(new_dims)
return symbolic_shape_inference
def _get_int_values(self, node, broadcast=False):
values = [self._try_get_value(node, i) for i in range(len(node.input))]
if all([v is not None for v in values]):
# some shape compute is in floating point, cast to int for sympy
for i, v in enumerate(values):
if type(v) != np.ndarray:
continue
if len(v.shape) > 1:
new_v = None # ignore value for rank > 1
elif len(v.shape) == 0:
new_v = int(v.item())
else:
assert len(v.shape) == 1
new_v = [int(vv) for vv in v]
values[i] = new_v
values_len = [len(v) if type(v) == list else 0 for v in values]
max_len = max(values_len)
if max_len >= 1 and broadcast:
# broadcast
for i, v in enumerate(values):
if v is None:
continue # don't broadcast if value is unknown
if type(v) == list:
if len(v) < max_len:
values[i] = v * max_len
else:
assert len(v) == max_len
else:
values[i] = [v] * max_len
return values
def _compute_on_sympy_data(self, node, op_func):
assert len(node.output) == 1
values = self._get_int_values(node, broadcast=True)
if all([v is not None for v in values]):
is_list = [type(v) == list for v in values]
as_list = any(is_list)
if as_list:
self.sympy_data_[node.output[0]] = [op_func(vs) for vs in zip(*values)]
else:
self.sympy_data_[node.output[0]] = op_func(values)
def _pass_on_sympy_data(self, node):
assert len(node.input) == 1 or node.op_type in [
"Reshape",
"Unsqueeze",
"Squeeze",
]
self._compute_on_sympy_data(node, lambda x: x[0])
def _pass_on_shape_and_type(self, node):
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0],
get_elem_type_from_type_proto(self.known_vi_[node.input[0]].type),
self._get_shape(node, 0),
)
)
def _new_symbolic_dim(self, prefix, dim):
new_dim = "{}_d{}".format(prefix, dim)
if new_dim in self.suggested_merge_:
v = self.suggested_merge_[new_dim]
new_symbolic_dim = sympy.Integer(int(v)) if is_literal(v) else v
else:
new_symbolic_dim = sympy.Symbol(new_dim, integer=True, nonnegative=True)
self.symbolic_dims_[new_dim] = new_symbolic_dim
return new_symbolic_dim
def _new_symbolic_dim_from_output(self, node, out_idx=0, dim=0):
return self._new_symbolic_dim(
"{}{}_{}_o{}_".format(
node.op_type,
self.prefix_,
list(self.out_mp_.graph.node).index(node),
out_idx,
),
dim,
)
def _new_symbolic_shape(self, rank, node, out_idx=0):
return [self._new_symbolic_dim_from_output(node, out_idx, i) for i in range(rank)]
def _compute_conv_pool_shape(self, node, channels_last=False):
sympy_shape = self._get_sympy_shape(node, 0)
if len(node.input) > 1:
W_shape = self._get_sympy_shape(node, 1)
rank = len(W_shape) - 2 # number of spatial axes
kernel_shape = W_shape[-rank - 1 : -1] if channels_last else W_shape[-rank:]
sympy_shape[3 if channels_last else 1] = W_shape[0]
else:
W_shape = None
kernel_shape = get_attribute(node, "kernel_shape")
rank = len(kernel_shape)
assert len(sympy_shape) == rank + 2
# only need to symbolic shape inference if input has symbolic dims in spatial axes
spatial_shape = sympy_shape[-rank - 1 : -1] if channels_last else sympy_shape[-rank:]
is_symbolic_dims = [not is_literal(i) for i in spatial_shape]
if not any(is_symbolic_dims):
shape = get_shape_from_value_info(self.known_vi_[node.output[0]])
if len(shape) > 0:
assert len(sympy_shape) == len(shape)
if channels_last:
sympy_shape[-rank - 1 : -1] = [sympy.Integer(d) for d in shape[-rank - 1 : -1]]
else:
sympy_shape[-rank:] = [sympy.Integer(d) for d in shape[-rank:]]
return sympy_shape
dilations = get_attribute(node, "dilations", [1] * rank)
strides = get_attribute(node, "strides", [1] * rank)
effective_kernel_shape = [(k - 1) * d + 1 for k, d in zip(kernel_shape, dilations)]
pads = get_attribute(node, "pads")
if pads is None:
pads = [0] * (2 * rank)
auto_pad = get_attribute(node, "auto_pad", b"NOTSET").decode("utf-8")
if auto_pad != "VALID" and auto_pad != "NOTSET":
try:
residual = [sympy.Mod(d, s) for d, s in zip(sympy_shape[-rank:], strides)]
total_pads = [
max(0, (k - s) if r == 0 else (k - r))
for k, s, r in zip(effective_kernel_shape, strides, residual)
]
except TypeError: # sympy may throw TypeError: cannot determine truth value of Relational
total_pads = [
max(0, (k - s)) for k, s in zip(effective_kernel_shape, strides)
] # assuming no residual if sympy throws error
elif auto_pad == "VALID":
total_pads = []
else:
total_pads = [0] * rank
else:
assert len(pads) == 2 * rank
total_pads = [p1 + p2 for p1, p2 in zip(pads[:rank], pads[rank:])]
ceil_mode = get_attribute(node, "ceil_mode", 0)
for i in range(rank):
effective_input_size = sympy_shape[-rank + i + (-1 if channels_last else 0)]
if len(total_pads) > 0:
effective_input_size = effective_input_size + total_pads[i]
if ceil_mode:
strided_kernel_positions = sympy.ceiling(
(effective_input_size - effective_kernel_shape[i]) / strides[i]
)
else:
strided_kernel_positions = (effective_input_size - effective_kernel_shape[i]) // strides[i]
sympy_shape[-rank + i + (-1 if channels_last else 0)] = strided_kernel_positions + 1
return sympy_shape
def _check_merged_dims(self, dims, allow_broadcast=True):
if allow_broadcast:
dims = [d for d in dims if not (is_literal(d) and int(d) <= 1)]
if not all([d == dims[0] for d in dims]):
self._add_suggested_merge(dims, apply=True)
def _compute_matmul_shape(self, node, output_dtype=None):
lhs_shape = self._get_shape(node, 0)
rhs_shape = self._get_shape(node, 1)
lhs_rank = len(lhs_shape)
rhs_rank = len(rhs_shape)
lhs_reduce_dim = 0
rhs_reduce_dim = 0
assert lhs_rank > 0 and rhs_rank > 0
if lhs_rank == 1 and rhs_rank == 1:
new_shape = []
elif lhs_rank == 1:
rhs_reduce_dim = -2
new_shape = rhs_shape[:rhs_reduce_dim] + [rhs_shape[-1]]
elif rhs_rank == 1:
lhs_reduce_dim = -1
new_shape = lhs_shape[:lhs_reduce_dim]
else:
lhs_reduce_dim = -1
rhs_reduce_dim = -2
new_shape = self._broadcast_shapes(lhs_shape[:-2], rhs_shape[:-2]) + [lhs_shape[-2]] + [rhs_shape[-1]]
# merge reduce dim
self._check_merged_dims(
[lhs_shape[lhs_reduce_dim], rhs_shape[rhs_reduce_dim]],
allow_broadcast=False,
)
if output_dtype is None:
# infer output_dtype from input type when not specified
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape))
def _fuse_tensor_type(self, node, out_idx, dst_type, src_type):
"""
update dst_tensor_type to be compatible with src_tensor_type when dimension mismatches
"""
dst_tensor_type = (
dst_type.sequence_type.elem_type.tensor_type if is_sequence(dst_type) else dst_type.tensor_type
)
src_tensor_type = (
src_type.sequence_type.elem_type.tensor_type if is_sequence(src_type) else src_type.tensor_type
)
if dst_tensor_type.elem_type != src_tensor_type.elem_type:
node_id = node.name if node.name else node.op_type
raise ValueError(
f"For node {node_id}, dst_tensor_type.elem_type != src_tensor_type.elem_type: "
f"{onnx.onnx_pb.TensorProto.DataType.Name(dst_tensor_type.elem_type)} vs "
f"{onnx.onnx_pb.TensorProto.DataType.Name(src_tensor_type.elem_type)}"
)
if dst_tensor_type.HasField("shape"):
for di, ds in enumerate(zip(dst_tensor_type.shape.dim, src_tensor_type.shape.dim)):
if ds[0] != ds[1]:
# create a new symbolic dimension for node/out_idx/mismatch dim id in dst_tensor_type for tensor_type
# for sequence_type, clear the dimension
new_dim = onnx.TensorShapeProto.Dimension()
if not is_sequence(dst_type):
new_dim.dim_param = str(self._new_symbolic_dim_from_output(node, out_idx, di))
dst_tensor_type.shape.dim[di].CopyFrom(new_dim)
else:
dst_tensor_type.CopyFrom(src_tensor_type)
def _infer_ArrayFeatureExtractor(self, node):
data_shape = self._get_shape(node, 0)
indices_shape = self._get_shape(node, 1)
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0],
self.known_vi_[node.input[0]].type.tensor_type.elem_type,
data_shape[:-1] + indices_shape,
)
)
def _infer_symbolic_compute_ops(self, node):
funcs = {
"Add": lambda l: l[0] + l[1],
"Div": lambda l: l[0] // l[1], # integer div in sympy
"Equal": lambda l: l[0] == l[1],
"Floor": lambda l: sympy.floor(l[0]),
"Max": lambda l: l[1]
if is_literal(l[0]) and int(l[0]) < -self.int_max_
else (l[0] if is_literal(l[1]) and int(l[1]) < -self.int_max_ else sympy.Max(l[0], l[1])),
"Min": lambda l: l[1]
if is_literal(l[0]) and int(l[0]) > self.int_max_
else (l[0] if is_literal(l[1]) and int(l[1]) > self.int_max_ else sympy.Min(l[0], l[1])),
"Mul": lambda l: l[0] * l[1],
"Sub": lambda l: l[0] - l[1],
"Where": lambda l: l[1] if l[0] else l[2],
"Neg": lambda l: -l[0],
}
assert node.op_type in funcs
self._compute_on_sympy_data(node, funcs[node.op_type])
def _infer_Cast(self, node):
self._pass_on_sympy_data(node)
def _infer_CategoryMapper(self, node):
input_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type
if input_type == onnx.TensorProto.STRING:
output_type = onnx.TensorProto.INT64
else:
output_type = onnx.TensorProto.STRING
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_type, self._get_shape(node, 0)))
def _infer_Compress(self, node):
input_shape = self._get_shape(node, 0)
# create a new symbolic dimension for Compress output
compress_len = str(self._new_symbolic_dim_from_output(node))
axis = get_attribute(node, "axis")
if axis == None:
# when axis is not specified, input is flattened before compress so output is 1D
output_shape = [compress_len]
else:
output_shape = input_shape
output_shape[handle_negative_axis(axis, len(input_shape))] = compress_len
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0],
self.known_vi_[node.input[0]].type.tensor_type.elem_type,
output_shape,
)
)
def _infer_Concat(self, node):
if any([i in self.sympy_data_ or i in self.initializers_ for i in node.input]):
values = self._get_int_values(node)
if all([v is not None for v in values]):
assert 0 == get_attribute(node, "axis")
self.sympy_data_[node.output[0]] = []
for i in range(len(node.input)):
value = values[i]
if type(value) == list:
self.sympy_data_[node.output[0]].extend(value)
else:
self.sympy_data_[node.output[0]].append(value)
sympy_shape = self._get_sympy_shape(node, 0)
axis = handle_negative_axis(get_attribute(node, "axis"), len(sympy_shape))
for i_idx in range(1, len(node.input)):
input_shape = self._get_sympy_shape(node, i_idx)
if input_shape:
sympy_shape[axis] = sympy_shape[axis] + input_shape[axis]
self._update_computed_dims(sympy_shape)
# merge symbolic dims for non-concat axes
for d in range(len(sympy_shape)):
if d == axis:
continue
dims = [self._get_shape(node, i_idx)[d] for i_idx in range(len(node.input)) if self._get_shape(node, i_idx)]
if all([d == dims[0] for d in dims]):
continue
merged = self._merge_symbols(dims)
if type(merged) == str:
sympy_shape[d] = self.symbolic_dims_[merged] if merged else None
else:
sympy_shape[d] = merged
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0],
self.known_vi_[node.input[0]].type.tensor_type.elem_type,
get_shape_from_sympy_shape(sympy_shape),
)
)
def _infer_ConcatFromSequence(self, node):
seq_shape = self._get_shape(node, 0)
new_axis = 1 if get_attribute(node, "new_axis") else 0
axis = handle_negative_axis(get_attribute(node, "axis"), len(seq_shape) + new_axis)
concat_dim = str(self._new_symbolic_dim_from_output(node, 0, axis))
new_shape = seq_shape
if new_axis:
new_shape = seq_shape[:axis] + [concat_dim] + seq_shape[axis:]
else:
new_shape[axis] = concat_dim
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0],
self.known_vi_[node.input[0]].type.sequence_type.elem_type.tensor_type.elem_type,
new_shape,
)
)
def _infer_Constant(self, node):
t = get_attribute(node, "value")
self.sympy_data_[node.output[0]] = numpy_helper.to_array(t)
def _infer_ConstantOfShape(self, node):
sympy_shape = self._get_int_values(node)[0]
vi = self.known_vi_[node.output[0]]
if sympy_shape is not None:
if type(sympy_shape) != list:
sympy_shape = [sympy_shape]
self._update_computed_dims(sympy_shape)
# update sympy data if output type is int, and shape is known
if vi.type.tensor_type.elem_type == onnx.TensorProto.INT64 and all([is_literal(x) for x in sympy_shape]):
self.sympy_data_[node.output[0]] = np.ones(
[int(x) for x in sympy_shape], dtype=np.int64
) * numpy_helper.to_array(get_attribute(node, "value", 0))
else:
# create new dynamic shape
# note input0 is a 1D vector of shape, the new symbolic shape has the rank of the shape vector length
sympy_shape = self._new_symbolic_shape(self._get_shape(node, 0)[0], node)
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0],
vi.type.tensor_type.elem_type,
get_shape_from_sympy_shape(sympy_shape),
)
)
def _infer_Conv(self, node):
sympy_shape = self._compute_conv_pool_shape(node)
self._update_computed_dims(sympy_shape)
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0],
vi.type.tensor_type.elem_type,
get_shape_from_sympy_shape(sympy_shape),
)
)
def _infer_NhwcConv(self, node):
sympy_shape = self._compute_conv_pool_shape(node, channels_last=True)
self._update_computed_dims(sympy_shape)
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0],
self.known_vi_[node.input[0]].type.tensor_type.elem_type,
get_shape_from_sympy_shape(sympy_shape),
)
)
def _infer_Einsum(self, node):
# ref:https://github.com/onnx/onnx/blob/623dfaa0151b2e4ce49779c3ec31cbd78c592b80/onnx/defs/math/defs.cc#L3275
equation = get_attribute(node, "equation")
equation = equation.replace(b" ", b"")
mid_index = equation.find(b"->")
left_equation = equation[:mid_index] if mid_index != -1 else equation
num_operands = 0
num_ellipsis = 0
num_ellipsis_indices = 0
letter_to_dim = {}
terms = left_equation.split(b",")
for term in terms:
ellipsis_index = term.find(b"...")
shape = self._get_shape(node, num_operands)
rank = len(shape)
if ellipsis_index != -1:
if num_ellipsis == 0:
num_ellipsis_indices = rank - len(term) + 3
num_ellipsis = num_ellipsis + 1
for i in range(1, rank + 1):
letter = term[-i]
if letter != 46: # letter != b'.'
dim = shape[-i]
if letter not in letter_to_dim.keys():
letter_to_dim[letter] = dim
elif type(dim) != sympy.Symbol:
letter_to_dim[letter] = dim
num_operands = num_operands + 1
new_sympy_shape = []
from collections import OrderedDict
num_letter_occurrences = OrderedDict()
if mid_index != -1:
right_equation = equation[mid_index + 2 :]
right_ellipsis_index = right_equation.find(b"...")
if right_ellipsis_index != -1:
for i in range(num_ellipsis_indices):
new_sympy_shape.append(shape[i])
for c in right_equation:
if c != 46: # c != b'.'
new_sympy_shape.append(letter_to_dim[c])
else:
for i in range(num_ellipsis_indices):
new_sympy_shape.append(shape[i])
for c in left_equation:
if c != 44 and c != 46: # c != b',' and c != b'.':
if c in num_letter_occurrences:
num_letter_occurrences[c] = num_letter_occurrences[c] + 1
else:
num_letter_occurrences[c] = 1
for key, value in num_letter_occurrences.items():
if value == 1:
new_sympy_shape.append(letter_to_dim[key])
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
vi = self.known_vi_[node.output[0]]