Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix legacy codepath detection feature for decorated HybridBlocks (#19143
Browse files Browse the repository at this point in the history
)

if self.hybrid_forward.__func__ is not HybridBlock.hybrid_forward check for detecting legacy Blocks yields false positives on Gluon 2 Blocks if they are wrapped with a class decorator. This leads to hybridization to silently fail on Gluon 2 Blocks that make use of class decorator such as @use_np.
  • Loading branch information
leezu authored Sep 15, 2020
1 parent 2697573 commit 179262b
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 20 deletions.
7 changes: 4 additions & 3 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
__all__ = ['Block', 'HybridBlock', 'SymbolBlock']

import copy
import inspect
import warnings
import weakref
from collections import OrderedDict, defaultdict
Expand Down Expand Up @@ -984,7 +985,7 @@ def _get_graph_v2(self, *args):

def _get_graph(self, *args):
if not self._cached_graph:
if self.hybrid_forward.__func__ is not HybridBlock.hybrid_forward: # Gluon 1
if inspect.unwrap(self.hybrid_forward.__func__) is not HybridBlock.hybrid_forward:
return self._get_graph_v1(*args)
else: # Gluon 2 based on deferred compute mode
return self._get_graph_v2(*args)
Expand Down Expand Up @@ -1277,7 +1278,7 @@ def _infer_attrs(self, infer_fn, attr, *args):

def infer_shape(self, *args):
"""Infers shape of Parameters from inputs."""
if self.hybrid_forward.__func__ is not HybridBlock.hybrid_forward:
if inspect.unwrap(self.hybrid_forward.__func__) is not HybridBlock.hybrid_forward:
# Gluon 1 based on F: hybrid_forward is defined by user
self._infer_attrs('infer_shape', 'shape', *args)
else:
Expand Down Expand Up @@ -1388,7 +1389,7 @@ def c_callback(name, op_name, array):
cld()._monitor_all = monitor_all

def __call__(self, x, *args):
if self.hybrid_forward.__func__ is not HybridBlock.hybrid_forward:
if inspect.unwrap(self.hybrid_forward.__func__) is not HybridBlock.hybrid_forward:
# Gluon 1 based on F: hybrid_forward is defined by user
return super().__call__(x, *args)
else: # Gluon 2 based on deferred compute mode
Expand Down
41 changes: 24 additions & 17 deletions tests/python/unittest/test_deferred_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import functools
import operator
import tempfile

import numpy as np

Expand Down Expand Up @@ -306,8 +307,7 @@ def test_dc_dynamic_shape():
def f(a, *, nd):
return [mx.nd.np.flatnonzero(a)]

# Skip GraphExecutor test due to https://github.com/apache/incubator-mxnet/issues/17810
for mode in ('imperative', 'imperativewithnondccompute'):
for mode in ('imperative', 'imperativewithnondccompute', 'symbolic', 'all'):
_assert_dc(_dc_simple_setup, f, mode=mode, numpy=True)


Expand Down Expand Up @@ -338,10 +338,6 @@ def f(a, *, nd):


def test_dc_simple_boolean_indexing():
if mx.test_utils.default_context() == mx.gpu(0) and mx.runtime.Features().is_enabled("TVM_OP"):
# Skip due to https://github.com/apache/incubator-mxnet/issues/17886
return

def setup(*, nd):
assert nd is mx.np
x = mx.np.array([[0, 1], [1, 1], [2, 2]])
Expand All @@ -351,10 +347,6 @@ def f(a, idx, *, nd):
assert nd is mx.np
return [a[idx].reshape((2, 2))]

# Skip GraphExecutor test due to https://github.com/apache/incubator-mxnet/issues/17810
for mode in ('imperative', 'imperativewithnondccompute'):
_assert_dc(setup, f, mode=mode)


def test_dc_list_indexing_error():
def f(a, *, nd):
Expand Down Expand Up @@ -428,6 +420,8 @@ def _assert_dc_gluon(setup, net, setup_is_deterministic=True, numpy=True, autogr

_all_same(ys_np, ys_hybrid_np)

with tempfile.TemporaryDirectory() as root:
net.export(root)

def _dc_gluon_simple_setup(shape=(8, 10), *, nd):
return [nd.ones(shape=shape, ctx=mx.context.current_context())]
Expand All @@ -452,12 +446,29 @@ def forward(self, x):
net = MyBlock()
net.initialize(ctx=contexts)
_assert_dc_gluon(_dc_gluon_simple_setup, net, numpy=False, ctx=ctx)
with mx.util.np_array(True):
with mx.util.np_shape(True), mx.util.np_array(True):
net = MyBlock()
net.initialize(ctx=contexts)
_assert_dc_gluon(_dc_gluon_simple_setup, net, numpy=True, ctx=ctx)


def test_dc_hybridblock_wrapped():
@mx.util.use_np
class MyBlock(mx.gluon.HybridBlock):
def __init__(self):
super().__init__()
self.dense = mx.gluon.nn.Dense(units=10, in_units=10)
self.weight = mx.gluon.Parameter('weight', shape=(10, ))

def forward(self, x):
assert x.shape[1] == 10 # due to in_units=10 above
return self.dense(x) + self.weight.data(x.context)

net = MyBlock()
net.initialize()
_assert_dc_gluon(_dc_gluon_simple_setup, net, numpy=True)


def test_dc_hybridblock_deferred_init_no_infer_shape_error():
class MyBlock(mx.gluon.HybridBlock):
def __init__(self):
Expand Down Expand Up @@ -491,17 +502,13 @@ def forward(self, x):
net = MyBlock()
net.initialize()
_assert_dc_gluon(_dc_gluon_simple_setup, net, numpy=False)
with mx.util.np_array(True):
with mx.util.np_shape(True), mx.util.np_array(True):
net = MyBlock()
net.initialize()
_assert_dc_gluon(_dc_gluon_simple_setup, net, numpy=True)


def test_dc_hybridblock_dynamic_shape():
if mx.test_utils.default_context() == mx.gpu(0) and mx.runtime.Features().is_enabled("TVM_OP"):
# Skip due to https://github.com/apache/incubator-mxnet/issues/17886
return

class MyBlock(mx.gluon.HybridBlock):
def __init__(self):
super().__init__()
Expand All @@ -515,7 +522,7 @@ def setup(*, nd):
x = mx.np.array([[0, 1], [1, 1], [2, 2]])
return [x, x < 2]

with mx.util.np_array(True):
with mx.util.np_shape(True), mx.util.np_array(True):
net = MyBlock()
net.initialize()
_assert_dc_gluon(setup, net, numpy=True)
Expand Down

0 comments on commit 179262b

Please sign in to comment.