Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
merge with 38f7c55
Browse files Browse the repository at this point in the history
compiles on GPU

update check alloc:

Checkpoint. Pass elem-sum gpu test

bug fix for copyfromto. sparse sgd test pass on gpu

inefficient implementation for csr copy
  • Loading branch information
eric-haibin-lin committed May 15, 2017
1 parent cde5361 commit d6e6d98
Show file tree
Hide file tree
Showing 62 changed files with 4,469 additions and 268 deletions.
6 changes: 3 additions & 3 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -204,17 +204,17 @@ del /Q *.7z
// Python unittest for CPU
def python_ut(docker_type) {
timeout(time: max_time, unit: 'MINUTES') {
sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests --with-timer --verbose tests/python/unittest"
sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-2.7 --with-timer --verbose tests/python/unittest"
sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-3.4 --with-timer --verbose tests/python/unittest"
sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests --with-timer --verbose tests/python/train"
sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-2.7 --with-timer --verbose tests/python/train"
}
}

// GPU test has two parts. 1) run unittest on GPU, 2) compare the results on
// both CPU and GPU
def python_gpu_ut(docker_type) {
timeout(time: max_time, unit: 'MINUTES') {
sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests --with-timer --verbose tests/python/gpu"
sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-2.7 --with-timer --verbose tests/python/gpu"
sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-3.4 --with-timer --verbose tests/python/gpu"
}
}
Expand Down
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ ifeq ($(DEV), 1)
endif

# CFLAGS for debug
# FIXME(haibin) temporarily turn on -DDMLC_LOG_FATAL_THROW for debug
ifeq ($(DEBUG), 1)
CFLAGS += -g -O0
CFLAGS += -g -O0 -DDMLC_LOG_FATAL_THROW=1
else
CFLAGS += -O3
endif
Expand Down
78 changes: 78 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,38 @@ MXNET_DLL int MXNDArrayCreateEx(const mx_uint *shape,
int delay_alloc,
int dtype,
NDArrayHandle *out);


/*!
* \brief create an empty sparse NDArray with specified shape and data type
* \param storage_type the storage type of the ndarray
* \param shape the pointer to the shape
* \param ndim the dimension of the shape
* \param dev_type device type, specify device we want to take
* \param dev_id the device id of the specific device
* \param delay_alloc whether to delay allocation until
* the narray is first mutated
* \param dtype data type of created array
* \param num_aux the number of aux data to support this ndarray
* \param aux_type data type of the aux data for the created array
* \param aux_ndims the dimension of the shapes of aux data
* \param aux_shape the shapes of aux data
* \param out the returning handle
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayCreateSparseEx(int storage_type,
const mx_uint *shape,
mx_uint ndim,
int dev_type,
int dev_id,
int delay_alloc,
int dtype,
mx_uint num_aux,
int *aux_type,
mx_uint *aux_ndims,
const mx_uint *aux_shape,
NDArrayHandle *out);

/*!
* \brief create a NDArray handle that is loaded from raw bytes.
* \param buf the head of the raw bytes
Expand Down Expand Up @@ -363,6 +395,13 @@ MXNET_DLL int MXNDArraySlice(NDArrayHandle handle,
MXNET_DLL int MXNDArrayAt(NDArrayHandle handle,
mx_uint idx,
NDArrayHandle *out);

/*!
* \brief get the storage type of the array
*/
MXNET_DLL int MXNDArrayGetStorageType(NDArrayHandle handle,
int *out_storage_type);

/*!
* \brief Reshape the NDArray.
* \param handle the handle to the narray
Expand Down Expand Up @@ -401,6 +440,26 @@ MXNET_DLL int MXNDArrayGetData(NDArrayHandle handle,
*/
MXNET_DLL int MXNDArrayGetDType(NDArrayHandle handle,
int *out_dtype);

/*!
* \brief get the type of the ith aux data in NDArray
* \param handle the handle to the narray
* \param i the index of the aux data
* \param out_type pointer holder to get type of aux data
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayGetAuxType(NDArrayHandle handle,
mx_uint i,
int *out_type);

// Get the ith aux data blob wrapped in an NDArray
MXNET_DLL int MXNDArrayGetAuxNDArray(NDArrayHandle handle,
mx_uint i,
NDArrayHandle *out);

// Get the data blob wrapped in an NDArray
MXNET_DLL int MXNDArrayGetDataNDArray(NDArrayHandle handle,
NDArrayHandle *out);
/*!
* \brief get the context of the NDArray
* \param handle the handle to the narray
Expand Down Expand Up @@ -932,6 +991,25 @@ MXNET_DLL int MXSymbolInferType(SymbolHandle sym,
mx_uint *aux_type_size,
const int **aux_type_data,
int *complete);




/*!
* \brief infer storage type of unknown input types given the known one.
*/
MXNET_DLL int MXSymbolInferStorageType(SymbolHandle sym,
mx_uint num_args,
const char** keys,
const int *arg_storage_type_data,
mx_uint *in_storage_type_size,
const int **in_storage_type_data,
mx_uint *out_storage_type_size,
const int **out_storage_type_data,
mx_uint *aux_storage_type_size,
const int **aux_storage_type_data,
int *complete);

//--------------------------------------------
// Part 4: Executor interface
//--------------------------------------------
Expand Down
Loading

0 comments on commit d6e6d98

Please sign in to comment.