Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[NumPy]Set numpy default dtype (#17283)
Browse files Browse the repository at this point in the history
* draft 1

Preliminary completion

fix rebase mistake

* add new change

* use global flag & fix part of CI error

* fix CI error

* modify docs

* rebase

* consistent with FFI

* ones/full ffi

* arange ffi

* identity ffi

* depart set_np_default_dtype from set_np to test CI

fix sanity error

fix sanity error

* test CI

* fix CI error : use InitNumpyType instead of InitType / remove full ffi

fix sanity error

fix rebase error

* try to pass CI

* set the second output's dtype of normal to be float32

* modify ffi op

* update & rebase

* test windows CI

* comment test when platform is windows

* resolve comment

* add to set_np
  • Loading branch information
JiangZhaoh authored Apr 28, 2020
1 parent df28e61 commit 5c525c9
Show file tree
Hide file tree
Showing 45 changed files with 1,045 additions and 277 deletions.
2 changes: 1 addition & 1 deletion benchmark/python/einsum/benchmark_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,5 +83,5 @@ def test_np_einsum():


if __name__ == "__main__":
npx.set_np()
npx.set_np(dtype=False)
test_np_einsum()
5 changes: 4 additions & 1 deletion benchmark/python/ffi/benchmark_ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def generate_workloads():
def prepare_workloads():
pool = generate_workloads()
OpArgMngr.add_workload("zeros", (2, 2))
OpArgMngr.add_workload("full", (2, 2), 10)
OpArgMngr.add_workload("identity", 3)
OpArgMngr.add_workload("ones", (2, 2))
OpArgMngr.add_workload("einsum", "ii", pool['2x2'], optimize=False)
OpArgMngr.add_workload("unique", pool['1'], return_index=True, return_inverse=True, return_counts=True, axis=-1)
OpArgMngr.add_workload("dstack", (pool['2x1'], pool['2x1'], pool['2x1'], pool['2x1']))
Expand Down Expand Up @@ -244,7 +247,7 @@ def show_results(results):
import numpy as onp
from mxnet import np as dnp

mx.npx.set_np()
mx.npx.set_np(dtype=False)
packages = {
"onp": {
"module": onp,
Expand Down
14 changes: 14 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1305,6 +1305,20 @@ MXNET_DLL int MXIsNumpyShape(int* curr);
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSetIsNumpyShape(int is_np_shape, int* prev);
/*!
* \brief get numpy default data type
* \param curr returns the current status
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXIsNumpyDefaultDtype(bool* curr);
/*!
* \brief set numpy default data type
* \param dtype_flag false when default dtype is flaot32,
* true when default dtype is flaot64.
* \param prev returns the previous status before this set
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSetIsNumpyDefaultDtype(bool dtype_flag, bool* prev);
/*!
* \brief mark NDArrays as variables to compute gradient for autograd
* \param num_var number of variable NDArrays
Expand Down
26 changes: 24 additions & 2 deletions include/mxnet/imperative.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ namespace mxnet {
* turn off numpy shape flag globally.
* */
enum NumpyShape{Off, ThreadLocalOn, GlobalOn};
typedef NumpyShape NumpyDefaultDtype;
/*! \brief runtime functions for NDArray */
class Imperative {
public:
Expand Down Expand Up @@ -189,9 +190,11 @@ class Imperative {
* */
int is_np_shape() const {
if (is_np_shape_global_) {
return 2;
return NumpyShape::GlobalOn;
}
return is_np_shape_thread_local_ ? 1 : 0;
return is_np_shape_thread_local_ ?
NumpyShape::ThreadLocalOn :
NumpyShape::Off;
}
/*! \brief specify numpy compatibility off, thread local on or global on. */
bool set_is_np_shape(int is_np_shape) {
Expand All @@ -212,6 +215,24 @@ class Imperative {
}
return old;
}
/*! \brief return current numpy default dtype compatibility status.
* */
bool is_np_default_dtype() const {
if (is_np_default_dtype_global_) {
return true;
}
return false;
}
/*! \brief specify numpy default dtype off or global on. */
bool set_is_np_default_dtype(bool is_np_default_dtype) {
bool old = this->is_np_default_dtype();
if (is_np_default_dtype) {
is_np_default_dtype_global_ = true;
} else {
is_np_default_dtype_global_ = false;
}
return old;
}
/*! \brief to record operator, return corresponding node. */
void RecordOp(nnvm::NodeAttrs&& attrs,
const std::vector<NDArray*>& inputs,
Expand Down Expand Up @@ -301,6 +322,7 @@ class Imperative {
static MX_THREAD_LOCAL bool is_np_shape_thread_local_;
#endif
bool is_np_shape_global_{false};
bool is_np_default_dtype_global_{false};
/*! \brief node count used for naming */
std::atomic<uint64_t> node_count_{0};
/*! \brief variable count used for naming */
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .base import MXNetError
from .util import is_np_shape, set_np_shape, np_shape, use_np_shape
from .util import is_np_array, np_array, use_np_array, use_np
from .util import is_np_default_dtype, np_default_dtype, use_np_default_dtype
from . import base

# version info
Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/gluon/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def __len__(self):

def _thread_worker_initializer(active_shape, active_array):
"""Initializer for ThreadPool."""
set_np(shape=active_shape, array=active_array)
set_np(shape=active_shape, array=active_array, dtype=False)


_worker_dataset = None
Expand All @@ -418,7 +418,7 @@ def _worker_initializer(dataset, active_shape, active_array):
# can be passed as argument
global _worker_dataset
_worker_dataset = dataset
set_np(shape=active_shape, array=active_array)
set_np(shape=active_shape, array=active_array, dtype=False)

def _worker_fn(samples, batchify_fn, dataset=None):
"""Function for processing data in worker process."""
Expand Down
Loading

0 comments on commit 5c525c9

Please sign in to comment.