diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index def5d145f80e..64d1823f0ce3 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -324,7 +324,7 @@ def __init__(self, prefix=None, params=None): self._reg_params = {} self._cached_graph = () self._cached_op = None - self._cached_params = None + self._cached_params = () self._out_format = None self._in_format = None self._active = False @@ -384,32 +384,33 @@ def _build_cache(self, *args): inputs, out = self._get_graph(*args) self._cached_op = ndarray.CachedOp(out) - params = dict(self.collect_params().items()) - self._cached_params = [params.get(name, None) for name in out.list_inputs()] + # Must not change the following dict + params = self.collect_params()._params + self._cached_params = tuple(params.get(name, None) for name in out.list_inputs()) assert len(params) + len(self._cached_graph[0]) == len(out.list_inputs()), \ "Wrong number of inputs." name2pos = {var.name: i for i, var in enumerate(inputs)} - self._in_idx = [(i, name2pos[name]) for i, name in enumerate(out.list_inputs()) - if name not in params] + self._in_idx = tuple(name2pos[name] if name not in params else -1 + for name in out.list_inputs()) def _call_cached_op(self, *args): if self._cached_op is None: self._build_cache(*args) + args, fmt = _flatten(args) + assert fmt == self._in_format, "Invalid input format" + try: - cargs = [i.data() if i else None for i in self._cached_params] + cargs = tuple(args[self._in_idx[i]] if self._in_idx[i] != -1 else v for i, v + in enumerate(p.data() if p else None for p in self._cached_params)) except DeferredInitializationError: self.infer_shape(*args) for i in self._cached_params: if i is not None: i._finish_deferred_init() - cargs = [i.data() if i else None for i in self._cached_params] - - args, fmt = _flatten(args) - assert fmt == self._in_format, "Invalid input format" - for i, j in self._in_idx: - cargs[i] = args[j] + cargs = tuple(args[self._in_idx[i]] if self._in_idx[i] != -1 else v for i, v + in enumerate(p.data() if p else None for p in self._cached_params)) out = self._cached_op(*cargs) if isinstance(out, NDArray): out = [out]