This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
dtype default to source_array.dtype for sparse ndarrays #8403
Merged
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
e03bcce
derive default dtype/ctx from input for sparse ndarrays
eric-haibin-lin 4205b61
add gpu tests
eric-haibin-lin 1251535
fix lint. add doc
eric-haibin-lin 22a880d
remove default_ctx code
eric-haibin-lin fe2e990
Merge branch 'scipy-fix' of https://github.com/eric-haibin-lin/mxnet …
eric-haibin-lin 5478e26
bug fix when passing dtype to array()
eric-haibin-lin bec84a0
update doc
eric-haibin-lin 9cc7ee9
Merge branch 'scipy-fix' of github.com:eric-haibin-lin/mxnet into sci…
eric-haibin-lin 93c13a3
remove extra line
eric-haibin-lin 24b432a
also check ctx
eric-haibin-lin File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -727,6 +727,18 @@ def _prepare_src_array(source_array, dtype): | |
raise TypeError('values must be array like object') | ||
return source_array | ||
|
||
def _prepare_default_dtype(src_array, dtype): | ||
"""Prepare the value of dtype if `dtype` is None. If `src_array` is an NDArray, numpy.ndarray | ||
or scipy.sparse.csr.csr_matrix, return src_array.dtype. float32 is returned otherwise.""" | ||
if dtype is None: | ||
if isinstance(src_array, (NDArray, np.ndarray)): | ||
dtype = src_array.dtype | ||
elif spsp and isinstance(src_array, spsp.csr.csr_matrix): | ||
dtype = src_array.dtype | ||
else: | ||
dtype = mx_real_t | ||
return dtype | ||
|
||
def _check_shape(s1, s2): | ||
"""check s1 == s2 if both are not None""" | ||
if s1 and s2 and s1 != s2: | ||
|
@@ -749,12 +761,11 @@ def csr_matrix(arg1, shape=None, ctx=None, dtype=None): | |
|
||
- csr_matrix(S) | ||
to construct a CSRNDArray with a sparse 2D array ``S`` | ||
- **S** (*CSRNDArray or scipy.sparse.csr_matrix*) - A sparse matrix. | ||
- **S** (*CSRNDArray or scipy.sparse.csr.csr_matrix*) - A sparse matrix. | ||
- **ctx** (*Context, optional*) - Device context \ | ||
(default is the current default context). | ||
- **dtype** (*str or numpy.dtype, optional*) - The data type of the output array. \ | ||
The default dtype is ``D.dtype`` if ``D`` is an NDArray or numpy.ndarray, \ | ||
float32 otherwise. | ||
The default dtype is ``S.dtype``. | ||
|
||
- csr_matrix((M, N)) | ||
to construct an empty CSRNDArray with shape ``(M, N)`` | ||
|
@@ -784,19 +795,20 @@ def csr_matrix(arg1, shape=None, ctx=None, dtype=None): | |
- **ctx** (*Context, optional*) - Device context \ | ||
(default is the current default context). | ||
- **dtype** (*str or numpy.dtype, optional*) - The data type of the output array. \ | ||
The default dtype is float32. | ||
The default dtype is ``data.dtype`` if ``data`` is an NDArray or numpy.ndarray, \ | ||
float32 otherwise. | ||
|
||
Parameters | ||
---------- | ||
arg1: tuple of int, tuple of array_like, array_like, CSRNDArray or scipy.sparse.csr_matrix | ||
arg1: NDArray, CSRNDArray, numpy.ndarray, scipy.sparse.csr.csr_matrix, tuple of int or tuple \ | ||
of array_like | ||
The argument to help instantiate the csr matrix. See above for further details. | ||
shape : tuple of int | ||
shape : tuple of int, optional | ||
The shape of the csr matrix. | ||
ctx: Context, optional | ||
Device context (default is the current default context). | ||
dtype: str or numpy.dtype, optional | ||
The data type of the output array. The default dtype is ``values.dtype`` | ||
if `values` is an `NDArray`, `float32` otherwise. | ||
The data type of the output array. | ||
|
||
Returns | ||
------- | ||
|
@@ -839,7 +851,14 @@ def csr_matrix(arg1, shape=None, ctx=None, dtype=None): | |
raise ValueError("Unexpected input type: RowSparseNDArray") | ||
else: | ||
# construct a csr matrix from a dense one | ||
dns = _array(arg1, ctx=ctx, dtype=dtype) | ||
# prepare default ctx and dtype since mx.nd.array doesn't use default values | ||
# based on source_array | ||
dtype = _prepare_default_dtype(arg1, dtype) | ||
# create dns array with provided dtype. ctx is not passed since copy across | ||
# ctx requires dtype to be the same | ||
dns = _array(arg1, dtype=dtype) | ||
if ctx is not None and dns.context != ctx: | ||
dns = dns.as_in_context(ctx) | ||
_check_shape(dns.shape, shape) | ||
return dns.tostype('csr') | ||
|
||
|
@@ -848,10 +867,9 @@ def _csr_matrix_from_definition(data, indices, indptr, shape=None, ctx=None, | |
"""Create a `CSRNDArray` based on data, indices and indptr""" | ||
storage_type = 'csr' | ||
# context | ||
if ctx is None: | ||
ctx = Context.default_ctx | ||
ctx = Context.default_ctx if ctx is None else ctx | ||
# types | ||
dtype = mx_real_t if dtype is None else dtype | ||
dtype = _prepare_default_dtype(data, dtype) | ||
indptr_type = _STORAGE_AUX_TYPES[storage_type][0] if indptr_type is None else indptr_type | ||
indices_type = _STORAGE_AUX_TYPES[storage_type][1] if indices_type is None else indices_type | ||
# prepare src array and types | ||
|
@@ -906,8 +924,7 @@ def row_sparse_array(arg1, shape=None, ctx=None, dtype=None): | |
- **ctx** (*Context, optional*) - Device context \ | ||
(default is the current default context). | ||
- **dtype** (*str or numpy.dtype, optional*) - The data type of the output array. \ | ||
The default dtype is ``D.dtype`` if ``D`` is an NDArray or numpy.ndarray, \ | ||
float32 otherwise. | ||
The default dtype is ``S.dtype``. | ||
|
||
- row_sparse_array((D0, D1 .. Dn)) | ||
to construct an empty RowSparseNDArray with shape ``(D0, D1, ... Dn)`` | ||
|
@@ -931,20 +948,21 @@ def row_sparse_array(arg1, shape=None, ctx=None, dtype=None): | |
stores the row index for each row slice with non-zero elements. | ||
- **shape** (*tuple of int, optional*) - The shape of the array. The default \ | ||
shape is inferred from the indices and indptr arrays. | ||
- **ctx** (*Context, optional*) - Device context \ | ||
(default is the current default context). | ||
- **dtype** (*str or numpy.dtype, optional*) - The data type of the output array. \ | ||
The default dtype is float32. | ||
|
||
Parameters | ||
---------- | ||
arg1: tuple of int, tuple of array_like, array_like or RowSparseNDArray | ||
arg1: NDArray, numpy.ndarray, RowSparseNDArray, tuple of int or tuple of array_like | ||
The argument to help instantiate the row sparse ndarray. See above for further details. | ||
shape : tuple of int | ||
shape : tuple of int, optional | ||
The shape of the row sparse ndarray. | ||
ctx : Context, optional | ||
Device context (default is the current default context). | ||
dtype : str or numpy.dtype, optional | ||
The data type of the output array. The default dtype is ``data.dtype`` | ||
if `data` is an `NDArray`, `float32` otherwise. | ||
The data type of the output array. | ||
|
||
Returns | ||
------- | ||
|
@@ -995,7 +1013,14 @@ def row_sparse_array(arg1, shape=None, ctx=None, dtype=None): | |
raise ValueError("Unexpected input type: CSRNDArray") | ||
else: | ||
# construct a csr matrix from a dense one | ||
dns = _array(arg1, ctx=ctx, dtype=dtype) | ||
# prepare default dtype since mx.nd.array doesn't use default values | ||
# based on source_array | ||
dtype = _prepare_default_dtype(arg1, dtype) | ||
# create dns array with provided dtype. ctx is not passed since copy across | ||
# ctx requires dtype to be the same | ||
dns = _array(arg1, dtype=dtype) | ||
if ctx is not None and dns.context != ctx: | ||
dns = dns.as_in_context(ctx) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not put context in _array constructor? |
||
_check_shape(dns.shape, shape) | ||
return dns.tostype('row_sparse') | ||
|
||
|
@@ -1004,10 +1029,9 @@ def _row_sparse_ndarray_from_definition(data, indices, shape=None, ctx=None, | |
"""Create a `RowSparseNDArray` based on data and indices""" | ||
storage_type = 'row_sparse' | ||
# context | ||
if ctx is None: | ||
ctx = Context.default_ctx | ||
ctx = Context.default_ctx if ctx is None else ctx | ||
# types | ||
dtype = mx_real_t if dtype is None else dtype | ||
dtype = _prepare_default_dtype(data, dtype) | ||
indices_type = _STORAGE_AUX_TYPES[storage_type][0] if indices_type is None else indices_type | ||
# prepare src array and types | ||
data = _prepare_src_array(data, dtype) | ||
|
@@ -1022,7 +1046,9 @@ def _row_sparse_ndarray_from_definition(data, indices, shape=None, ctx=None, | |
indices = _array(indices, ctx, indices_type) | ||
if shape is None: | ||
num_indices = indices.shape[0] | ||
dim0 = 0 if num_indices == 0 else indices[num_indices - 1].asscalar() + 1 | ||
if num_indices == 0: | ||
raise ValueError('invalid shape') | ||
dim0 = indices[num_indices - 1].asscalar() + 1 | ||
shape = (dim0, ) + data.shape[1:] | ||
# verify shapes | ||
if data.ndim != len(shape) or indices.ndim != 1 or np.prod(shape[1:]) == 0: | ||
|
@@ -1127,10 +1153,12 @@ def array(source_array, ctx=None, dtype=None): | |
source_array : RowSparseNDArray, CSRNDArray or scipy.sparse.csr.csr_matrix | ||
The source sparse array | ||
ctx : Context, optional | ||
Device context (default is the current default context). | ||
The default context is ``source_array.context`` if ``source_array`` is an NDArray. \ | ||
The current default context otherwise. | ||
dtype : str or numpy.dtype, optional | ||
The data type of the output array. The default dtype is ``source_array.dtype`` | ||
if `source_array` is an `NDArray`, `float32` otherwise. | ||
if `source_array` is an `NDArray`, `numpy.ndarray` or `scipy.sparse.csr.csr_matrix`, \ | ||
`float32` otherwise. | ||
|
||
Returns | ||
------- | ||
|
@@ -1148,19 +1176,29 @@ def array(source_array, ctx=None, dtype=None): | |
>>> mx.nd.sparse.array(mx.nd.sparse.zeros('row_sparse', (3, 2))) | ||
<RowSparseNDArray 3x2 @cpu(0)> | ||
""" | ||
ctx = Context.default_ctx if ctx is None else ctx | ||
if isinstance(source_array, NDArray): | ||
assert(source_array.stype != 'default'), \ | ||
"Please use `tostype` to create RowSparseNDArray or CSRNDArray from an NDArray" | ||
dtype = source_array.dtype if dtype is None else dtype | ||
arr = empty(source_array.stype, source_array.shape, ctx=ctx, dtype=dtype) | ||
arr[:] = source_array | ||
# prepare dtype and ctx based on source_array, if not provided | ||
dtype = _prepare_default_dtype(source_array, dtype) | ||
# if both dtype and ctx are different from source_array, we cannot copy directly | ||
if source_array.dtype != dtype and source_array.context != ctx: | ||
arr = empty(source_array.stype, source_array.shape, dtype=dtype) | ||
arr[:] = source_array | ||
arr = arr.as_in_context(ctx) | ||
else: | ||
arr = empty(source_array.stype, source_array.shape, dtype=dtype, ctx=ctx) | ||
arr[:] = source_array | ||
return arr | ||
elif spsp and isinstance(source_array, spsp.csr.csr_matrix): | ||
# TODO(haibin) implement `_sync_copy_from` with scipy csr object to reduce a copy | ||
# preprocess scipy csr to canonical form | ||
csr = source_array.sorted_indices() | ||
csr.sum_duplicates() | ||
return csr_matrix((csr.data, csr.indices, csr.indptr), shape=csr.shape, dtype=dtype) | ||
dtype = _prepare_default_dtype(source_array, dtype) | ||
return csr_matrix((csr.data, csr.indices, csr.indptr), shape=csr.shape, \ | ||
dtype=dtype, ctx=ctx) | ||
elif isinstance(source_array, (np.ndarray, np.generic)): | ||
raise ValueError("Please use mx.nd.array to create an NDArray with source_array of type ", | ||
type(source_array)) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't we raise an error if spsp is not available?
Do we currently require scipy when using sparse?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scipy is an optional dependency right now for MXNet. If scipy is not there, spsp is None