Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-374] handle row_sparse weight in parameter and trainer #11001

Merged
merged 23 commits into from
May 29, 2018
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,14 @@ class HybridBlock(Block):
the end-to-end usage.
"""
def __init__(self, prefix=None, params=None):
# check if any parameter is row_sparse
if isinstance(params, ParameterDict):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check shouldn't be done here.
Parameters are only added to the current block when self.params.get is called.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed. Will the checks in param.list_data() and param.data() be sufficient?

for param in params.values():
stype = param._stype
if stype != 'default':
raise ValueError("Cannot create a HybridBlock with Parameter '%s' " \
"because its storage type is %s. Please consider " \
"using a SparseBlock instead."%(param.name, stype))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR for sparse block will be created separately after this one is merged.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"please consider using" -> "please use"

super(HybridBlock, self).__init__(prefix=prefix, params=params)
self._cached_graph = ()
self._cached_op = None
Expand Down Expand Up @@ -713,6 +721,14 @@ def __init__(self, outputs, inputs, params=None):
"Input symbols must be variable, but %s is an output of operators"%str(i)
input_names.add(i.name)

# check if any symbol is row_sparse
row_sparse_storage = ndarray.ndarray._STORAGE_TYPE_STR_TO_ID['row_sparse']
for i in out:
for j in i.get_internals():
assert(j.attr("__storage_type__") != str(row_sparse_storage)), \
"SymbolBlock doesn't support Parameter '%s' because its storage " \
"type is 'row_sparse'." % j.name

for i in out.list_arguments():
if i not in input_names:
self.params.get(i, allow_deferred_init=True)
Expand Down
135 changes: 120 additions & 15 deletions python/mxnet/gluon/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ class Parameter(object):
Weight decay multiplier (L2 regularizer coefficient). Works similar to lr_mult.
init : Initializer, default None
Initializer of this parameter. Will use the global initializer by default.
stype: {'default', 'row_sparse', 'csr'}, defaults to 'default'.
The storage type of the parameter.
grad_stype: {'default', 'row_sparse', 'csr'}, defaults to 'default'.
The storage type of the parameter's gradient.

Expand All @@ -99,12 +101,13 @@ class Parameter(object):
"""
def __init__(self, name, grad_req='write', shape=None, dtype=mx_real_t,
lr_mult=1.0, wd_mult=1.0, init=None, allow_deferred_init=False,
differentiable=True, grad_stype='default'):
differentiable=True, stype='default', grad_stype='default'):
self._var = None
self._data = None
self._grad = None
self._ctx_list = None
self._ctx_map = None
self._trainer = None
self._deferred_init = ()
self._differentiable = differentiable
self._allow_deferred_init = allow_deferred_init
Expand All @@ -116,10 +119,14 @@ def __init__(self, name, grad_req='write', shape=None, dtype=mx_real_t,
self.wd_mult = wd_mult
self.grad_req = grad_req
self.init = init
assert grad_stype in ['default', 'row_sparse', 'csr'], \
"grad_stype for Parameter '%s' must be one of 'default', 'row_sparse', or 'csr'," \
" but got '%s'" % (name, grad_stype)
# sparse related storage type information
valid_stypes = ['default', 'row_sparse', 'csr']
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might as well make it a set.

Copy link
Member Author

@eric-haibin-lin eric-haibin-lin May 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only has 3 elements. I don't think this makes any real difference

assert grad_stype in valid_stypes, "grad_stype for Parameter '%s' must be " \
"one of 'default', 'row_sparse', or 'csr', but got '%s'" % (name, grad_stype)
assert stype in valid_stypes, "stype for Parameter '%s' must be " \
"one of 'default', 'row_sparse', or 'csr', but got '%s'" % (name, stype)
self._grad_stype = grad_stype
self._stype = stype


def __repr__(self):
Expand Down Expand Up @@ -162,6 +169,15 @@ def shape(self, new_shape):

self._shape = new_shape

def _set_trainer(self, trainer):
""" Set the trainer this parameter is associated with. """
if self._trainer and self._trainer is not trainer:
raise RuntimeError(
"Failed to set the trainer for Parameter '%s' to %s because it was set to %s. " \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can user detach a parameter's association with a trainer without exiting python?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. Users can just call _set_trainer(None). I don't think this will be used by common users, hence it remains private

"More than one trainers for a single Parameter is not supported." %(
self.name, str(trainer), str(self._trainer)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does str(trainer) show? It's likely not meaningful to users

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a breaking change.
Suppose users want to use sgd to train 10 epochs and then switch to ADAM, this would prevent that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now only throws exception for rowsparse param

self._trainer = trainer

def _check_and_get(self, arr_list, ctx):
if arr_list is not None:
if ctx is list:
Expand Down Expand Up @@ -194,8 +210,26 @@ def _check_and_get(self, arr_list, ctx):
"because the later does not include Parameters of " \
"nested child Blocks"%(self.name))

def _load_init(self, data, ctx):
def _get_row_sparse(self, arr_list, ctx, row_id):
""" Get row_sparse data from row_sparse parameters based on row_id. """
# get row sparse params based on row ids
if not isinstance(row_id, ndarray.NDArray):
raise TypeError("Cannot get 'row_sparse' Parameter %s with %s type. "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"row_id must have NDArray type, but %s is given"

"NDArray type is expected." % (self.name, type(row_id)))
if not self._trainer:
raise RuntimeError("Cannot get row_sparse data for Parameter '%s' when no " \
"Trainer is created with it."%self.name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if user want to train with single device?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For single device, we will encourage the user to use normal hybrid blocks with sparse_grad=True. There's no need to use rowsparse weight.
Even if the user choose to use rowsparse weight, a kvstore is created for the rowsparse param and the code still works.

results = self._check_and_get(arr_list, ctx)

# fetch row sparse params from the trainer
self._trainer._row_sparse_pull(self, results, row_id)
return results

def _load_init(self, data, ctx, cast_stype=False):
"""(Re)initializes by loading from data."""
if self._trainer and self._trainer._kv_initialized and self._trainer._update_on_kvstore:
raise RuntimeError("Cannot (Re)initialize Parameter '%s' when its Trainer " \
"already initialized the parameter on KVStore."%(self.name))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

message is cryptic. The reason is multi device training and update_on_kvstore is true.
error message should describe the reason and suggest a solution

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated message.

if self.shape:
for self_dim, data_dim in zip(self.shape, data.shape):
assert self_dim == 0 or self_dim == data_dim, \
Expand All @@ -208,6 +242,14 @@ def _load_init(self, data, ctx):
"Failed loading Parameter '%s' from saved params: " \
"dtype incompatible expected %s vs saved %s"%(
self.name, str(self.dtype), str(data.dtype))
if self._stype != data.stype:
if not cast_stype:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is cast_stype needed? Why not always cast?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to always cast

raise RuntimeError("Failed loading Parameter '%s' from saved params: storage " \
"type incompatible expected %s vs saved %s. Set " \
"cast_stype=True to cast saved params to the same stype " \
"as '%s'."%(self.name, self._stype, data.stype, self.name))
else:
data = data.tostype(self._stype)
if isinstance(ctx, Context):
ctx = [ctx]
if self._data is None:
Expand Down Expand Up @@ -243,7 +285,7 @@ def _finish_deferred_init(self):
with autograd.pause():
if data is None:
data = ndarray.zeros(shape=self.shape, dtype=self.dtype,
ctx=context.cpu())
ctx=context.cpu(), stype=self._stype)
initializer.create(default_init)(
initializer.InitDesc(self.name, {'__init__': init}), data)

Expand Down Expand Up @@ -271,12 +313,18 @@ def _init_grad(self):
self._grad = [ndarray.zeros(shape=i.shape, dtype=i.dtype, ctx=i.context,
stype=self._grad_stype) for i in self._data]

autograd.mark_variables(self.list_data(), self.list_grad(), self.grad_req)
autograd.mark_variables(self._check_and_get(self._data, list),
self._grad, self.grad_req)

def _reduce(self):
"""Reduce data from multiple context."""
block = self.list_data()
data = ndarray.add_n(*(w.copyto(context.cpu()) for w in block)) / len(block)
if self._stype == 'default':
block = self.list_data()
data = ndarray.add_n(*(w.copyto(context.cpu()) for w in block)) / len(block)
else:
# fetch all rows for 'row_sparse' param
all_row_ids = ndarray.arange(0, self.shape[0], dtype='int64', ctx=context.cpu())
data = self.row_sparse_data(all_row_ids)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to have row_sparse but update_on_kvstore=false?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently when gluon sees rowsparse weight, it always creates a kvstore and set update to kvstore=True.

return data

def initialize(self, init=None, ctx=None, default_init=initializer.Uniform(),
Expand Down Expand Up @@ -380,12 +428,55 @@ def set_data(self, data):
self._deferred_init = self._deferred_init[:3] + (data,)
return

for arr in self.list_data():
for arr in self._check_and_get(self._data, list):
arr[:] = data

def row_sparse_data(self, row_id):
"""Returns a copy of the 'row_sparse' parameter on the same context as row_id's.
The copy only retains rows whose ids occur in provided row ids.
The parameter must have been initialized on this context before.

Parameters
----------
row_id: NDArray
Row ids to retain for the 'row_sparse' parameter.

Returns
-------
NDArray on row_id's context
"""
if self._stype != 'row_sparse':
raise ValueError("Cannot return a copy of Parameter %s via row_sparse_data() " \
"because its storage type is %s. Please use data() instead." \
%(self.name, self._stype))
return self._get_row_sparse(self._data, row_id.context, row_id)

def list_row_sparse_data(self, row_id):
"""Returns copies of the 'row_sparse' parameter on all contexts, in the same order
as creation. The copy only retains rows whose ids occur in provided row ids.
The parameter must have been initialized before.

Parameters
----------
ctx : Context
Desired context.
row_id: NDArray
Row ids to retain for the 'row_sparse' parameter.

Returns
-------
list of NDArrays
"""
if self._stype != 'row_sparse':
raise ValueError("Cannot return copies of Parameter '%s' on all contexts via " \
"list_row_sparse_data() because its storage type is %s. Please " \
"use data() instead." % (self.name, self._stype))
return self._get_row_sparse(self._data, list, row_id)

def data(self, ctx=None):
"""Returns a copy of this parameter on one context. Must have been
initialized on this context before.
initialized on this context before. For sparse parameters, use
:py:meth:`Parameter.row_sparse_data` instead.

Parameters
----------
Expand All @@ -396,11 +487,25 @@ def data(self, ctx=None):
-------
NDArray on ctx
"""
if self._stype != 'default':
raise ValueError("Cannot return a copy of Parameter '%s' on ctx %s via data() " \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should be UserError?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I should change to RuntimeError? There's UserWarning but I am not aware of UserError

"because its storage type is %s. Please use row_sparse_data() " \
"instead." % (self.name, str(ctx), self._stype))
return self._check_and_get(self._data, ctx)

def list_data(self):
"""Returns copies of this parameter on all contexts, in the same order
as creation."""
as creation. For sparse parameters, use :py:meth:`Parameter.list_row_sparse_data`
instead.

Returns
-------
list of NDArrays
"""
if self._stype != 'default':
raise ValueError("Cannot return copies of Parameter '%s' on all contexts via " \
"list_data() because its storage type is %s. Please use " \
"row_sparse_data() instead." % (self.name, self._stype))
return self._check_and_get(self._data, list)

def grad(self, ctx=None):
Expand Down Expand Up @@ -447,7 +552,7 @@ def var(self):
if self._var is None:
self._var = symbol.var(self.name, shape=self.shape, dtype=self.dtype,
lr_mult=self.lr_mult, wd_mult=self.wd_mult,
init=self.init)
init=self.init, stype=self._stype)
return self._var

def cast(self, dtype):
Expand Down Expand Up @@ -766,7 +871,7 @@ def save(self, filename, strip_prefix=''):
ndarray.save(filename, arg_dict)

def load(self, filename, ctx=None, allow_missing=False,
ignore_extra=False, restore_prefix=''):
ignore_extra=False, restore_prefix='', cast_stype=False):
"""Load parameters from file.

filename : str
Expand Down Expand Up @@ -804,4 +909,4 @@ def load(self, filename, ctx=None, allow_missing=False,
"Please make sure source and target networks have the same prefix."%(
name[lprefix:], filename, _brief_print_list(self._params.keys()))
continue
self[name]._load_init(arg_dict[name], ctx)
self[name]._load_init(arg_dict[name], ctx, cast_stype=cast_stype)
Loading