diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 64d1823f0ce3..f448dfa6a0b5 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -324,7 +324,7 @@ def __init__(self, prefix=None, params=None): self._reg_params = {} self._cached_graph = () self._cached_op = None - self._cached_params = () + self._cached_call_params = () self._out_format = None self._in_format = None self._active = False @@ -382,18 +382,16 @@ def infer_shape(self, *args): def _build_cache(self, *args): inputs, out = self._get_graph(*args) + name2pos = {var.name: i for i, var in enumerate(inputs)} self._cached_op = ndarray.CachedOp(out) # Must not change the following dict params = self.collect_params()._params - self._cached_params = tuple(params.get(name, None) for name in out.list_inputs()) + self._cached_call_params = tuple(params[name] if name in params else name2pos[name] + for name in out.list_inputs()) assert len(params) + len(self._cached_graph[0]) == len(out.list_inputs()), \ "Wrong number of inputs." - name2pos = {var.name: i for i, var in enumerate(inputs)} - self._in_idx = tuple(name2pos[name] if name not in params else -1 - for name in out.list_inputs()) - def _call_cached_op(self, *args): if self._cached_op is None: self._build_cache(*args) @@ -402,15 +400,15 @@ def _call_cached_op(self, *args): assert fmt == self._in_format, "Invalid input format" try: - cargs = tuple(args[self._in_idx[i]] if self._in_idx[i] != -1 else v for i, v - in enumerate(p.data() if p else None for p in self._cached_params)) + cargs = tuple(p.data() if isinstance(p, Parameter) else args[p] + for p in self._cached_call_params) except DeferredInitializationError: self.infer_shape(*args) - for i in self._cached_params: - if i is not None: + for i in self._cached_call_params: + if isinstance(i, Parameter): i._finish_deferred_init() - cargs = tuple(args[self._in_idx[i]] if self._in_idx[i] != -1 else v for i, v - in enumerate(p.data() if p else None for p in self._cached_params)) + cargs = tuple(p.data() if isinstance(p, Parameter) else args[p] + for p in self._cached_call_params) out = self._cached_op(*cargs) if isinstance(out, NDArray): out = [out]