forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
keras.py
1580 lines (1410 loc) · 58.6 KB
/
keras.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, import-self, import-outside-toplevel
"""Keras frontend."""
import dis
import sys
import numpy as np
import tvm
from tvm.ir import IRModule, TensorType, TupleType
from .. import analysis
from .. import expr as _expr
from .. import function as _function
from .. import op as _op
from ... import nd as _nd
from .common import ExprTable, new_var
__all__ = ["from_keras"]
def _check_data_format(keras_layer):
if hasattr(keras_layer, ("data_format")):
if keras_layer.data_format != "channels_last":
raise ValueError("Keras frontend currently supports data_format = channels_last only.")
def _get_pad_pair(input1d, kernel1d, stride1d):
out1d = (input1d + stride1d - 1) // stride1d
pad = np.maximum((out1d - 1) * stride1d + kernel1d - input1d, 0)
pad_before = pad // 2
pad_after = pad - pad_before
return [pad_before, pad_after]
def _get_elu(inexpr, alpha):
"""A helper method for elu."""
return _op.negative(alpha) * _op.nn.relu(
_expr.const(1.0, dtype="float32") - _op.exp(inexpr)
) + _op.nn.relu(inexpr)
def _as_list(arr):
"""Force being a list, ignore if already is."""
if isinstance(arr, list):
return arr
return [arr]
def _convert_recurrent_activation(inexpr, keras_layer):
act_type = keras_layer.recurrent_activation.__name__
return _convert_activation(inexpr, act_type, None, None, None)
def _convert_activation(
inexpr, keras_layer, etab, data_layout, input_shape=None
): # pylint: disable=unused-argument
if isinstance(keras_layer, str):
act_type = keras_layer
else:
if sys.version_info.major < 3:
act_type = keras_layer.activation.func_name
else:
act_type = keras_layer.activation.__name__
if act_type == "linear":
if isinstance(keras_layer, str):
return inexpr
alpha = keras_layer.alpha if hasattr(keras_layer, "alpha") else 1.0
beta = keras_layer.beta if hasattr(keras_layer, "beta") else 0.0
alpha = _expr.const(alpha, dtype="float32")
beta = _expr.const(beta, dtype="float32")
return _op.add(_op.multiply(inexpr, alpha), beta)
if act_type == "softmax":
axis = 1 if data_layout == "NCHW" else -1
return _op.nn.softmax(inexpr, axis)
if act_type == "sigmoid":
return _op.sigmoid(inexpr)
if act_type == "tanh":
return _op.tanh(inexpr)
if act_type == "relu":
return _op.nn.relu(inexpr)
if act_type == "softplus":
return _op.log(_op.add(_op.exp(inexpr), _expr.const(1.0, dtype="float32")))
if act_type == "elu":
alpha = keras_layer.alpha if hasattr(keras_layer, "alpha") else 1.0
alpha = _expr.const(alpha, dtype="float32")
return _get_elu(inexpr, alpha)
if act_type == "selu":
# Alpha, Gamma values obtained from https://arxiv.org/abs/1706.02515
alpha = (
keras_layer.alpha
if hasattr(keras_layer, "alpha")
else 1.6732632423543772848170429916717
)
gamma = (
keras_layer.gamma
if hasattr(keras_layer, "gamma")
else 1.0507009873554804934193349852946
)
alpha = _expr.const(alpha, dtype="float32")
gamma = _expr.const(gamma, dtype="float32")
return gamma * _get_elu(inexpr, alpha)
if act_type == "relu6":
return _op.clip(inexpr, a_min=0.0, a_max=6.0)
if act_type == "softsign":
return inexpr / (_expr.const(1.0, dtype="float32") + _op.abs(inexpr))
if act_type == "hard_sigmoid":
x = (_expr.const(0.2, dtype="float32") * inexpr) + _expr.const(0.5, dtype="float32")
return _op.clip(x, a_min=0.0, a_max=1.0)
raise tvm.error.OpNotImplemented(f"Operator {act_type} is not supported in frontend Keras.")
def _convert_advanced_activation(inexpr, keras_layer, etab, data_layout, input_shape=None):
act_type = type(keras_layer).__name__
if input_shape is None:
input_shape = keras_layer.input_shape
if act_type == "Softmax":
axis = keras_layer.axis
dims = len(input_shape)
if isinstance(axis, list):
raise tvm.error.OpAttributeUnImplemented(f"Softmax with axes {axis} is not supported.")
if data_layout == "NCHW":
if axis == -1:
axis = 1
else:
axis = axis + 1 if axis < dims - 1 else 1
return _op.nn.softmax(inexpr, axis=axis)
if act_type == "ReLU":
if np.isnan(keras_layer.threshold).any():
raise tvm.error.OpAttributeInvalid("The threshold value of a ReLU cannot be None.")
threshold = _expr.const(keras_layer.threshold, dtype="float32")
if keras_layer.max_value and float(keras_layer.threshold) == 0:
# f(x) = max_value, for x >= max_value
# f(x) = x, for threshold <= x < max_value
return _op.clip(inexpr, a_min=0.0, a_max=float(keras_layer.max_value))
if keras_layer.max_value and _op.greater(threshold, inexpr).astype("float32"):
# f(x) = negative_slope * (inexpr - threshold)
negative_slope = _expr.const(keras_layer.negative_slope, dtype="float32")
return _op.multiply(negative_slope, _op.subtract(inexpr, threshold))
return _op.nn.relu(inexpr)
if act_type == "LeakyReLU":
if np.isnan(keras_layer.alpha).any():
raise tvm.error.OpAttributeInvalid("The alpha value of a LeakyReLU cannot be None.")
return _op.nn.leaky_relu(inexpr, alpha=float(keras_layer.alpha))
if act_type == "ELU":
if np.isnan(keras_layer.alpha).any():
raise tvm.error.OpAttributeInvalid("The alpha value of a ELU cannot be None.")
alpha = keras_layer.alpha if hasattr(keras_layer, "alpha") else 1.0
alpha = _expr.const(alpha, dtype="float32")
return _get_elu(inexpr, alpha)
if act_type == "PReLU":
assert hasattr(keras_layer, "alpha"), "alpha required for PReLU."
_check_data_format(keras_layer)
size = len(keras_layer.alpha.shape)
if data_layout == "NCHW":
alpha = etab.new_const(keras_layer.get_weights()[0].transpose(np.roll(range(size), 1)))
else:
alpha = etab.new_const(keras_layer.get_weights()[0])
return _op.negative(alpha) * _op.nn.relu(_op.negative(inexpr)) + _op.nn.relu(inexpr)
if act_type == "ThresholdedReLU":
theta = keras_layer.theta if hasattr(keras_layer, "theta") else 1.0
return _op.multiply(
inexpr, _op.greater(inexpr, _expr.const(theta, dtype="float32")).astype("float32")
)
raise tvm.error.OpNotImplemented(f"Operator {act_type} is not supported in frontend Keras.")
def _convert_merge(
inexpr, keras_layer, _, input_shape=None, data_layout=None
): # pylint: disable=unused-argument
merge_type = type(keras_layer).__name__
ret = inexpr[0]
if merge_type == "Dot":
axes = keras_layer.axes
if isinstance(keras_layer.axes, int):
axes = [keras_layer.axes, keras_layer.axes]
if isinstance(axes, list):
if len(axes) != 2:
raise tvm.error.OpAttributeUnImplemented(
f"Dot with axes {keras_layer.axes} is not supported."
)
for i, axis in enumerate(axes):
if axis not in [1, 2]:
raise tvm.error.OpAttributeUnImplemented(
f"Dot with axes {keras_layer.axes} is not supported."
)
if axes[i] == 2:
inexpr[i] = _op.transpose(inexpr[i], axes=[0, 2, 1])
else:
raise tvm.error.OpAttributeUnImplemented(
f"Dot with axes {keras_layer.axes} is not supported."
)
ret_dot = _op.nn.batch_matmul(inexpr[0], inexpr[1])
ret = _op.transpose(ret_dot, axes=[0, 2, 1])
elif merge_type == "Subtract":
assert len(inexpr) == 2, "Subtract merge takes 2 inputs."
ret = _op.subtract(ret, inexpr[1])
elif merge_type in ["Add", "Multiply", "Minimum", "Maximum"]:
op_map = {
"Add": _op.add,
"Multiply": _op.multiply,
"Minimum": _op.minimum,
"Maximum": _op.maximum,
}
for i in range(1, len(inexpr)):
ret = op_map[merge_type](ret, inexpr[i])
elif merge_type == "Average":
for i in range(1, len(inexpr)):
ret = _op.add(ret, inexpr[i])
ret = ret / _expr.const(len(inexpr), dtype="float32")
else:
raise tvm.error.OpNotImplemented(
f"Operator {merge_type} is not supported in frontend Keras."
)
return ret
def _convert_permute(
inexpr, keras_layer, _, input_shape=None, data_layout=None
): # pylint: disable=unused-argument
return _op.transpose(inexpr, axes=(0,) + keras_layer.dims)
def _convert_embedding(
inexpr, keras_layer, etab, data_layout, input_shape=None
): # pylint: disable=unused-argument
indices = inexpr
weightList = keras_layer.get_weights()
weight = etab.new_const(weightList[0])
out = _op.take(weight, indices.astype("int32"), axis=0)
return out
def _convert_dense(
inexpr, keras_layer, etab, data_layout, input_shape=None
): # pylint: disable=unused-argument
weightList = keras_layer.get_weights()
weight = etab.new_const(weightList[0].transpose([1, 0]))
params = {"weight": weight, "units": weightList[0].shape[1]}
if input_shape is None:
input_shape = keras_layer.input_shape
input_dim = len(input_shape)
# In case of RNN dense, input shape will be (1, 1, n)
if input_dim > 2:
input_shape = tuple(dim if dim else 1 for dim in _as_list(input_shape)[0])
if input_dim != 3 or input_shape[0] != 1 or input_shape[1] != 1:
raise tvm.error.OpAttributeInvalid(
f"Input shape {input_shape} is not valid for operator Dense."
)
inexpr = _op.squeeze(inexpr, axis=[0])
out = _op.nn.dense(data=inexpr, **params)
if keras_layer.use_bias:
bias = etab.new_const(weightList[1])
out = _op.nn.bias_add(out, bias)
# defuse activation
if sys.version_info.major < 3:
act_type = keras_layer.activation.func_name
else:
act_type = keras_layer.activation.__name__
if act_type != "linear":
out = _convert_activation(out, act_type, etab, data_layout)
if input_dim > 2:
out = _op.expand_dims(out, axis=0)
return out
def _convert_convolution1d(inexpr, keras_layer, etab, data_layout, input_shape=None):
is_deconv = type(keras_layer).__name__ == "Conv1DTranspose"
if input_shape is None:
input_shape = keras_layer.input_shape
_check_data_format(keras_layer)
weightList = keras_layer.get_weights()
weight = weightList[0]
if data_layout == "NWC":
kernel_layout = "WIO"
if is_deconv:
kernel_layout = "WOI"
else:
kernel_layout = "OIW"
if is_deconv:
kernel_layout = "IOW"
msg = (
f"Kernel layout with {kernel_layout} is not supported for operator Convolution1D "
f"in frontend Keras."
)
raise tvm.error.OpAttributeUnImplemented(msg)
if is_deconv:
if kernel_layout == "IOW":
weight = weight.transpose([2, 1, 0])
kernel_w, n_filters, _ = weight.shape
else:
kernel_w, _, n_filters = weight.shape
dilation_rate = keras_layer.dilation_rate
if isinstance(dilation_rate, (list, tuple)):
dilation = [dilation_rate[0]]
else:
dilation = [dilation_rate]
dilated_kernel_w = (kernel_w - 1) * dilation[0] + 1
stride_w = keras_layer.strides[0]
params = {
"weight": etab.new_const(weight),
"kernel_size": [kernel_w],
"strides": [stride_w],
"dilation": dilation,
"padding": [0],
"data_layout": data_layout,
"kernel_layout": kernel_layout,
}
params["channels"] = n_filters
if keras_layer.padding == "valid":
pass
# calculate the padding values
elif keras_layer.padding == "same":
in_w = input_shape[1]
pad_w = _get_pad_pair(in_w, dilated_kernel_w, stride_w)
params["padding"] = [pad_w[0], pad_w[1]]
else:
msg = (
f"Padding with {keras_layer.padding} is not supported for operator Convolution3D "
f"in frontend Keras."
)
raise tvm.error.OpAttributeUnImplemented(msg)
if is_deconv:
out = _op.nn.conv1d_transpose(data=inexpr, **params)
else:
out = _op.nn.conv1d(data=inexpr, **params)
channel_axis = -1 if data_layout == "NWC" else 1
if keras_layer.use_bias:
bias = etab.new_const(weightList[1])
out = _op.nn.bias_add(out, bias, channel_axis)
# defuse activation
if sys.version_info.major < 3:
act_type = keras_layer.activation.func_name
else:
act_type = keras_layer.activation.__name__
if act_type != "linear":
out = _convert_activation(out, act_type, etab, data_layout)
return out
def _convert_convolution(inexpr, keras_layer, etab, data_layout, input_shape=None):
_check_data_format(keras_layer)
is_deconv = type(keras_layer).__name__ == "Conv2DTranspose"
is_depthconv = type(keras_layer).__name__ == "DepthwiseConv2D"
weightList = keras_layer.get_weights()
weight = weightList[0]
if input_shape is None:
input_shape = keras_layer.input_shape
if data_layout == "NHWC":
if is_depthconv:
kernel_layout = "HWOI"
elif is_deconv:
kernel_layout = "HWOI"
else:
kernel_layout = "HWIO"
else:
if is_deconv:
kernel_layout = "IOHW"
else:
kernel_layout = "OIHW"
if is_deconv:
kernel_h, kernel_w, n_filters, in_channels = weight.shape
if kernel_layout == "IOHW":
weight = weight.transpose([3, 2, 0, 1])
elif is_depthconv:
kernel_h, kernel_w, in_channels, depth_mult = weight.shape
if kernel_layout == "OIHW":
weight = weight.transpose([2, 3, 0, 1])
elif data_layout == "NCHW":
kernel_h, kernel_w, in_channels, n_filters = weight.shape
weight = weight.transpose([3, 2, 0, 1])
else:
kernel_h, kernel_w, in_channels, n_filters = weight.shape
if isinstance(keras_layer.dilation_rate, (list, tuple)):
dilation = [keras_layer.dilation_rate[0], keras_layer.dilation_rate[1]]
else:
dilation = [keras_layer.dilation_rate, keras_layer.dilation_rate]
dilated_kernel_h = (kernel_h - 1) * dilation[0] + 1
dilated_kernel_w = (kernel_w - 1) * dilation[1] + 1
stride_h, stride_w = keras_layer.strides
params = {
"weight": etab.new_const(weight),
"kernel_size": [kernel_h, kernel_w],
"strides": [stride_h, stride_w],
"dilation": dilation,
"padding": [0, 0],
"data_layout": data_layout,
"kernel_layout": kernel_layout,
}
if is_depthconv:
params["channels"] = in_channels * depth_mult
params["groups"] = in_channels
else:
params["channels"] = n_filters
if is_deconv and keras_layer.output_padding:
params["output_padding"] = keras_layer.output_padding
if keras_layer.padding == "valid":
pass
# we insert a separate pad operator
elif keras_layer.padding == "same":
in_h = input_shape[1]
in_w = input_shape[2]
pad_t, pad_b = _get_pad_pair(in_h, dilated_kernel_h, stride_h)
pad_l, pad_r = _get_pad_pair(in_w, dilated_kernel_w, stride_w)
params["padding"] = (pad_t, pad_l, pad_b, pad_r)
else:
msg = (
f"Padding with {keras_layer.padding} is not supported for operator Convolution "
f"in frontend Keras."
)
raise tvm.error.OpAttributeUnImplemented(msg)
if is_deconv:
out = _op.nn.conv2d_transpose(data=inexpr, **params)
else:
out = _op.nn.conv2d(data=inexpr, **params)
if keras_layer.use_bias:
bias = etab.new_const(weightList[1])
if data_layout == "NCHW":
out = _op.nn.bias_add(out, bias)
else:
out = _op.nn.bias_add(out, bias, axis=-1)
# defuse activation
if sys.version_info.major < 3:
act_type = keras_layer.activation.func_name
else:
act_type = keras_layer.activation.__name__
if act_type != "linear":
out = _convert_activation(out, act_type, etab, data_layout)
return out
def _convert_convolution3d(inexpr, keras_layer, etab, data_layout, input_shape=None):
_check_data_format(keras_layer)
is_deconv = type(keras_layer).__name__ == "Conv3DTranspose"
weightList = keras_layer.get_weights()
weight = weightList[0]
if input_shape is None:
input_shape = keras_layer.input_shape
if data_layout == "NDHWC":
kernel_layout = "DHWIO"
if is_deconv:
kernel_layout = "DHWOI"
else:
kernel_layout = "OIDHW"
if is_deconv:
kernel_layout = "IODHW"
msg = (
f"Kernel layout with {kernel_layout} is not supported for operator Convolution3D "
f"in frontend Keras."
)
raise tvm.error.OpAttributeUnImplemented(msg)
if is_deconv:
kernel_d, kernel_h, kernel_w, n_filters, _ = weight.shape
if kernel_layout == "IODHW":
weight = weight.transpose([4, 3, 0, 1, 2])
else:
kernel_d, kernel_h, kernel_w, _, n_filters = weight.shape
dilation_rate = keras_layer.dilation_rate
if isinstance(dilation_rate, (list, tuple)):
dilation = [dilation_rate[0], dilation_rate[1], dilation_rate[2]]
else:
dilation = [dilation_rate, dilation_rate, dilation_rate]
dilated_kernel_d = (kernel_d - 1) * dilation[0] + 1
dilated_kernel_h = (kernel_h - 1) * dilation[1] + 1
dilated_kernel_w = (kernel_w - 1) * dilation[2] + 1
stride_d, stride_h, stride_w = keras_layer.strides
params = {
"weight": etab.new_const(weight),
"kernel_size": [kernel_d, kernel_h, kernel_w],
"strides": [stride_d, stride_h, stride_w],
"dilation": dilation,
"padding": [0, 0, 0],
"data_layout": data_layout,
"kernel_layout": kernel_layout,
}
params["channels"] = n_filters
if is_deconv and keras_layer.output_padding:
params["output_padding"] = keras_layer.output_padding
if keras_layer.padding == "valid":
pass
# calculate the padding values
elif keras_layer.padding == "same":
in_d = input_shape[1]
in_h = input_shape[2]
in_w = input_shape[3]
pad_d = _get_pad_pair(in_d, dilated_kernel_d, stride_d)
pad_h = _get_pad_pair(in_h, dilated_kernel_h, stride_h)
pad_w = _get_pad_pair(in_w, dilated_kernel_w, stride_w)
params["padding"] = [pad_d[0], pad_h[0], pad_w[0], pad_d[1], pad_h[1], pad_w[1]]
else:
msg = (
f"Padding with {keras_layer.padding} is not supported for operator Convolution3D "
f"in frontend Keras."
)
raise tvm.error.OpAttributeUnImplemented(msg)
if is_deconv:
out = _op.nn.conv3d_transpose(data=inexpr, **params)
else:
out = _op.nn.conv3d(data=inexpr, **params)
channel_axis = -1 if data_layout == "NDHWC" else 1
if keras_layer.use_bias:
bias = etab.new_const(weightList[1])
out = _op.nn.bias_add(out, bias, channel_axis)
# defuse activation
if sys.version_info.major < 3:
act_type = keras_layer.activation.func_name
else:
act_type = keras_layer.activation.__name__
if act_type != "linear":
out = _convert_activation(out, act_type, etab, None)
return out
def _convert_separable_convolution(inexpr, keras_layer, etab, data_layout, input_shape=None):
_check_data_format(keras_layer)
if data_layout == "NHWC":
kernel_layout = "HWOI"
else:
kernel_layout = "OIHW"
if input_shape is None:
input_shape = keras_layer.input_shape
weightList = keras_layer.get_weights()
# depthwise conv
kernel_h, kernel_w, in_channels, depth_mult = weightList[0].shape
stride_h, stride_w = keras_layer.strides
if kernel_layout == "OIHW":
weight0 = weightList[0].transpose([2, 3, 0, 1])
else:
weight0 = weightList[0]
if isinstance(keras_layer.dilation_rate, (list, tuple)):
dilation = [keras_layer.dilation_rate[0], keras_layer.dilation_rate[1]]
else:
dilation = [keras_layer.dilation_rate, keras_layer.dilation_rate]
params0 = {
"weight": etab.new_const(weight0),
"channels": in_channels * depth_mult,
"groups": in_channels,
"kernel_size": [kernel_h, kernel_w],
"strides": [stride_h, stride_w],
"dilation": dilation,
"padding": [0, 0],
"data_layout": data_layout,
"kernel_layout": kernel_layout,
}
if keras_layer.padding == "valid":
pass
# we insert a separate pad operator
elif keras_layer.padding == "same":
in_h = input_shape[1]
in_w = input_shape[2]
pad_t, pad_b = _get_pad_pair(in_h, kernel_h, stride_h)
pad_l, pad_r = _get_pad_pair(in_w, kernel_w, stride_w)
params0["padding"] = (pad_t, pad_l, pad_b, pad_r)
else:
msg = (
f"Padding with {keras_layer.padding} is not supported for operator Separable "
f"Convolution in frontend Keras."
)
raise tvm.error.OpAttributeUnImplemented(msg)
depthconv = _op.nn.conv2d(data=inexpr, **params0)
# pointwise conv
if kernel_layout == "OIHW":
weight1 = weightList[1].transpose([3, 2, 0, 1])
else:
weight1 = weightList[1]
kernel_layout = "HWIO"
params1 = {
"weight": etab.new_const(weight1),
"channels": weightList[1].shape[3],
"groups": 1,
"kernel_size": [1, 1],
"strides": [1, 1],
"dilation": [1, 1],
"data_layout": data_layout,
"kernel_layout": kernel_layout,
}
out = _op.nn.conv2d(data=depthconv, **params1)
if keras_layer.use_bias:
bias = etab.new_const(weightList[2])
if data_layout == "NCHW":
out = _op.nn.bias_add(out, bias)
else:
out = _op.nn.bias_add(out, bias, axis=-1)
# defuse activation
if sys.version_info.major < 3:
act_type = keras_layer.activation.func_name
else:
act_type = keras_layer.activation.__name__
if act_type != "linear":
out = _convert_activation(out, act_type, etab, data_layout)
return out
def _convert_flatten(
inexpr, keras_layer, etab, data_layout, input_shape=None
): # pylint: disable=unused-argument
_check_data_format(keras_layer)
# NCHW -> NHWC so that dense can be correctly converted
if data_layout == "NCHW":
inexpr = _op.transpose(inexpr, axes=[0, 2, 3, 1])
return _op.nn.batch_flatten(inexpr)
def _convert_pooling(
inexpr, keras_layer, etab, data_layout, input_shape=None
): # pylint: disable=unused-argument
_check_data_format(keras_layer)
pool_type = type(keras_layer).__name__
# global pool in keras = global pool + flatten in relay
global_pool_params = {"layout": data_layout}
if input_shape is None:
input_shape = keras_layer.input_shape
if pool_type == "GlobalMaxPooling2D":
return _convert_flatten(
_op.nn.global_max_pool2d(inexpr, **global_pool_params), keras_layer, etab, data_layout
)
if pool_type == "GlobalAveragePooling2D":
global_avg_pool2d = _op.nn.global_avg_pool2d(inexpr, **global_pool_params)
keep_dims = len(keras_layer.input.shape) == len(keras_layer.output.shape)
if keep_dims:
return global_avg_pool2d
return _convert_flatten(global_avg_pool2d, keras_layer, etab, data_layout)
pool_h, pool_w = keras_layer.pool_size
stride_h, stride_w = keras_layer.strides
params = {
"pool_size": [pool_h, pool_w],
"strides": [stride_h, stride_w],
"padding": [0, 0],
"layout": data_layout,
}
if keras_layer.padding == "valid":
pass
elif keras_layer.padding == "same":
in_h = input_shape[1]
in_w = input_shape[2]
pad_t, pad_b = _get_pad_pair(in_h, pool_h, stride_h)
pad_l, pad_r = _get_pad_pair(in_w, pool_w, stride_w)
params["padding"] = [pad_t, pad_l, pad_b, pad_r]
else:
raise tvm.error.OpAttributeUnImplemented(
f"Padding with {keras_layer.padding} is not supported in operator Pooling."
)
if pool_type == "MaxPooling2D":
return _op.nn.max_pool2d(inexpr, **params)
if pool_type == "AveragePooling2D":
params["count_include_pad"] = False
return _op.nn.avg_pool2d(inexpr, **params)
raise tvm.error.OpNotImplemented(f"Operator {keras_layer} is not supported for frontend Keras.")
def _convert_pooling3d(
inexpr, keras_layer, etab, data_layout, input_shape=None
): # pylint: disable=unused-argument
_check_data_format(keras_layer)
pool_type = type(keras_layer).__name__
if input_shape is None:
input_shape = keras_layer.input_shape
if pool_type not in ["MaxPooling3D", "AveragePooling3D"]:
raise tvm.error.OpNotImplemented(
f"Operator {keras_layer} is not supported for frontend Keras."
)
pool_d1, pool_d2, pool_d3 = keras_layer.pool_size
stride_d1, stride_d2, stride_d3 = keras_layer.strides
params = {
"pool_size": [pool_d1, pool_d2, pool_d3],
"strides": [stride_d1, stride_d2, stride_d3],
"padding": [0, 0, 0],
"layout": data_layout,
}
if keras_layer.padding == "valid":
pass
elif keras_layer.padding == "same":
in_d1 = input_shape[1]
in_d2 = input_shape[2]
in_d3 = input_shape[3]
pad_d1 = _get_pad_pair(in_d1, pool_d1, stride_d1)
pad_d2 = _get_pad_pair(in_d2, pool_d2, stride_d2)
pad_d3 = _get_pad_pair(in_d3, pool_d3, stride_d3)
params["padding"] = [pad_d1[0], pad_d2[0], pad_d3[0], pad_d1[1], pad_d2[1], pad_d3[1]]
else:
raise tvm.error.OpAttributeUnImplemented(
f"Padding with {keras_layer.padding} is not supported in operator Pooling3D."
)
out = _op.transpose(inexpr, axes=(0, 4, 1, 2, 3))
params["layout"] = "NCDHW"
if pool_type == "MaxPooling3D":
out = _op.nn.max_pool3d(out, **params)
elif pool_type == "AveragePooling3D":
out = _op.nn.avg_pool3d(out, **params)
return _op.transpose(out, axes=(0, 2, 3, 4, 1))
def _convert_global_pooling3d(
inexpr, keras_layer, etab, data_layout, input_shape=None
): # pylint: disable=unused-argument
_check_data_format(keras_layer)
pool_type = type(keras_layer).__name__
global_pool_params = {"layout": data_layout}
if pool_type == "GlobalMaxPooling3D":
out = _op.nn.global_max_pool3d(inexpr, **global_pool_params)
elif pool_type == "GlobalAveragePooling3D":
out = _op.nn.global_avg_pool3d(inexpr, **global_pool_params)
else:
raise tvm.error.OpNotImplemented(
f"Operator {keras_layer} is not supported for frontend Keras."
)
return _convert_flatten(out, keras_layer, etab, input_shape, data_layout)
def _convert_upsample(
inexpr, keras_layer, etab, data_layout, input_shape=None
): # pylint: disable=unused-argument
_check_data_format(keras_layer)
upsample_type = type(keras_layer).__name__
params = {}
if upsample_type == "UpSampling1D":
h = keras_layer.size
params["scale_h"] = h
elif upsample_type == "UpSampling2D":
h, w = keras_layer.size
params["scale_h"] = h
params["scale_w"] = w
if hasattr(keras_layer, "interpolation"):
interpolation = keras_layer.interpolation
if interpolation == "nearest":
params["method"] = "nearest_neighbor"
else:
params["method"] = "bilinear"
else:
raise tvm.error.OpNotImplemented(
f"Operator {upsample_type} is not supported for frontend Keras."
)
params["layout"] = data_layout
out = _op.nn.upsampling(inexpr, **params)
return out
def _convert_upsample3d(
inexpr, keras_layer, etab, data_layout, input_shape=None
): # pylint: disable=unused-argument
_check_data_format(keras_layer)
params = {}
d, h, w = keras_layer.size
params["scale_d"] = d
params["scale_h"] = h
params["scale_w"] = w
params["layout"] = data_layout
params["coordinate_transformation_mode"] = "asymmetric"
out = _op.nn.upsampling3d(inexpr, **params)
return out
def _convert_cropping(
inexpr, keras_layer, etab, data_layout, input_shape=None
): # pylint: disable=unused-argument
_check_data_format(keras_layer)
crop_type = type(keras_layer).__name__
if input_shape is None:
input_shape = keras_layer.input_shape
if crop_type == "Cropping2D":
(_, in_h, in_w, _) = input_shape
((crop_t, crop_b), (crop_l, crop_r)) = keras_layer.cropping
else:
raise tvm.error.OpNotImplemented(
f"Operator {crop_type} is not supported for frontend Keras."
)
int32_max = np.iinfo(np.int32).max
if data_layout == "NHWC":
begin = [0, crop_t, crop_l, 0]
end = [int32_max, in_h - crop_b, in_w - crop_r, int32_max]
else:
begin = [0, 0, crop_t, crop_l]
end = [int32_max, int32_max, in_h - crop_b, in_w - crop_r]
return _op.strided_slice(
inexpr,
begin=begin,
end=end,
)
def _convert_batchnorm(inexpr, keras_layer, etab, data_layout, input_shape=None):
if input_shape is None:
input_shape = keras_layer.input_shape
if data_layout == "NCHW" or len(input_shape) < 4:
axis = 1
else:
axis = 3
params = {"scale": False, "center": False, "epsilon": keras_layer.epsilon, "axis": axis}
idx = 0
if keras_layer.scale:
params["scale"] = True
gamma = keras_layer.get_weights()[idx]
params["gamma"] = etab.new_const(gamma)
idx += 1
if keras_layer.center:
params["center"] = True
beta = keras_layer.get_weights()[idx]
params["beta"] = etab.new_const(beta)
idx += 1
moving_mean = keras_layer.get_weights()[idx]
moving_var = keras_layer.get_weights()[idx + 1]
params["moving_mean"] = etab.new_const(moving_mean)
params["moving_var"] = etab.new_const(moving_var)
# in case beta or gamma is not defined
params["beta"] = (
etab.new_const(np.zeros(moving_mean.shape)) if "beta" not in params else params["beta"]
)
params["gamma"] = (
etab.new_const(np.ones(moving_mean.shape)) if "gamma" not in params else params["gamma"]
)
result, moving_mean, moving_var = _op.nn.batch_norm(inexpr, **params)
return result
def _convert_padding(
inexpr, keras_layer, etab, data_layout, input_shape=None
): # pylint: disable=unused-argument
_check_data_format(keras_layer)
padding_type = type(keras_layer).__name__
padding = keras_layer.padding
top = left = bottom = right = 0
if padding_type == "ZeroPadding2D":
if isinstance(padding, int):
top = left = bottom = right = padding
elif isinstance(padding, tuple):
if isinstance(padding[0], int):
top, left = padding
bottom, right = padding
elif isinstance(padding[0], tuple):
top, bottom = padding[0]
left, right = padding[1]
else:
msg = (
f'Value {str(padding)} in attribute "padding" of operator Padding is '
f"not valid."
)
raise tvm.error.OpAttributeInvalid(msg)
else:
msg = f'Value {str(padding)} in attribute "padding" of operator Padding is not valid.'
raise tvm.error.OpAttributeInvalid(msg)
else:
msg = f"Operator {padding_type} is not supported in frontend Keras."
raise tvm.error.OpNotImplemented(msg)
if data_layout == "NCHW":
return _op.nn.pad(data=inexpr, pad_width=((0, 0), (0, 0), (top, bottom), (left, right)))
return _op.nn.pad(data=inexpr, pad_width=((0, 0), (top, bottom), (left, right), (0, 0)))
def _convert_padding3d(
inexpr, keras_layer, etab, data_layout, input_shape=None
): # pylint: disable=unused-argument
_check_data_format(keras_layer)
padding = keras_layer.padding
d_pad = h_pad = w_pad = [0, 0]
# padding can be 'int' or 'tuple of 3 ints' or 'tuple of 3 tuples of 2 ints' or 'tuple
# of 3 tuples of 2 ints different values'. In all these scenarios keras will send 3
# tuples of 2 ints.
if isinstance(padding, tuple) and isinstance(padding[0], tuple):
d_pad = padding[0]
h_pad = padding[1]
w_pad = padding[2]
else:
msg = f'Value {str(padding)} in attribute "padding" of operator ZeroPadding3D is not valid.'
raise tvm.error.OpAttributeInvalid(msg)
if data_layout == "NCDHW":
out = _op.nn.pad(
data=inexpr,
pad_width=(
(0, 0),
(0, 0),
(d_pad[0], d_pad[1]),
(h_pad[0], h_pad[1]),
(w_pad[0], w_pad[1]),
),
)
else:
out = _op.nn.pad(
data=inexpr,
pad_width=(
(0, 0),
(d_pad[0], d_pad[1]),
(h_pad[0], h_pad[1]),
(w_pad[0], w_pad[1]),
(0, 0),
),
)
return out
def _convert_concat(
inexpr, keras_layer, etab, data_layout, input_shape=None
): # pylint: disable=unused-argument
_check_data_format(keras_layer)
if input_shape is None:
input_shape = keras_layer.input_shape
axis = keras_layer.axis
dims = len(input_shape[0])
if data_layout == "NCHW": # need_transpose
if axis == -1:
axis = 1
else:
axis = axis + 1 if axis < dims else 1
return _op.concatenate(_as_list(inexpr), axis=axis)
def _convert_reshape(
inexpr, keras_layer, etab, data_layout, input_shape=None
): # pylint: disable=unused-argument
_check_data_format(keras_layer)
if input_shape is None:
input_shape = keras_layer.input_shape
inshape = input_shape # includes batch
tshape = keras_layer.target_shape # no batch
shape = (-1,) + tshape
if data_layout == "NCHW" and (len(inshape) > 3 or len(tshape) > 2):
# Perform reshape in original NHWC format.
inexpr = _op.transpose(inexpr, [0] + list(range(2, len(inshape))) + [1])
inexpr = _op.reshape(inexpr, newshape=shape)
return _op.transpose(inexpr, axes=[0, -1] + list(range(1, len(shape) - 1)))
return _op.reshape(inexpr, newshape=shape)
def _convert_lstm(
inexpr, keras_layer, etab, data_layout, input_shape=None
): # pylint: disable=unused-argument
_check_data_format(keras_layer)
if input_shape is None:
input_shape = keras_layer.input_shape
if not isinstance(inexpr, list):
buf = np.zeros((1, keras_layer.units), "float32")
c_op = etab.new_const(buf)
h_op = etab.new_const(buf)
inexpr = [inexpr, h_op, c_op]
in_data = inexpr[0]
next_h = inexpr[1]