-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-374] handle row_sparse weight in parameter and trainer #11001
Changes from 12 commits
2863a1f
e3d20c7
ad672a7
674d374
6db6e29
6f0f403
8db0499
60d9f16
83009bc
cf006c8
4e9ab9c
a991e98
468b599
0f70344
ff9bf84
bee6774
077b7a5
6038fe9
70de567
12a8b59
2a06884
fbcf15d
01b3e4d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -445,6 +445,14 @@ class HybridBlock(Block): | |
the end-to-end usage. | ||
""" | ||
def __init__(self, prefix=None, params=None): | ||
# check if any parameter is row_sparse | ||
if isinstance(params, ParameterDict): | ||
for param in params.values(): | ||
stype = param._stype | ||
if stype != 'default': | ||
raise ValueError("Cannot create a HybridBlock with Parameter '%s' " \ | ||
"because its storage type is %s. Please consider " \ | ||
"using a SparseBlock instead."%(param.name, stype)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. PR for sparse block will be created separately after this one is merged. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "please consider using" -> "please use" |
||
super(HybridBlock, self).__init__(prefix=prefix, params=params) | ||
self._cached_graph = () | ||
self._cached_op = None | ||
|
@@ -713,6 +721,14 @@ def __init__(self, outputs, inputs, params=None): | |
"Input symbols must be variable, but %s is an output of operators"%str(i) | ||
input_names.add(i.name) | ||
|
||
# check if any symbol is row_sparse | ||
row_sparse_storage = ndarray.ndarray._STORAGE_TYPE_STR_TO_ID['row_sparse'] | ||
for i in out: | ||
for j in i.get_internals(): | ||
assert(j.attr("__storage_type__") != str(row_sparse_storage)), \ | ||
"SymbolBlock doesn't support Parameter '%s' because its storage " \ | ||
"type is 'row_sparse'." % j.name | ||
|
||
for i in out.list_arguments(): | ||
if i not in input_names: | ||
self.params.get(i, allow_deferred_init=True) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -81,6 +81,8 @@ class Parameter(object): | |
Weight decay multiplier (L2 regularizer coefficient). Works similar to lr_mult. | ||
init : Initializer, default None | ||
Initializer of this parameter. Will use the global initializer by default. | ||
stype: {'default', 'row_sparse', 'csr'}, defaults to 'default'. | ||
The storage type of the parameter. | ||
grad_stype: {'default', 'row_sparse', 'csr'}, defaults to 'default'. | ||
The storage type of the parameter's gradient. | ||
|
||
|
@@ -99,12 +101,13 @@ class Parameter(object): | |
""" | ||
def __init__(self, name, grad_req='write', shape=None, dtype=mx_real_t, | ||
lr_mult=1.0, wd_mult=1.0, init=None, allow_deferred_init=False, | ||
differentiable=True, grad_stype='default'): | ||
differentiable=True, stype='default', grad_stype='default'): | ||
self._var = None | ||
self._data = None | ||
self._grad = None | ||
self._ctx_list = None | ||
self._ctx_map = None | ||
self._trainer = None | ||
self._deferred_init = () | ||
self._differentiable = differentiable | ||
self._allow_deferred_init = allow_deferred_init | ||
|
@@ -116,10 +119,14 @@ def __init__(self, name, grad_req='write', shape=None, dtype=mx_real_t, | |
self.wd_mult = wd_mult | ||
self.grad_req = grad_req | ||
self.init = init | ||
assert grad_stype in ['default', 'row_sparse', 'csr'], \ | ||
"grad_stype for Parameter '%s' must be one of 'default', 'row_sparse', or 'csr'," \ | ||
" but got '%s'" % (name, grad_stype) | ||
# sparse related storage type information | ||
valid_stypes = ['default', 'row_sparse', 'csr'] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. might as well make it a set. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only has 3 elements. I don't think this makes any real difference |
||
assert grad_stype in valid_stypes, "grad_stype for Parameter '%s' must be " \ | ||
"one of 'default', 'row_sparse', or 'csr', but got '%s'" % (name, grad_stype) | ||
assert stype in valid_stypes, "stype for Parameter '%s' must be " \ | ||
"one of 'default', 'row_sparse', or 'csr', but got '%s'" % (name, stype) | ||
self._grad_stype = grad_stype | ||
self._stype = stype | ||
|
||
|
||
def __repr__(self): | ||
|
@@ -162,6 +169,15 @@ def shape(self, new_shape): | |
|
||
self._shape = new_shape | ||
|
||
def _set_trainer(self, trainer): | ||
""" Set the trainer this parameter is associated with. """ | ||
if self._trainer and self._trainer is not trainer: | ||
raise RuntimeError( | ||
"Failed to set the trainer for Parameter '%s' to %s because it was set to %s. " \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How can user detach a parameter's association with a trainer without exiting python? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated. Users can just call |
||
"More than one trainers for a single Parameter is not supported." %( | ||
self.name, str(trainer), str(self._trainer))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what does str(trainer) show? It's likely not meaningful to users There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a breaking change. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now only throws exception for rowsparse param |
||
self._trainer = trainer | ||
|
||
def _check_and_get(self, arr_list, ctx): | ||
if arr_list is not None: | ||
if ctx is list: | ||
|
@@ -194,8 +210,26 @@ def _check_and_get(self, arr_list, ctx): | |
"because the later does not include Parameters of " \ | ||
"nested child Blocks"%(self.name)) | ||
|
||
def _load_init(self, data, ctx): | ||
def _get_row_sparse(self, arr_list, ctx, row_id): | ||
""" Get row_sparse data from row_sparse parameters based on row_id. """ | ||
# get row sparse params based on row ids | ||
if not isinstance(row_id, ndarray.NDArray): | ||
raise TypeError("Cannot get 'row_sparse' Parameter %s with %s type. " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "row_id must have NDArray type, but %s is given" |
||
"NDArray type is expected." % (self.name, type(row_id))) | ||
if not self._trainer: | ||
raise RuntimeError("Cannot get row_sparse data for Parameter '%s' when no " \ | ||
"Trainer is created with it."%self.name) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if user want to train with single device? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For single device, we will encourage the user to use normal hybrid blocks with sparse_grad=True. There's no need to use rowsparse weight. |
||
results = self._check_and_get(arr_list, ctx) | ||
|
||
# fetch row sparse params from the trainer | ||
self._trainer._row_sparse_pull(self, results, row_id) | ||
return results | ||
|
||
def _load_init(self, data, ctx, cast_stype=False): | ||
"""(Re)initializes by loading from data.""" | ||
if self._trainer and self._trainer._kv_initialized and self._trainer._update_on_kvstore: | ||
raise RuntimeError("Cannot (Re)initialize Parameter '%s' when its Trainer " \ | ||
"already initialized the parameter on KVStore."%(self.name)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. message is cryptic. The reason is multi device training and update_on_kvstore is true. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated message. |
||
if self.shape: | ||
for self_dim, data_dim in zip(self.shape, data.shape): | ||
assert self_dim == 0 or self_dim == data_dim, \ | ||
|
@@ -208,6 +242,14 @@ def _load_init(self, data, ctx): | |
"Failed loading Parameter '%s' from saved params: " \ | ||
"dtype incompatible expected %s vs saved %s"%( | ||
self.name, str(self.dtype), str(data.dtype)) | ||
if self._stype != data.stype: | ||
if not cast_stype: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is cast_stype needed? Why not always cast? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changed to always cast |
||
raise RuntimeError("Failed loading Parameter '%s' from saved params: storage " \ | ||
"type incompatible expected %s vs saved %s. Set " \ | ||
"cast_stype=True to cast saved params to the same stype " \ | ||
"as '%s'."%(self.name, self._stype, data.stype, self.name)) | ||
else: | ||
data = data.tostype(self._stype) | ||
if isinstance(ctx, Context): | ||
ctx = [ctx] | ||
if self._data is None: | ||
|
@@ -243,7 +285,7 @@ def _finish_deferred_init(self): | |
with autograd.pause(): | ||
if data is None: | ||
data = ndarray.zeros(shape=self.shape, dtype=self.dtype, | ||
ctx=context.cpu()) | ||
ctx=context.cpu(), stype=self._stype) | ||
initializer.create(default_init)( | ||
initializer.InitDesc(self.name, {'__init__': init}), data) | ||
|
||
|
@@ -271,12 +313,18 @@ def _init_grad(self): | |
self._grad = [ndarray.zeros(shape=i.shape, dtype=i.dtype, ctx=i.context, | ||
stype=self._grad_stype) for i in self._data] | ||
|
||
autograd.mark_variables(self.list_data(), self.list_grad(), self.grad_req) | ||
autograd.mark_variables(self._check_and_get(self._data, list), | ||
self._grad, self.grad_req) | ||
|
||
def _reduce(self): | ||
"""Reduce data from multiple context.""" | ||
block = self.list_data() | ||
data = ndarray.add_n(*(w.copyto(context.cpu()) for w in block)) / len(block) | ||
if self._stype == 'default': | ||
block = self.list_data() | ||
data = ndarray.add_n(*(w.copyto(context.cpu()) for w in block)) / len(block) | ||
else: | ||
# fetch all rows for 'row_sparse' param | ||
all_row_ids = ndarray.arange(0, self.shape[0], dtype='int64', ctx=context.cpu()) | ||
data = self.row_sparse_data(all_row_ids) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to have row_sparse but update_on_kvstore=false? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently when gluon sees rowsparse weight, it always creates a kvstore and set update to kvstore=True. |
||
return data | ||
|
||
def initialize(self, init=None, ctx=None, default_init=initializer.Uniform(), | ||
|
@@ -380,12 +428,55 @@ def set_data(self, data): | |
self._deferred_init = self._deferred_init[:3] + (data,) | ||
return | ||
|
||
for arr in self.list_data(): | ||
for arr in self._check_and_get(self._data, list): | ||
arr[:] = data | ||
|
||
def row_sparse_data(self, row_id): | ||
"""Returns a copy of the 'row_sparse' parameter on the same context as row_id's. | ||
The copy only retains rows whose ids occur in provided row ids. | ||
The parameter must have been initialized on this context before. | ||
|
||
Parameters | ||
---------- | ||
row_id: NDArray | ||
Row ids to retain for the 'row_sparse' parameter. | ||
|
||
Returns | ||
------- | ||
NDArray on row_id's context | ||
""" | ||
if self._stype != 'row_sparse': | ||
raise ValueError("Cannot return a copy of Parameter %s via row_sparse_data() " \ | ||
"because its storage type is %s. Please use data() instead." \ | ||
%(self.name, self._stype)) | ||
return self._get_row_sparse(self._data, row_id.context, row_id) | ||
|
||
def list_row_sparse_data(self, row_id): | ||
"""Returns copies of the 'row_sparse' parameter on all contexts, in the same order | ||
as creation. The copy only retains rows whose ids occur in provided row ids. | ||
The parameter must have been initialized before. | ||
|
||
Parameters | ||
---------- | ||
ctx : Context | ||
Desired context. | ||
row_id: NDArray | ||
Row ids to retain for the 'row_sparse' parameter. | ||
|
||
Returns | ||
------- | ||
list of NDArrays | ||
""" | ||
if self._stype != 'row_sparse': | ||
raise ValueError("Cannot return copies of Parameter '%s' on all contexts via " \ | ||
"list_row_sparse_data() because its storage type is %s. Please " \ | ||
"use data() instead." % (self.name, self._stype)) | ||
return self._get_row_sparse(self._data, list, row_id) | ||
|
||
def data(self, ctx=None): | ||
"""Returns a copy of this parameter on one context. Must have been | ||
initialized on this context before. | ||
initialized on this context before. For sparse parameters, use | ||
:py:meth:`Parameter.row_sparse_data` instead. | ||
|
||
Parameters | ||
---------- | ||
|
@@ -396,11 +487,25 @@ def data(self, ctx=None): | |
------- | ||
NDArray on ctx | ||
""" | ||
if self._stype != 'default': | ||
raise ValueError("Cannot return a copy of Parameter '%s' on ctx %s via data() " \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These should be UserError? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe I should change to RuntimeError? There's UserWarning but I am not aware of UserError |
||
"because its storage type is %s. Please use row_sparse_data() " \ | ||
"instead." % (self.name, str(ctx), self._stype)) | ||
return self._check_and_get(self._data, ctx) | ||
|
||
def list_data(self): | ||
"""Returns copies of this parameter on all contexts, in the same order | ||
as creation.""" | ||
as creation. For sparse parameters, use :py:meth:`Parameter.list_row_sparse_data` | ||
instead. | ||
|
||
Returns | ||
------- | ||
list of NDArrays | ||
""" | ||
if self._stype != 'default': | ||
raise ValueError("Cannot return copies of Parameter '%s' on all contexts via " \ | ||
"list_data() because its storage type is %s. Please use " \ | ||
"row_sparse_data() instead." % (self.name, self._stype)) | ||
return self._check_and_get(self._data, list) | ||
|
||
def grad(self, ctx=None): | ||
|
@@ -447,7 +552,7 @@ def var(self): | |
if self._var is None: | ||
self._var = symbol.var(self.name, shape=self.shape, dtype=self.dtype, | ||
lr_mult=self.lr_mult, wd_mult=self.wd_mult, | ||
init=self.init) | ||
init=self.init, stype=self._stype) | ||
return self._var | ||
|
||
def cast(self, dtype): | ||
|
@@ -766,7 +871,7 @@ def save(self, filename, strip_prefix=''): | |
ndarray.save(filename, arg_dict) | ||
|
||
def load(self, filename, ctx=None, allow_missing=False, | ||
ignore_extra=False, restore_prefix=''): | ||
ignore_extra=False, restore_prefix='', cast_stype=False): | ||
"""Load parameters from file. | ||
|
||
filename : str | ||
|
@@ -804,4 +909,4 @@ def load(self, filename, ctx=None, allow_missing=False, | |
"Please make sure source and target networks have the same prefix."%( | ||
name[lprefix:], filename, _brief_print_list(self._params.keys())) | ||
continue | ||
self[name]._load_init(arg_dict[name], ctx) | ||
self[name]._load_init(arg_dict[name], ctx, cast_stype=cast_stype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check shouldn't be done here.
Parameters are only added to the current block when self.params.get is called.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed. Will the checks in
param.list_data()
andparam.data()
be sufficient?