Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
update args
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Oct 5, 2017
1 parent 03a6403 commit 7b54584
Showing 1 changed file with 10 additions and 12 deletions.
22 changes: 10 additions & 12 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def __init__(self, prefix=None, params=None):
self._reg_params = {}
self._cached_graph = ()
self._cached_op = None
self._cached_params = ()
self._cached_call_params = ()
self._out_format = None
self._in_format = None
self._active = False
Expand Down Expand Up @@ -382,18 +382,16 @@ def infer_shape(self, *args):

def _build_cache(self, *args):
inputs, out = self._get_graph(*args)
name2pos = {var.name: i for i, var in enumerate(inputs)}
self._cached_op = ndarray.CachedOp(out)

# Must not change the following dict
params = self.collect_params()._params
self._cached_params = tuple(params.get(name, None) for name in out.list_inputs())
self._cached_call_params = tuple(params[name] if name in params else name2pos[name]
for name in out.list_inputs())
assert len(params) + len(self._cached_graph[0]) == len(out.list_inputs()), \
"Wrong number of inputs."

name2pos = {var.name: i for i, var in enumerate(inputs)}
self._in_idx = tuple(name2pos[name] if name not in params else -1
for name in out.list_inputs())

def _call_cached_op(self, *args):
if self._cached_op is None:
self._build_cache(*args)
Expand All @@ -402,15 +400,15 @@ def _call_cached_op(self, *args):
assert fmt == self._in_format, "Invalid input format"

try:
cargs = tuple(args[self._in_idx[i]] if self._in_idx[i] != -1 else v for i, v
in enumerate(p.data() if p else None for p in self._cached_params))
cargs = tuple(p.data() if isinstance(p, Parameter) else args[p]
for p in self._cached_call_params)
except DeferredInitializationError:
self.infer_shape(*args)
for i in self._cached_params:
if i is not None:
for i in self._cached_call_params:
if isinstance(i, Parameter):
i._finish_deferred_init()
cargs = tuple(args[self._in_idx[i]] if self._in_idx[i] != -1 else v for i, v
in enumerate(p.data() if p else None for p in self._cached_params))
cargs = tuple(p.data() if isinstance(p, Parameter) else args[p]
for p in self._cached_call_params)
out = self._cached_op(*cargs)
if isinstance(out, NDArray):
out = [out]
Expand Down

0 comments on commit 7b54584

Please sign in to comment.