Skip to content

Commit

Permalink
register Realize Rewrite for global avg pool and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 11, 2019
1 parent db61cac commit c48f82b
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/relay/pass/quantize/realize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,9 @@ Expr AvgPoolRealize(const Call& ref_call,
RELAY_REGISTER_OP("nn.avg_pool2d")
.set_attr<FForwardRewrite>("FQRealizeRewrite", AvgPoolRealize);

RELAY_REGISTER_OP("nn.global_avg_pool2d")
.set_attr<FForwardRewrite>("FQRealizeRewrite", AvgPoolRealize);

Expr CastHintRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
Expand Down
40 changes: 40 additions & 0 deletions tests/python/relay/test_pass_auto_quantize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import relay
from tvm.relay import testing


def test_mul_rewrite():
data = relay.var("data", shape=(1, 16, 64, 64))
conv = relay.nn.conv2d(data, relay.var("weight"),
kernel_size=(3, 3),
padding=(1, 1),
channels=16)
act = relay.nn.relu(data=conv)
pool = relay.nn.global_avg_pool2d(data=act)
f = relay.Function(relay.analysis.free_vars(act), act * pool)
mod, params = testing.create_workload(f)

with relay.quantize.qconfig(skip_conv_layers=[]):
qmod = relay.quantize.quantize(mod, params)

relay.build(qmod, "llvm", params=params)


if __name__ == "__main__":
test_mul_rewrite()

0 comments on commit c48f82b

Please sign in to comment.