diff --git a/dmlc-core b/dmlc-core index 7d3c78428819..75f1950d386d 160000 --- a/dmlc-core +++ b/dmlc-core @@ -1 +1 @@ -Subproject commit 7d3c78428819dc84c4da8ae1f302ba6c6a235a5d +Subproject commit 75f1950d386d033b0b64919017515d27e698962a diff --git a/doc/user-guide/executor.md b/doc/user-guide/executor.md deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/doc/user-guide/symbol.md b/doc/user-guide/symbol.md deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 4e22abfcb6cd..d43e0576fab3 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -127,12 +127,47 @@ MXNET_DLL int MXNArrayListLoad(const char* fname, mx_uint *out_name_size, const char*** out_names); /*! - * \brief wait until all the operation with respect NArray - * to this NArray is finished, always call this before fetching data out + * \brief Perform a synchronize copy from a continugous CPU memory region. + * + * This function will call WaitToWrite before the copy is performed. + * This is useful to copy data from existing memory region that are + * not wrapped by NArray(thus dependency not being tracked). + * + * \param handle the NArray handle + * \param data the data source to copy from. + * \param size the memory size we want to copy from. + */ +MXNET_DLL int MXNArraySyncCopyFromCPU(NArrayHandle handle, + const mx_float *data, + size_t size); +/*! + * \brief Perform a synchronize copyto a continugous CPU memory region. + * + * This function will call WaitToRead before the copy is performed. + * This is useful to copy data from existing memory region that are + * not wrapped by NArray(thus dependency not being tracked). + * + * \param handle the NArray handle + * \param data the data source to copy into. + * \param size the memory size we want to copy into. + */ +MXNET_DLL int MXNArraySyncCopyToCPU(NArrayHandle handle, + mx_float *data, + size_t size); +/*! + * \brief Wait until all the pending writes with respect NArray are finished. + * Always call this before read data out synchronizely. + * \param handle the NArray handle + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXNArrayWaitToRead(NArrayHandle handle); +/*! + * \brief Wait until all the pending read/write with respect NArray are finished. + * Always call this before write data into NArray synchronizely. * \param handle the NArray handle * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXNArrayWait(NArrayHandle handle); +MXNET_DLL int MXNArrayWaitToWrite(NArrayHandle handle); /*! * \brief wait until all delayed operations in * the system is completed diff --git a/include/mxnet/dag_engine.h b/include/mxnet/dag_engine.h index c119c8137132..ca10dad85441 100644 --- a/include/mxnet/dag_engine.h +++ b/include/mxnet/dag_engine.h @@ -129,11 +129,29 @@ class DAGEngine { */ virtual void PushDelete(Fn delete_fun, Context exec_ctx, Variable var) = 0; /*! - * \brief Wait for variable. - * \param var The variable we should wait for, this function returns when all - * the operations related to var has been completed. + * \brief Wait to read a variable. + * + * The caller should read the content immediately in a synchronized way, + * before any subsequent write operations are issued. + * The subsequent write operations to the variable can destroy the content. + * + * \param var The variable we should wait for, + * This function returns when all the write operations to this + * var has been completed. + */ + virtual void WaitToRead(Variable var) = 0; + /*! + * \brief Wait to write a variable. + * + * The caller should rwrite the content immediately in a synchronized way, + * before any subsequent write operations are issued. + * The subsequent write operations to the variable can destroy the content. + * + * \param var The variable we should wait for, + * This function returns when all the read/write operations + * on var has been completed. */ - virtual void WaitForVar(Variable var) = 0; + virtual void WaitToWrite(Variable var) = 0; /*! * \brief Wait until all the activity of dag engine finishes. */ diff --git a/include/mxnet/narray.h b/include/mxnet/narray.h index 06da4d841944..26e796d55d39 100644 --- a/include/mxnet/narray.h +++ b/include/mxnet/narray.h @@ -73,10 +73,21 @@ class NArray { inline bool is_none() const { return ptr_.get() == nullptr; } - /*! \brief wait until the result of the NArray is computed */ - inline void Wait() const { + /*! + * \brief Block until all the pending write operations with respect + * to current NArray are finished, and read can be performed. + */ + inline void WaitToRead() const { if (is_none()) return; - DAGEngine::Get()->WaitForVar(ptr_->var); + DAGEngine::Get()->WaitToRead(ptr_->var); + } + /*! + * \brief Block until all the pending read/write operations with respect + * to current NArray are finished, and write can be performed. + */ + inline void WaitToWrite() const { + if (is_none()) return; + DAGEngine::Get()->WaitToWrite(ptr_->var); } /*! \return the associated DAG variable of the narray.*/ inline DAGEngine::Variable var() const { @@ -166,6 +177,28 @@ class NArray { * \return the new copy */ NArray Copy(Context ctx) const; + /*! + * \brief Do a synchronize copy from a continugous CPU memory region. + * + * This function will call WaitToWrite before the copy is performed. + * This is useful to copy data from existing memory region that are + * not wrapped by NArray(thus dependency not being tracked). + * + * \param data the data source to copy from. + * \param size the memory size we want to copy from. + */ + void SyncCopyFromCPU(const real_t *data, size_t size) const; + /*! + * \brief Do a synchronize copy to a continugous CPU memory region. + * + * This function will call WaitToRead before the copy is performed. + * This is useful to copy data from existing memory region that are + * not wrapped by NArray(thus dependency not being tracked). + * + * \param data the data source to copyinto. + * \param size the memory size we want to copy into. + */ + void SyncCopyToCPU(real_t *data, size_t size) const; /*! * \brief Slice a NArray * \param begin begin index in first dim diff --git a/python/mxnet/narray.py b/python/mxnet/narray.py index 76fdc5f893d0..6706bd4fecc1 100644 --- a/python/mxnet/narray.py +++ b/python/mxnet/narray.py @@ -5,10 +5,11 @@ import ctypes import warnings import sys +import numpy as np from .base import _LIB, string_types, numeric_types from .base import c_array, py_str, c_str from .base import mx_uint, mx_float, NArrayHandle, FunctionHandle -from .base import ctypes2numpy_shared, ctypes2buffer +from .base import ctypes2buffer from .base import check_call from .context import Context @@ -183,7 +184,9 @@ def __setitem__(self, in_slice, value): if value.handle is not self.handle: value.copyto(self) elif isinstance(value, numeric_types): - return NArray._set_value(float(value), out=self) + NArray._set_value(float(value), out=self) + elif isinstance(value, (np.ndarray, np.generic)): + self._sync_copyfrom(value) else: raise TypeError('type %s not supported' % str(type(value))) @@ -193,9 +196,47 @@ def __getitem__(self, in_slice): raise Exception("Set NArray should use empty index array[:] += value") return self - def wait(self): - """Wait until the data on current NArray is available.""" - check_call(_LIB.MXNArrayWait(self.handle)) + def _sync_copyfrom(self, source_array): + """Peform an synchronize copy from the array. + + Parameters + ---------- + source_array : array_like + The data source we should like to copy from. + """ + if not isinstance(source_array, np.ndarray): + try: + source_array = np.array(source_array, dtype=np.float32) + except: + raise TypeError('array must be an array_like data,' + + 'type %s is not supported' % str(type(array))) + source_array = np.ascontiguousarray(source_array, dtype=np.float32) + + if source_array.shape != self.shape: + raise ValueError('array shape do not match the shape of NArray') + + check_call(_LIB.MXNArraySyncCopyFromCPU( + self.handle, + source_array.ctypes.data_as(ctypes.POINTER(mx_float)), + source_array.size)) + + def wait_to_read(self): + """Block until all pending writes operations on current NArray are finished. + + This function will return when all the pending writes to the current + NArray finishes. There can still be pending read going on when the + function returns. + """ + check_call(_LIB.MXNArrayWaitToRead(self.handle)) + + def wait_to_write(self): + """Block until all pending read/write operations on current NArray are finished. + + This function will return when all the pending writes to the current + NArray finishes. There can still be pending read going on when the + function returns. + """ + check_call(_LIB.MXNArrayWaitToWrite(self.handle)) @property def shape(self): @@ -217,7 +258,8 @@ def context(self): Returns ------- - the context of current NArray + context : mxnet.Context + The context of current NArray. """ dev_mask = ctypes.c_int() dev_id = ctypes.c_int() @@ -225,20 +267,20 @@ def context(self): self.handle, ctypes.byref(dev_mask), ctypes.byref(dev_id))) return Context(Context.devmask2type[dev_mask.value], dev_id.value) - @property - def numpy(self): - """Return a numpy representation of current array. - - This array have to sit on CPU + def asnumpy(self): + """Return a copied numpy array of current array. Returns ------- - a numpy array view + array : numpy.ndarray + A copy of array content. """ - self.wait() - pdata = ctypes.POINTER(mx_float)() - check_call(_LIB.MXNArrayGetData(self.handle, ctypes.byref(pdata))) - return ctypes2numpy_shared(pdata, self.shape) + data = np.empty(self.shape, dtype=np.float32) + check_call(_LIB.MXNArraySyncCopyToCPU( + self.handle, + data.ctypes.data, + data.size)) + return data def copyto(self, other): """Copy the content of current array to other. @@ -271,8 +313,8 @@ def copyto(self, other): # pylint: enable= no-member -def create(shape, ctx=None): - """Create a new NArray, with specified shape. +def empty(shape, ctx=None): + """Create an empty uninitialized new NArray, with specified shape. Parameters ---------- @@ -292,6 +334,33 @@ def create(shape, ctx=None): return NArray(handle=_new_alloc_handle(shape, ctx, False)) +def array(source_array, ctx=None): + """Create a new NArray that copies content from source_array. + + Parameters + ---------- + source_array : array_like + Source data to create NArray from. + + ctx : Context, optional + The context of the NArray, default to current default context. + + Returns + ------- + out: Array + The created NArray. + """ + + if not isinstance(source_array, np.ndarray): + try: + source_array = np.array(source_array, dtype=np.float32) + except: + raise TypeError('source_array must be array like object') + arr = empty(source_array.shape, ctx) + arr[:] = source_array + return arr + + def load(fname): """Load narray from binary file. diff --git a/src/c_api.cc b/src/c_api.cc index 48d83cffc688..2e59613829bc 100644 --- a/src/c_api.cc +++ b/src/c_api.cc @@ -235,9 +235,31 @@ int MXNArraySaveRawBytes(NArrayHandle handle, API_END(); } -int MXNArrayWait(NArrayHandle handle) { +int MXNArraySyncCopyFromCPU(NArrayHandle handle, + const mx_float *data, + size_t size) { API_BEGIN(); - static_cast(handle)->Wait(); + static_cast(handle)->SyncCopyFromCPU(data, size); + API_END(); +} + +int MXNArraySyncCopyToCPU(NArrayHandle handle, + mx_float *data, + size_t size) { + API_BEGIN(); + static_cast(handle)->SyncCopyToCPU(data, size); + API_END(); +} + +int MXNArrayWaitToRead(NArrayHandle handle) { + API_BEGIN(); + static_cast(handle)->WaitToRead(); + API_END(); +} + +int MXNArrayWaitToWrite(NArrayHandle handle) { + API_BEGIN(); + static_cast(handle)->WaitToWrite(); API_END(); } diff --git a/src/dag_engine/naive_engine.cc b/src/dag_engine/naive_engine.cc index 6ad780c9615c..bffeb474bfa6 100644 --- a/src/dag_engine/naive_engine.cc +++ b/src/dag_engine/naive_engine.cc @@ -66,7 +66,10 @@ class NaiveEngine : public DAGEngine { this->Push(delete_fun, exec_ctx, {}, {var}); } - void WaitForVar(Variable var) override { + void WaitToRead(Variable var) override { + } + + void WaitToWrite(Variable var) override { } void WaitForAll() override { diff --git a/src/narray/narray.cc b/src/narray/narray.cc index eee59ed8ecd1..c9dda3f5a654 100644 --- a/src/narray/narray.cc +++ b/src/narray/narray.cc @@ -313,10 +313,10 @@ void NArray::Save(dmlc::Stream *strm) const { NArray temp; if (ctx.dev_mask != cpu::kDevMask) { temp = this->Copy(Context(cpu::kDevMask, 0)); - temp.Wait(); + temp.WaitToRead(); save_data = temp.data(); } else { - this->Wait(); + this->WaitToRead(); save_data = this->data(); } // save type flag @@ -365,6 +365,58 @@ NArray NArray::Copy(Context ctx) const { return ret; } +void NArray::SyncCopyFromCPU(const real_t *data, size_t size) const { + this->WaitToWrite(); + TShape dshape = this->shape(); + CHECK_EQ(dshape.Size(), size) + << "Memory size do not match"; + Context ctx = this->ctx(); + TBlob dst = this->data(); + TBlob src((real_t*)data, dshape, cpu::kDevMask); // NOLINT(*) + + RunContext run_ctx; + if (ctx.dev_mask == cpu::kDevMask) { + narray::Copy(src, &dst, Context(cpu::kDevMask, 0), ctx, run_ctx); + } else { +#if MXNET_USE_CUDA + // use empty stream to do sync copy + // TODO(bing, yutian) consider use a Real Stream, so it is not blocking others + // Maybe move to engine part + mshadow::Stream zero_stream; + run_ctx.stream = &zero_stream; + narray::Copy(src, &dst, Context(cpu::kDevMask, 0), ctx, run_ctx); +#else + LOG(FATAL) << "GPU is not enabled"; +#endif + } +} + +void NArray::SyncCopyToCPU(real_t *data, size_t size) const { + this->WaitToRead(); + TShape dshape = this->shape(); + CHECK_EQ(dshape.Size(), size) + << "Memory size do not match"; + Context ctx = this->ctx(); + TBlob src = this->data(); + TBlob dst(data, dshape, cpu::kDevMask); // NOLINT(*) + + RunContext run_ctx; + if (ctx.dev_mask == cpu::kDevMask) { + narray::Copy(src, &dst, ctx, Context(cpu::kDevMask, 0), run_ctx); + } else { +#if MXNET_USE_CUDA + // use empty stream to do sync copy + // TODO(bing, yutian) consider use a Real Stream, so it is not blocking others + // Maybe move to engine part + mshadow::Stream zero_stream; + run_ctx.stream = &zero_stream; + narray::Copy(src, &dst, ctx, Context(cpu::kDevMask, 0), run_ctx); +#else + LOG(FATAL) << "GPU is not enabled"; +#endif + } +} + // register API function // those with underscore will be registered at NArray MXNET_REGISTER_NARRAY_FUN(_set_value).set_function(SetValueOp); diff --git a/tests/python/test_bind.py b/tests/python/test_bind.py index 934b66688934..8802eb87d3c2 100644 --- a/tests/python/test_bind.py +++ b/tests/python/test_bind.py @@ -16,12 +16,11 @@ def check_bind_with_uniform(uf, gf, dim): rhs = mx.symbol.Variable('rhs') ret = uf(lhs, rhs) assert ret.list_arguments() == ['lhs', 'rhs'] - lhs_arr = mx.narray.create(shape) - rhs_arr = mx.narray.create(shape) - lhs_grad = mx.narray.create(shape) - rhs_grad = mx.narray.create(shape) - lhs_arr.numpy[:] = np.random.uniform(-10, 10, shape) - rhs_arr.numpy[:] = np.random.uniform(-10, 10, shape) + lhs_arr = mx.narray.array(np.random.uniform(-10, 10, shape)) + rhs_arr = mx.narray.array(np.random.uniform(-10, 10, shape)) + lhs_grad = mx.narray.empty(shape) + rhs_grad = mx.narray.empty(shape) + executor = ret.bind(mx.Context('cpu'), args=[lhs_arr, rhs_arr], @@ -41,22 +40,21 @@ def check_bind_with_uniform(uf, gf, dim): executor.forward() exec3.forward() exec4.forward() - out2 = executor.heads()[0].numpy - out1 = uf(lhs_arr.numpy, rhs_arr.numpy) - out3 = exec3.heads()[0].numpy - out4 = exec4.heads()[0].numpy + out2 = executor.heads()[0].asnumpy() + out1 = uf(lhs_arr.asnumpy(), rhs_arr.asnumpy()) + out3 = exec3.heads()[0].asnumpy() + out4 = exec4.heads()[0].asnumpy() assert reldiff(out1, out2) < 1e-6 assert reldiff(out1, out3) < 1e-6 assert reldiff(out1, out4) < 1e-6 # test gradient - out_grad = mx.narray.create(shape) - out_grad.numpy[:] = np.ones(shape) - lhs_grad2, rhs_grad2 = gf(out_grad.numpy, - lhs_arr.numpy, - rhs_arr.numpy) + out_grad = mx.narray.array(np.ones(shape)) + lhs_grad2, rhs_grad2 = gf(out_grad.asnumpy(), + lhs_arr.asnumpy(), + rhs_arr.asnumpy()) executor.backward([out_grad]) - assert reldiff(lhs_grad.numpy, lhs_grad2) < 1e-6 - assert reldiff(rhs_grad.numpy, rhs_grad2) < 1e-6 + assert reldiff(lhs_grad.asnumpy(), lhs_grad2) < 1e-6 + assert reldiff(rhs_grad.asnumpy(), rhs_grad2) < 1e-6 def test_bind(): @@ -79,3 +77,5 @@ def test_bind(): dim) +if __name__ == "__main__": + test_bind() \ No newline at end of file diff --git a/tests/python/test_conv.py b/tests/python/test_conv.py index 49a7f998d04e..9ab34ce1c8ae 100644 --- a/tests/python/test_conv.py +++ b/tests/python/test_conv.py @@ -32,22 +32,22 @@ def CalAcc(out, label): data_shape = (batch_size, 1, 28, 28) arg_shapes, out_shapes, aux_shapes = softmax.infer_shape(data=data_shape) -arg_narrays = [mx.narray.create(shape) for shape in arg_shapes] -grad_narrays = [mx.narray.create(shape) for shape in arg_shapes] -aux_narrays = [mx.narray.create(shape) for shape in aux_shapes] +arg_narrays = [mx.narray.empty(shape) for shape in arg_shapes] +grad_narrays = [mx.narray.empty(shape) for shape in arg_shapes] +aux_narrays = [mx.narray.empty(shape) for shape in aux_shapes] inputs = dict(zip(args_list, arg_narrays)) np.random.seed(0) # set random weight for name, narray in inputs.items(): if "weight" in name: - narray.numpy[:] = np.random.uniform(-0.07, 0.07, narray.numpy.shape) + narray[:] = np.random.uniform(-0.07, 0.07, narray.shape) if "bias" in name: - narray.numpy[:] = 0.0 + narray[:] = 0.0 if "gamma" in name: - narray.numpy[:] = 1.0 + narray[:] = 1.0 if "beta" in name: - narray.numpy[:] = 0.0 + narray[:] = 0.0 # bind executer # TODO(bing): think of a better bind interface @@ -55,7 +55,7 @@ def CalAcc(out, label): # update out_narray = executor.heads()[0] -grad_narray = mx.narray.create(out_narray.shape) +grad_narray = mx.narray.empty(out_narray.shape) epoch = 1 momentum = 0.9 @@ -90,14 +90,13 @@ def test_mnist(): train_nbatch = 0 val_nbatch = 0 for data, label in train_dataiter: - data = data.numpy - label = label.numpy.flatten() - inputs["data"].numpy[:] = data - inputs["sm_label"].numpy[:] = label + label = label.asnumpy().flatten() + inputs["data"][:] = data + inputs["sm_label"][:] = label executor.forward(is_train = True) - train_acc += CalAcc(out_narray.numpy, label) + train_acc += CalAcc(out_narray.asnumpy(), label) train_nbatch += 1 - grad_narray.numpy[:] = out_narray.numpy + grad_narray[:] = out_narray executor.backward([grad_narray]) for grad, weight in block: @@ -105,11 +104,10 @@ def test_mnist(): # evaluate for data, label in val_dataiter: - data = data.numpy - label = label.numpy.flatten() - inputs["data"].numpy[:] = data + label = label.asnumpy().flatten() + inputs["data"][:] = data executor.forward(is_train = False) - val_acc += CalAcc(out_narray.numpy, label) + val_acc += CalAcc(out_narray.asnumpy(), label) val_nbatch += 1 print("Train Acc: ", train_acc / train_nbatch) print("Valid Acc: ", val_acc / val_nbatch) diff --git a/tests/python/test_io.py b/tests/python/test_io.py index 1156782bdfef..54b538f13eba 100644 --- a/tests/python/test_io.py +++ b/tests/python/test_io.py @@ -30,14 +30,14 @@ def test_MNISTIter(): # test_reset train_dataiter.reset() train_dataiter.iter_next() - label_0 = train_dataiter.getlabel().numpy.flatten() + label_0 = train_dataiter.getlabel().asnumpy().flatten() train_dataiter.iter_next() train_dataiter.iter_next() train_dataiter.iter_next() train_dataiter.iter_next() train_dataiter.reset() train_dataiter.iter_next() - label_1 = train_dataiter.getlabel().numpy.flatten() + label_1 = train_dataiter.getlabel().asnumpy().flatten() assert(sum(label_0 - label_1) == 0) ''' @@ -102,3 +102,6 @@ def test_Cifar10Rec(): for i in range(10): assert(labelcount[i] == 1000) ''' + +if __name__ == "__main__": + test_MNISTIter() \ No newline at end of file diff --git a/tests/python/test_mlp.py b/tests/python/test_mlp.py index b1575bbad30a..85abbb9ac216 100644 --- a/tests/python/test_mlp.py +++ b/tests/python/test_mlp.py @@ -23,16 +23,16 @@ def CalAcc(out, label): # infer shape data_shape = (batch_size, 784) arg_shapes, out_shapes, aux_shapes = softmax.infer_shape(data=data_shape) -arg_narrays = [mx.narray.create(shape) for shape in arg_shapes] -grad_narrays = [mx.narray.create(shape) for shape in arg_shapes] +arg_narrays = [mx.narray.empty(shape) for shape in arg_shapes] +grad_narrays = [mx.narray.empty(shape) for shape in arg_shapes] inputs = dict(zip(args_list, arg_narrays)) np.random.seed(0) # set random weight for name, narray in inputs.items(): if "weight" in name: - narray.numpy[:, :] = np.random.uniform(-0.07, 0.07, narray.numpy.shape) + narray[:] = np.random.uniform(-0.07, 0.07, narray.shape) if "bias" in name: - narray.numpy[:] = 0.0 + narray[:] = 0.0 # bind executer # TODO(bing): think of a better bind interface @@ -40,7 +40,7 @@ def CalAcc(out, label): # update out_narray = executor.heads()[0] -grad_narray = mx.narray.create(out_narray.shape) +grad_narray = mx.narray.empty(out_narray.shape) epoch = 9 lr = 0.1 @@ -74,14 +74,13 @@ def test_mlp(): train_nbatch = 0 val_nbatch = 0 for data, label in train_dataiter: - data = data.numpy - label = label.numpy.flatten() - inputs["data"].numpy[:] = data - inputs["sm_label"].numpy[:] = label + label = label.asnumpy().flatten() + inputs["data"][:] = data + inputs["sm_label"][:] = label executor.forward() - train_acc += CalAcc(out_narray.numpy, label) + train_acc += CalAcc(out_narray.asnumpy(), label) train_nbatch += 1 - grad_narray.numpy[:] = out_narray.numpy + grad_narray[:] = out_narray executor.backward([grad_narray]) for grad, weight in block: @@ -89,11 +88,10 @@ def test_mlp(): # evaluate for data, label in val_dataiter: - data = data.numpy - label = label.numpy.flatten() - inputs["data"].numpy[:] = data + label = label.asnumpy().flatten() + inputs["data"][:] = data executor.forward() - val_acc += CalAcc(out_narray.numpy, label) + val_acc += CalAcc(out_narray.asnumpy(), label) val_nbatch += 1 acc_train = train_acc / train_nbatch acc_val = val_acc / val_nbatch diff --git a/tests/python/test_narray.py b/tests/python/test_narray.py index b6325112ba99..fd01abca9457 100644 --- a/tests/python/test_narray.py +++ b/tests/python/test_narray.py @@ -18,21 +18,19 @@ def check_with_uniform(uf, arg_shapes, dim=None): narray_arg = [] numpy_arg = [] for s in arg_shapes: - narr = mx.narray.create(s) npy = np.random.uniform(-10, 10, s) - narr.numpy[:] = npy + narr = mx.narray.array(npy) narray_arg.append(narr) numpy_arg.append(npy) out1 = uf(*narray_arg) out2 = uf(*numpy_arg) assert out1.shape == out2.shape - assert reldiff(out1.numpy, out2) < 1e-6 + assert reldiff(out1.asnumpy(), out2) < 1e-6 def random_narray(dim): shape = tuple(np.random.randint(1, int(1000**(1.0/dim)), size=dim)) - data = mx.narray.create(shape) - data.numpy[:] = np.random.uniform(-10, 10, data.shape) + data= mx.narray.array(np.random.uniform(-10, 10, shape)) return data def test_narray_elementwise(): @@ -47,25 +45,24 @@ def test_narray_elementwise(): check_with_uniform(lambda x, y: x / y, 2, dim) def test_narray_copy(): - c = mx.narray.create((10,10)) - c.numpy[:] = np.random.uniform(-10, 10, c.shape) + c = mx.narray.array(np.random.uniform(-10, 10, (10, 10))) d = c.copyto(mx.Context('cpu', 0)) - assert np.sum(np.abs(c.numpy != d.numpy)) == 0.0 + assert np.sum(np.abs(c.asnumpy() != d.asnumpy())) == 0.0 def test_narray_scalar(): - c = mx.narray.create((10,10)) - d = mx.narray.create((10,10)) - c.numpy[:] = 0.5 - d.numpy[:] = 1.0 + c = mx.narray.empty((10,10)) + d = mx.narray.empty((10,10)) + c[:] = 0.5 + d[:] = 1.0 d -= c * 2 / 3 * 6.0 c += 0.5 - assert(np.sum(c.numpy) - 100 < 1e-5) - assert(np.sum(d.numpy) + 100 < 1e-5) + assert(np.sum(c.asnumpy()) - 100 < 1e-5) + assert(np.sum(d.asnumpy()) + 100 < 1e-5) c[:] = 2 - assert(np.sum(c.numpy) - 200 < 1e-5) + assert(np.sum(c.asnumpy()) - 200 < 1e-5) d = -c + 2 - assert(np.sum(c.numpy) < 1e-5) + assert(np.sum(c.asnumpy()) < 1e-5) def test_narray_pickle(): np.random.seed(0) @@ -74,13 +71,13 @@ def test_narray_pickle(): for repeat in range(nrepeat): for dim in range(1, maxdim): a = random_narray(dim) - b = mx.narray.create(a.shape) - a.numpy[:] = np.random.uniform(-10, 10, a.shape) - b.numpy[:] = np.random.uniform(-10, 10, a.shape) + b = mx.narray.empty(a.shape) + a[:] = np.random.uniform(-10, 10, a.shape) + b[:] = np.random.uniform(-10, 10, a.shape) a = a + b data = pkl.dumps(a) a2 = pkl.loads(data) - assert np.sum(a.numpy != a2.numpy) == 0 + assert np.sum(a.asnumpy() != a2.asnumpy()) == 0 def test_narray_saveload(): @@ -96,14 +93,14 @@ def test_narray_saveload(): data2 = mx.narray.load(fname) assert len(data) == len(data2) for x, y in zip(data, data2): - assert np.sum(x.numpy != y.numpy) == 0 + assert np.sum(x.asnumpy() != y.asnumpy()) == 0 dmap = {'narray xx %s' % i : x for i, x in enumerate(data)} mx.narray.save(fname, dmap) dmap2 = mx.narray.load(fname) assert len(dmap2) == len(dmap) for k, x in dmap.items(): y = dmap2[k] - assert np.sum(x.numpy != y.numpy) == 0 + assert np.sum(x.asnumpy() != y.asnumpy()) == 0 os.remove(fname) if __name__ == '__main__': @@ -112,3 +109,4 @@ def test_narray_saveload(): test_narray_copy() test_narray_elementwise() test_narray_scalar() + diff --git a/tests/python/test_operator.py b/tests/python/test_operator.py index 988610d34811..5b3540392dde 100644 --- a/tests/python/test_operator.py +++ b/tests/python/test_operator.py @@ -20,24 +20,24 @@ def check_elementwise_sum_with_shape(shape, n): # forward inputs = [mx.symbol.Variable('arg%d' % i) for i in range(n)] out = mx.symbol.ElementWiseSum(*inputs, name='esum') - arr = [mx.narray.create(shape) for i in range(n)] - arr_grad = [mx.narray.create(shape) for i in range(n)] + arr = [mx.narray.empty(shape) for i in range(n)] + arr_grad = [mx.narray.empty(shape) for i in range(n)] for i in range(n): - arr[i].numpy[:] = np.random.uniform(-10, 10, shape) + arr[i][:] = np.random.uniform(-10, 10, shape) exec1 = out.bind(mx.Context('cpu'), args=arr, args_grad=arr_grad) - out1 = exec1.heads()[0].numpy + out1 = exec1.heads()[0].asnumpy() exec1.forward() - out1 = exec1.heads()[0].numpy - out = sum(a.numpy for a in arr) + out1 = exec1.heads()[0].asnumpy() + out = sum(a.asnumpy() for a in arr) assert reldiff(out, out1) < 1e-6 - out_grad = mx.narray.create(shape) - out_grad.numpy[:] = np.random.uniform(-10, 10, shape) + out_grad = mx.narray.empty(shape) + out_grad[:] = np.random.uniform(-10, 10, shape) # backward exec1.backward([out_grad]) for a in arr_grad: - assert same(a.numpy, out_grad.numpy) + assert same(a.asnumpy(), out_grad.asnumpy()) def test_elementwise_sum(): @@ -58,27 +58,27 @@ def check_concat_with_shape(shapes): inputs = [mx.symbol.Variable('arg%d' % i) for i in range(n)] out = mx.symbol.Concat(*inputs, name='conc') - arr = [mx.narray.create(shape) for shape in shapes] + arr = [mx.narray.empty(shape) for shape in shapes] for i in range(n): arr[i][:] = shapes[i][1] - arr_np = [np.copy(narray.numpy) for narray in arr] - arr_grad = [mx.narray.create(shape) for shape in shapes] + arr_np = [np.copy(narray.asnumpy()) for narray in arr] + arr_grad = [mx.narray.empty(shape) for shape in shapes] args = out.list_arguments() arg_shapes, out_shapes, aux_shapes = out.infer_shape(**dict(zip(args, shapes))) - out_grad = mx.narray.create(out_shapes[0]) + out_grad = mx.narray.empty(out_shapes[0]) exec1 = out.bind(mx.Context('cpu'), args=arr, args_grad=arr_grad) exec1.forward() out1 = exec1.heads()[0] - ret = np.concatenate([narray.numpy for narray in arr], axis=1) - assert same(out1.numpy, ret) + ret = np.concatenate([narray.asnumpy() for narray in arr], axis=1) + assert same(out1.asnumpy(), ret) # backward out1.copyto(out_grad) out_grad[:] += 1 exec1.backward([out_grad]) for grad, np_grad in zip(arr_grad, arr_np): - assert same(grad.numpy, np_grad + 1) + assert same(grad.asnumpy(), np_grad + 1) def test_concat(): n = 2 diff --git a/tests/test_simple_engine.cc b/tests/test_simple_engine.cc index c65e847bf1ca..453a13e11d4b 100644 --- a/tests/test_simple_engine.cc +++ b/tests/test_simple_engine.cc @@ -61,7 +61,7 @@ int main() { printf("============= Test #3 ==============\n"); var = engine->NewVar(); oprs.clear(); - engine->WaitForVar(var); + engine->WaitToWrite(var); engine->PushDelete([](mxnet::RunContext) {}, mxnet::Context{}, var); engine->WaitForAll(); @@ -77,7 +77,7 @@ int main() { {}, {var})); engine->Push(oprs.at(0), mxnet::Context{}); LOG(INFO) << "Operator pushed, should wait for 2 seconds."; - engine->WaitForVar(var); + engine->WaitToWrite(var); LOG(INFO) << "OK, here I am."; for (auto&& i : oprs) { engine->DeleteOperator(i); @@ -97,7 +97,7 @@ int main() { {var}, {})); engine->Push(oprs.at(0), mxnet::Context{}); LOG(INFO) << "Operator pushed, should not wait."; - engine->WaitForVar(var); + engine->WaitToWrite(var); LOG(INFO) << "OK, here I am."; engine->WaitForAll(); LOG(INFO) << "That was 2 seconds.";