Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[1.7] Pass args fix3 (#18237)
Browse files Browse the repository at this point in the history
* fixed overwrite of args/aux variables

* fixed spacing
  • Loading branch information
samskalicky authored May 6, 2020
1 parent 295e939 commit 80baab8
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1484,18 +1484,18 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs):
assert isinstance(backend, str)

if args is None or len(args) == 0:
args = []
args_ = []
args_handle = c_array(NDArrayHandle, [])
else:
args_handle, args = self._get_ndarray_inputs('args', args,
self.list_arguments(), False)
args_handle, args_ = self._get_ndarray_inputs('args', args,
self.list_arguments(), False)

if aux is None or len(aux) == 0:
aux = []
aux_ = []
aux_handle = c_array(NDArrayHandle, [])
else:
aux_handle, aux = self._get_ndarray_inputs('aux_states', aux,
self.list_auxiliary_states(), False)
aux_handle, aux_ = self._get_ndarray_inputs('aux_states', aux,
self.list_auxiliary_states(), False)
if ctx is None:
ctx = current_context()
assert isinstance(ctx, Context)
Expand All @@ -1516,9 +1516,9 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs):
c_str(backend),
ctypes.c_int(ctx.device_typeid),
ctypes.byref(out),
mx_uint(len(args)),
mx_uint(len(args_)),
args_handle,
mx_uint(len(aux)),
mx_uint(len(aux_)),
aux_handle,
mx_uint(len(key_list)),
c_str_array(key_list),
Expand Down

0 comments on commit 80baab8

Please sign in to comment.