Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
memory save
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Oct 5, 2017
1 parent 8a4221b commit 4b05404
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def __init__(self, prefix=None, params=None):
self._reg_params = {}
self._cached_graph = ()
self._cached_op = None
self._cached_params = None
self._cached_params = ()
self._out_format = None
self._in_format = None
self._active = False
Expand Down Expand Up @@ -384,32 +384,33 @@ def _build_cache(self, *args):
inputs, out = self._get_graph(*args)
self._cached_op = ndarray.CachedOp(out)

params = dict(self.collect_params().items())
self._cached_params = [params.get(name, None) for name in out.list_inputs()]
# Must not change the following dict
params = self.collect_params()._params
self._cached_params = tuple(params.get(name, None) for name in out.list_inputs())
assert len(params) + len(self._cached_graph[0]) == len(out.list_inputs()), \
"Wrong number of inputs."

name2pos = {var.name: i for i, var in enumerate(inputs)}
self._in_idx = [(i, name2pos[name]) for i, name in enumerate(out.list_inputs())
if name not in params]
self._in_idx = tuple(name2pos[name] if name not in params else -1
for name in out.list_inputs())

def _call_cached_op(self, *args):
if self._cached_op is None:
self._build_cache(*args)

args, fmt = _flatten(args)
assert fmt == self._in_format, "Invalid input format"

try:
cargs = [i.data() if i else None for i in self._cached_params]
cargs = tuple(args[self._in_idx[i]] if self._in_idx[i] != -1 else v for i, v
in enumerate(p.data() if p else None for p in self._cached_params))
except DeferredInitializationError:
self.infer_shape(*args)
for i in self._cached_params:
if i is not None:
i._finish_deferred_init()
cargs = [i.data() if i else None for i in self._cached_params]

args, fmt = _flatten(args)
assert fmt == self._in_format, "Invalid input format"
for i, j in self._in_idx:
cargs[i] = args[j]
cargs = tuple(args[self._in_idx[i]] if self._in_idx[i] != -1 else v for i, v
in enumerate(p.data() if p else None for p in self._cached_params))
out = self._cached_op(*cargs)
if isinstance(out, NDArray):
out = [out]
Expand Down

0 comments on commit 4b05404

Please sign in to comment.