diff --git a/src/relay/op/dyn/nn/upsampling.cc b/src/relay/op/dyn/nn/upsampling.cc index 9ed3298142af..8a28475eacd5 100644 --- a/src/relay/op/dyn/nn/upsampling.cc +++ b/src/relay/op/dyn/nn/upsampling.cc @@ -22,13 +22,14 @@ * \brief upsampling operator */ -#include "../../nn/upsampling.h" +#include "upsampling.h" #include #include #include #include +#include #include #include "../../op_common.h" @@ -48,7 +49,6 @@ bool UpSamplingRel(const Array& types, int num_inputs, const Attrs& attrs, if (scale_h == nullptr) return false; if (scale_w == nullptr) return false; - CHECK_EQ(data->shape.size(), 4); CHECK_EQ(scale_h->shape.size(), 0); CHECK_EQ(scale_w->shape.size(), 0); static const Layout kNCHW("NCHW"); diff --git a/src/relay/op/dyn/nn/upsampling.h b/src/relay/op/dyn/nn/upsampling.h new file mode 100644 index 000000000000..79ed65bba36b --- /dev/null +++ b/src/relay/op/dyn/nn/upsampling.h @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file src/relay/op/dyn/nn/upsampling.h + * \brief implementation of the InferCorrectLayout pass for dynamic upsampling + */ + +#ifndef TVM_RELAY_OP_DYN_NN_UPSAMPLING_H_ +#define TVM_RELAY_OP_DYN_NN_UPSAMPLING_H_ + +#include +#include + +#include "../../op_common.h" + +namespace tvm { +namespace relay { +namespace dyn { + +template +Array > UpsamplingInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { + // NOTE: Discard "const" qualifier here. + T* params = const_cast(attrs.as()); + if (new_in_layouts.defined()) { + CHECK_GT(new_in_layouts.size(), 0); + + Layout raw_layout(params->layout); + Layout input = new_in_layouts[0]; + if (input.IndexOf(LayoutAxis::Get('W')) == raw_layout.IndexOf(LayoutAxis::Get('W')) && + input.IndexOf(LayoutAxis::Get('H')) == raw_layout.IndexOf(LayoutAxis::Get('H')) && + !input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h')) && + (input.IndexOf(LayoutAxis::Get('D')) == -1 || + (input.IndexOf(LayoutAxis::Get('D')) == raw_layout.IndexOf(LayoutAxis::Get('D')) && + !input.Contains(LayoutAxis::Get('d'))))) { + params->layout = input.name(); // modify self to follow the input layout + } + } + + Layout inferred_layout(params->layout); + Layout param_layout("NCHW"); + return Array >{{inferred_layout, param_layout, param_layout}, {inferred_layout}}; +} + +} // namespace dyn +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_OP_DYN_NN_UPSAMPLING_H_ diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index a7ae9f77fcb7..58c279d750ec 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -673,6 +673,45 @@ def expected(): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) +def test_alter_layout_nchw_dyn_upsamping_op(): + """Test upsamping operators """ + + def before(): + x = relay.var("x", shape=(1, 32, 28, 28)) + weight = relay.var("weight", shape=(32, 32, 3, 3)) + y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1)) + y = relay.nn.upsampling(y, scale_h=relay.const(2), scale_w=relay.const(2)) + y = relay.nn.avg_pool2d(y, pool_size=(2, 2), strides=(2, 2)) + y = relay.Function(analysis.free_vars(y), y) + return y + + def alter_conv2d(attrs, inputs, tinfos, out_type): + data, weight = inputs + new_attrs = dict(attrs) + new_attrs["data_layout"] = "NCHW16c" + return relay.nn.conv2d(data, weight, **new_attrs) + + def expected(): + x = relay.var("x", shape=(1, 32, 28, 28)) + weight = relay.var("weight") + x = relay.layout_transform(x, "NCHW", "NCHW16c") + y = relay.nn.conv2d( + x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1), data_layout="NCHW16c" + ) + y = relay.nn.upsampling(y, scale_h=relay.const(2), scale_w=relay.const(2), layout="NCHW16c") + y = relay.nn.avg_pool2d(y, pool_size=(2, 2), strides=(2, 2), layout="NCHW16c") + y = relay.layout_transform(y, "NCHW16c", "NCHW") + y = relay.Function(analysis.free_vars(y), y) + return y + + with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d): + a = before() + a = run_opt_pass(a, transform.AlterOpLayout()) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + + @tvm.testing.uses_gpu def test_alter_layout_strided_slice(): """Test rewriting strided_slice during alter_iop_layout"""