diff --git a/benchmark/python/sparse/cast_storage.py b/benchmark/python/sparse/cast_storage.py
new file mode 100644
index 000000000000..7ae537398c42
--- /dev/null
+++ b/benchmark/python/sparse/cast_storage.py
@@ -0,0 +1,99 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import ctypes
+
+from mxnet.test_utils import *
+import os
+import time
+import argparse
+
+from mxnet.base import check_call, _LIB
+
+parser = argparse.ArgumentParser(description="Benchmark cast storage operators",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+parser.add_argument('--num-omp-threads', type=int, default=1, help='number of omp threads to set in MXNet')
+args = parser.parse_args()
+
+def measure_cost(repeat, f, *args, **kwargs):
+ start = time.time()
+ results = []
+ for i in range(repeat):
+ (f(*args, **kwargs)).wait_to_read()
+ end = time.time()
+ diff = end - start
+ return diff / repeat
+
+
+def run_cast_storage_synthetic():
+ def dense_to_sparse(m, n, density, ctx, repeat, stype):
+ set_default_context(ctx)
+ data_shape = (m, n)
+ dns_data = rand_ndarray(data_shape, stype, density).tostype('default')
+ dns_data.wait_to_read()
+
+ # do one warm up run, verify correctness
+ assert same(mx.nd.cast_storage(dns_data, stype).asnumpy(), dns_data.asnumpy())
+
+ # start benchmarking
+ cost = measure_cost(repeat, mx.nd.cast_storage, dns_data, stype)
+ results = '{:10.1f} {:>10} {:8d} {:8d} {:10.2f}'.format(density*100, str(ctx), m, n, cost*1000)
+ print(results)
+
+ check_call(_LIB.MXSetNumOMPThreads(ctypes.c_int(args.num_omp_threads)))
+
+ # params
+ # m number of rows
+ # n number of columns
+ # density density of the matrix
+ # num_repeat number of benchmark runs to average over
+ # contexts mx.cpu(), mx.gpu()
+ # note: benchmark different contexts separately; to benchmark cpu, compile without CUDA
+ # benchmarks dns_to_csr, dns_to_rsp
+ m = [ 512, 512]
+ n = [50000, 100000]
+ density = [1.00, 0.80, 0.60, 0.40, 0.20, 0.10, 0.05, 0.02, 0.01]
+ num_repeat = 10
+ contexts = [mx.gpu()]
+ benchmarks = ["dns_to_csr", "dns_to_rsp"]
+
+ # run benchmark
+ for b in benchmarks:
+ stype = ''
+ print("==================================================")
+ if b is "dns_to_csr":
+ stype = 'csr'
+ print(" cast_storage benchmark: dense to csr, size m x n ")
+ elif b is "dns_to_rsp":
+ stype = 'row_sparse'
+ print(" cast_storage benchmark: dense to rsp, size m x n ")
+ else:
+ print("invalid benchmark: %s" %b)
+ continue
+ print("==================================================")
+ headline = '{:>10} {:>10} {:>8} {:>8} {:>10}'.format('density(%)', 'context', 'm', 'n', 'time(ms)')
+ print(headline)
+ for i in range(len(n)):
+ for ctx in contexts:
+ for den in density:
+ dense_to_sparse(m[i], n[i], den, ctx, num_repeat, stype)
+ print("")
+ print("")
+
+
+if __name__ == "__main__":
+ run_cast_storage_synthetic()
diff --git a/benchmark/python/sparse/dot.py b/benchmark/python/sparse/dot.py
new file mode 100644
index 000000000000..fe322821a09f
--- /dev/null
+++ b/benchmark/python/sparse/dot.py
@@ -0,0 +1,445 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import ctypes
+
+import os
+import time
+import argparse
+import subprocess
+import scipy.sparse as sp
+
+import mxnet as mx
+import numpy as np
+import numpy.random as rnd
+from mxnet.test_utils import rand_ndarray, set_default_context, assert_almost_equal
+from mxnet.base import check_call, _LIB
+from util import get_data, estimate_density
+
+PARSER = argparse.ArgumentParser(description="Benchmark sparse operators",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+PARSER.add_argument('--num-omp-threads', type=int,
+ default=1, help='number of omp threads to set in MXNet')
+PARSER.add_argument('--gpu', action='store_true',
+ help="to be run on gpu")
+# TODO: Use logging later
+PARSER.add_argument('--verbose', action='store_true',
+ help="Verbose output")
+ARGS = PARSER.parse_args()
+
+# some data information
+KDDA = {
+ 'data_mini': 'kdda.t.mini',
+ 'data_name': 'kdda.t',
+ 'data_origin_name': 'kdda.t.bz2',
+ 'url': "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/kdda.t.bz2",
+ 'feature_dim': 20216830,
+ 'm': [1, 8, 32],
+ 'batch_size': [64],
+ 'default_index': {'batch_size': 0,
+ 'output_dim': 2},
+ 'num_batches': 10
+}
+
+AVAZU = {
+ 'data_mini': 'avazu-app.t.mini',
+ 'data_name': 'avazu-app.t',
+ 'data_origin_name': 'avazu-app.t.bz2',
+ 'url': "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/avazu-app.t.bz2",
+ 'feature_dim': 1000000,
+ 'm': [1, 1000, 2000],
+ 'batch_size': [128, 256],
+ 'default_index': {'batch_size': 0,
+ 'output_dim': 1},
+ 'num_batches': 10
+}
+
+CRITEO = {
+ 'data_mini': 'criteo.t.mini',
+ 'data_name': 'criteo.t',
+ 'data_origin_name': 'criteo.t.bz2',
+ 'url' : "https://s3-us-west-2.amazonaws.com/sparse-dataset/criteo.t.bz2",
+ 'feature_dim': 8388621,
+ 'm': [1, 8, 16, 32, 64],
+ 'batch_size': [64, 128],
+ 'default_index': {'batch_size': 1,
+ 'output_dim': 3},
+ 'num_batches': 10
+}
+
+SYNTHETIC1 = {
+ 'feature_dim': [1000000],
+ 'm': [256, 1000],
+ 'density': [0.001, 0.005, 0.01, 0.02, 0.05,
+ 0.1, 0.2, 0.5, 0.65],
+ 'batch_size': [64, 128],
+ 'default_index': {'batch_size': 1,
+ 'density': 2,
+ 'output_dim': 1,
+ 'feature_dim': 0},
+ 'num_repeat': 10
+}
+
+SYNTHETIC2 = {
+ 'feature_dim': [8000000, 16000000],
+ 'm': [1, 32],
+ 'density': [0.001, 0.005, 0.01, 0.02, 0.05,
+ 0.1, 0.2, 0.5, 0.65],
+ 'batch_size': [64, 128],
+ 'default_index': {'batch_size': 1,
+ 'density': 2,
+ 'output_dim': 1,
+ 'feature_dim': 0},
+ 'num_repeat': 10
+}
+
+def measure_cost(repeat, scipy_trans_lhs, scipy_dns_lhs, func_name, *args, **kwargs):
+ """Measure time cost of running a function
+ """
+ mx.nd.waitall()
+ args_list = []
+ for arg in args:
+ args_list.append(arg)
+ start = time.time()
+ if scipy_trans_lhs:
+ args_list[0] = np.transpose(args_list[0]) if scipy_dns_lhs else sp.spmatrix.transpose(args_list[0])
+ for _ in range(repeat):
+ func_name(*args_list, **kwargs)
+ mx.nd.waitall()
+ end = time.time()
+ diff = end - start
+ return diff / repeat
+
+
+def _get_iter(path, data_shape, batch_size):
+ data_train = mx.io.LibSVMIter(data_libsvm=path,
+ data_shape=data_shape,
+ batch_size=batch_size)
+ data_iter = iter(data_train)
+ return data_iter
+
+
+def _line_count(path):
+ return int(subprocess.check_output('wc -l {}'.format(path), shell=True).split()[0])
+
+
+def _compare_sparse_dense(data_dir, file_name, mini_file_name, feature_dim,
+ output_dim, density, batch_size, num_batches=3, num_repeat=5, transpose=False,
+ rsp=False):
+
+ def create_mini_path(mini_path, path, num_batches):
+ """Samples batches of size: batch_size, total number: num_batches
+ from the dataset files for running benchmarks"""
+ if not os.path.exists(mini_path):
+ last = _line_count(path) - num_batches * batch_size
+ last = last if last >= 1 else 1
+ start = int(rnd.uniform(1, last))
+ os.system("sed -n '%d,%dp' %r > %r"
+ %(start, start + num_batches * batch_size, path, mini_path))
+ assert os.path.exists(mini_path)
+
+
+ def run_benchmark(mini_path):
+ """Run benchmarks
+ """
+ data_shape = (feature_dim, )
+ train_iter = _get_iter(mini_path, data_shape, batch_size)
+ weight_row_dim = batch_size if transpose else feature_dim
+ weight_shape = (weight_row_dim, output_dim)
+ if not rsp:
+ weight = mx.nd.random_uniform(low=0, high=1, shape=weight_shape)
+ else:
+ weight = rand_ndarray(weight_shape, "row_sparse", density=0.05, distribution="uniform")
+ total_cost = {}
+ average_cost = {}
+ count = 0
+ total_cost["sparse"] = 0.
+ total_cost["dense"] = 0.
+ for _ in train_iter:
+ csr_data = train_iter.getdata()
+ dns_data = csr_data.tostype('default')
+ cost_sparse = measure_cost(num_repeat, False, False, mx.nd.dot, csr_data, weight, transpose_a=transpose)
+ cost_dense = measure_cost(num_repeat, False, False, mx.nd.dot, dns_data, weight, transpose_a=transpose)
+ total_cost["sparse"] += cost_sparse
+ total_cost["dense"] += cost_dense
+ count = count + 1
+ average_cost["sparse"] = total_cost["sparse"] / count
+ average_cost["dense"] = total_cost["dense"] / count
+ return (average_cost["sparse"], average_cost["dense"])
+
+
+ def print_result(average_cost_sparse, average_cost_dense):
+ """Print result of comparison between sparse and dense
+ """
+ ratio = average_cost_dense / average_cost_sparse
+ fmt = '{:15.4f} {:10d} {:10d} {:10d} {:20.2f} {:15.2f} {:15.2f} {:10} {:10}'
+ print(fmt.format(density * 100, batch_size, output_dim, feature_dim,
+ ratio, average_cost_dense*1000, average_cost_sparse*1000,
+ transpose, rsp))
+
+ mini_path = os.path.join(data_dir, mini_file_name)
+ path = os.path.join(data_dir, file_name)
+ create_mini_path(mini_path, path, num_batches)
+ average_cost_sparse, average_cost_dense = run_benchmark(mini_path)
+ print_result(average_cost_sparse, average_cost_dense)
+
+
+def test_dot_real(data_dict):
+ """Dot operator testing with real datasets"""
+ data_dir = os.path.join(os.getcwd(), 'data')
+
+ path = os.path.join(data_dir, data_dict['data_name'])
+ if not os.path.exists(path):
+ get_data(
+ data_dir,
+ data_dict['data_name'],
+ data_dict['url'],
+ data_dict['data_origin_name']
+ )
+ assert os.path.exists(path)
+
+ k = data_dict['feature_dim']
+ m = data_dict['m']
+ batch_size_list = data_dict['batch_size']
+
+ default_output_index = data_dict['default_index']['output_dim']
+ default_batch_size_index = data_dict['default_index']['batch_size']
+ density = estimate_density(path, data_dict['feature_dim'])
+ num_batches = data_dict['num_batches']
+
+ assert default_batch_size_index < len(batch_size_list)
+ assert default_output_index < len(m)
+ if ARGS.verbose:
+ print("Running Benchmarking on %r data") % data_dict['data_mini']
+ print('{:>15} {:>10} {:>10} {:>10} {:>20} {:>15} {:>15} {:>10} {:>10}'.format('density(%)',
+ 'n',
+ 'm',
+ 'k',
+ 't_dense/t_sparse',
+ 't_dense(ms)',
+ 't_sparse(ms)',
+ 'is_transpose',
+ 'rhs_rsp'))
+
+
+ for output_dim in m:
+ _compare_sparse_dense(data_dir, data_dict['data_name'], data_dict['data_mini'],
+ k, output_dim, density,
+ batch_size_list[default_batch_size_index], num_batches)
+ _compare_sparse_dense(data_dir, data_dict['data_name'], data_dict['data_mini'],
+ k, output_dim, density,
+ batch_size_list[default_batch_size_index], num_batches,
+ transpose=True)
+ _compare_sparse_dense(data_dir, data_dict['data_name'], data_dict['data_mini'],
+ k, output_dim, density,
+ batch_size_list[default_batch_size_index], num_batches, rsp=True)
+
+ for batch_size in batch_size_list:
+ _compare_sparse_dense(data_dir, data_dict['data_name'], data_dict['data_mini'],
+ k, m[default_output_index], density, batch_size, num_batches)
+ _compare_sparse_dense(data_dir, data_dict['data_name'], data_dict['data_mini'],
+ k, m[default_output_index], density, batch_size, num_batches,
+ transpose=True)
+ _compare_sparse_dense(data_dir, data_dict['data_name'], data_dict['data_mini'],
+ k, output_dim, density,
+ batch_size_list[default_batch_size_index], num_batches, rsp=True)
+
+
+def test_dot_synthetic(data_dict):
+ """benchmark sparse mxnet dot and scipy dot operator with matrices of given density.
+ `t_sparse` is the runtime of the invoked sparse dot operator in ms, while `t_dense` is the
+ runtime of dot(dns, dns), with the same matrices except that they are in default storage type.
+ """
+ # Benchmark MXNet and Scipys dot operator
+ def bench_dot(lhs_shape, rhs_shape, lhs_stype, rhs_stype,
+ lhs_den, rhs_den, trans_lhs, ctx, num_repeat=10, fw="mxnet", distribution="uniform"):
+ set_default_context(ctx)
+ assert fw == "mxnet" or fw == "scipy"
+ # Set funcs
+ dot_func_sparse = mx.nd.dot if fw == "mxnet" else sp.spmatrix.dot
+ dot_func_dense = mx.nd.dot if fw == "mxnet" else np.dot
+ # Create matrix instances
+ lhs_nd = rand_ndarray(lhs_shape, lhs_stype, density=lhs_den, distribution=distribution)
+ # only uniform distribution supported for rhs
+ rhs_nd = rand_ndarray(rhs_shape, rhs_stype, density=rhs_den, distribution="uniform")
+ lhs_dns = None
+ rhs_dns = None
+ dense_cost = None
+ sparse_cost = None
+
+ if fw == "mxnet":
+ lhs_dns = lhs_nd if lhs_stype == 'default' else lhs_nd.tostype('default')
+ rhs_dns = rhs_nd if rhs_stype == 'default' else rhs_nd.tostype('default')
+ # One warm up run, verify correctness
+ out = dot_func_sparse(lhs_nd, rhs_dns, trans_lhs)
+ out_expected = dot_func_dense(lhs_dns, rhs_dns, trans_lhs)
+ assert_almost_equal(out.asnumpy(), out_expected.asnumpy(), rtol=1e-1, atol=1e-1)
+ sparse_cost = measure_cost(num_repeat, False, False, dot_func_sparse, lhs_nd, rhs_nd, trans_lhs)
+ dense_cost = measure_cost(num_repeat, False, False, dot_func_dense, lhs_dns, rhs_dns, trans_lhs)
+ else:
+ lhs_dns = lhs_nd.asnumpy()
+ rhs_dns = rhs_nd.asnumpy()
+ lhs_nd = sp.csr_matrix(lhs_nd.asnumpy())
+ rhs_nd = rhs_nd.asnumpy()
+ # One warm up run, verify correctness
+ lhs_nd_copy = sp.spmatrix.transpose(lhs_nd) if trans_lhs else lhs_nd
+ out = dot_func_sparse(lhs_nd_copy, rhs_dns)
+ sparse_cost = measure_cost(num_repeat, trans_lhs, False, dot_func_sparse, lhs_nd, rhs_nd)
+ dense_cost = measure_cost(num_repeat, trans_lhs, True, dot_func_dense, lhs_dns, rhs_dns)
+
+ speedup = dense_cost / sparse_cost
+ # Print results
+ m = lhs_shape[0]
+ k = lhs_shape[1]
+ n = rhs_shape[1]
+ result_pattern = '{:15.1f} {:15.1f} {:>10} {:8d} {:8d} {:8d} {:13.2f} {:13.2f} {:8.2f}'
+ results = result_pattern.format(lhs_den*100,
+ rhs_den*100,
+ str(ctx),
+ m,
+ k,
+ n,
+ sparse_cost*1000,
+ dense_cost*1000,
+ speedup)
+ print(results)
+
+ def print_benchmark_info(lhs, rhs, lhs_trans, fw):
+ trans_str = "^T" if lhs_trans else ""
+ print("========================================================")
+ print(" %s sparse dot benchmark: dot(%s, %s) = %s ") % (fw, lhs, rhs, rhs)
+ print(" (matrix multiplication: (m x k)%s * (k x n) = m x n) ") % (trans_str)
+ print("========================================================")
+ headline_pattern = '{:>15} {:>15} {:>10} {:>8} {:>8} {:>8} {:>13} {:>13} {:>8}'
+ headline = headline_pattern.format('lhs_density(%)',
+ 'rhs_density(%)',
+ 'context',
+ 'm', 'k', 'n',
+ 't_sparse(ms)',
+ 't_dense(ms)',
+ 'speedup')
+ print(headline)
+
+
+ def run_benchmark(ctx=None, lhs="csr", lhs_trans=False, rhs="dns", fw="mxnet", rhs_density=1,
+ distribution="uniform"):
+ if lhs != "csr":
+ raise ValueError("Value other than csr for lhs not supported")
+ if rhs_density > 1 or rhs_density < 0:
+ raise ValueError("rhs_density has to be between 0 and 1")
+
+ print_benchmark_info(lhs, rhs, lhs_trans, fw)
+
+
+ lhs_stype = "csr"
+ rhs_stype = "row_sparse" if rhs == "rsp" else "default"
+
+ feature_dim_list = data_dict['feature_dim']
+ output_dim_list = data_dict['m']
+ batch_size_list = data_dict['batch_size']
+ density_list = data_dict['density']
+
+ default_output_index = data_dict['default_index']['output_dim']
+ default_batch_size_index = data_dict['default_index']['batch_size']
+ default_feature_index = data_dict['default_index']['feature_dim']
+ default_density_index = data_dict['default_index']['density']
+ num_repeat = data_dict['num_repeat']
+
+ for output_dim in output_dim_list:
+ if lhs_trans:
+ output_row_dim = batch_size_list[default_batch_size_index]
+ else:
+ output_row_dim = feature_dim_list[default_feature_index]
+ bench_dot((batch_size_list[default_batch_size_index],
+ feature_dim_list[default_feature_index]),
+ (output_row_dim, output_dim),
+ lhs_stype, rhs_stype,
+ density_list[default_density_index], rhs_density,
+ lhs_trans, ctx, num_repeat=num_repeat,
+ fw=fw, distribution=distribution)
+
+ for feature_dim in feature_dim_list:
+ if lhs_trans:
+ output_row_dim = batch_size_list[default_batch_size_index]
+ else:
+ output_row_dim = feature_dim
+ bench_dot((batch_size_list[default_batch_size_index], feature_dim),
+ (output_row_dim, output_dim_list[default_output_index]),
+ lhs_stype, rhs_stype, density_list[default_density_index], rhs_density,
+ lhs_trans, ctx, num_repeat=num_repeat, fw=fw, distribution=distribution)
+
+ for batch_size in batch_size_list:
+ if lhs_trans:
+ output_row_dim = batch_size
+ else:
+ output_row_dim = feature_dim_list[default_feature_index]
+ bench_dot((batch_size, feature_dim_list[default_feature_index]),
+ (output_row_dim,
+ output_dim_list[default_output_index]),
+ lhs_stype, rhs_stype, density_list[default_density_index],
+ rhs_density, lhs_trans, ctx, num_repeat=num_repeat,
+ fw=fw, distribution=distribution)
+
+ for density in density_list:
+ if lhs_trans:
+ output_row_dim = batch_size_list[default_batch_size_index]
+ else:
+ output_row_dim = feature_dim_list[default_feature_index]
+ bench_dot((batch_size_list[default_batch_size_index],
+ feature_dim_list[default_feature_index]),
+ (output_row_dim,
+ output_dim_list[default_output_index]),
+ lhs_stype, rhs_stype, density, rhs_density, lhs_trans, ctx,
+ num_repeat=num_repeat, fw=fw, distribution=distribution)
+
+ check_call(_LIB.MXSetNumOMPThreads(ctypes.c_int(ARGS.num_omp_threads)))
+ context = mx.gpu() if ARGS.gpu else mx.cpu()
+ # TODO(anirudh): make the data dicts to config which can be passed at runtime
+ distributions = ["uniform", "powerlaw"]
+ for distribution in distributions:
+ run_benchmark(context, lhs="csr",
+ rhs="default", lhs_trans=False,
+ fw="mxnet", rhs_density=1,
+ distribution=distribution)
+ run_benchmark(context, lhs="csr",
+ rhs="default", lhs_trans=True,
+ fw="mxnet", rhs_density=1,
+ distribution=distribution)
+ run_benchmark(context, lhs="csr",
+ rhs="rsp", lhs_trans=False,
+ fw="mxnet", rhs_density=0.05,
+ distribution=distribution)
+ if not ARGS.gpu:
+ run_benchmark(context, lhs="csr",
+ rhs="default", lhs_trans=False,
+ fw="scipy", rhs_density=1,
+ distribution=distribution)
+ run_benchmark(context, lhs="csr",
+ rhs="default", lhs_trans=True,
+ fw="scipy", rhs_density=1,
+ distribution=distribution)
+
+
+if __name__ == "__main__":
+ begin_time = time.time()
+ test_dot_real(KDDA)
+ test_dot_real(AVAZU)
+ test_dot_real(CRITEO)
+ test_dot_synthetic(SYNTHETIC1)
+ test_dot_synthetic(SYNTHETIC2)
+ total_time = time.time() - begin_time
+ print("total time is %f") % total_time
diff --git a/benchmark/python/sparse/sparse_end2end.py b/benchmark/python/sparse/sparse_end2end.py
new file mode 100644
index 000000000000..e9d8bf884713
--- /dev/null
+++ b/benchmark/python/sparse/sparse_end2end.py
@@ -0,0 +1,249 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from mxnet.test_utils import *
+import time
+import argparse
+import os
+
+parser = argparse.ArgumentParser(description="Run sparse linear regression " \
+ "with distributed kvstore",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+parser.add_argument('--profiler', type=int, default=0,
+ help='whether to use profiler')
+parser.add_argument('--num-epoch', type=int, default=1,
+ help='number of epochs to train')
+parser.add_argument('--batch-size', type=int, default=512,
+ help='number of examples per batch')
+parser.add_argument('--num-batch', type=int, default=99999999,
+ help='number of batches per epoch')
+parser.add_argument('--dummy-iter', type=int, default=0,
+ help='whether to use dummy iterator to exclude io cost')
+parser.add_argument('--kvstore', type=str, default='local',
+ help='what kvstore to use [local, dist_sync, etc]')
+parser.add_argument('--log-level', type=str, default='debug',
+ help='logging level [debug, info, error]')
+parser.add_argument('--dataset', type=str, default='avazu',
+ help='what test dataset to use')
+parser.add_argument('--num-gpu', type=int, default=0,
+ help='number of gpus to use. 0 means using cpu(0);'
+ 'otherwise, use gpu(0),...,gpu(num_gpu-1)')
+parser.add_argument('--output-dim', type=int, default=4,
+ help='number of columns of the forward output')
+parser.add_argument('--dummy-metric', type=int, default=0,
+ help='whether to call update_metric')
+
+
+def get_libsvm_data(data_dir, data_name, url, data_origin_name):
+ if not os.path.isdir(data_dir):
+ os.system("mkdir " + data_dir)
+ os.chdir(data_dir)
+ if (not os.path.exists(data_name)):
+ import urllib
+ zippath = os.path.join(data_dir, data_origin_name)
+ urllib.urlretrieve(url, zippath)
+ os.system("bzip2 -d %r" % data_origin_name)
+ os.chdir("..")
+
+
+class DummyIter(mx.io.DataIter):
+ "A dummy iterator that always return the same batch, used for speed testing"
+ def __init__(self, real_iter):
+ super(DummyIter, self).__init__()
+ self.real_iter = real_iter
+ self.provide_data = real_iter.provide_data
+ self.provide_label = real_iter.provide_label
+ self.batch_size = real_iter.batch_size
+
+ for batch in real_iter:
+ self.the_batch = batch
+ break
+
+ def __iter__(self):
+ return self
+
+ def next(self):
+ return self.the_batch
+
+# testing dataset sources
+avazu = {
+ 'data_name': 'avazu-app.t',
+ 'data_origin_name': 'avazu-app.t.bz2',
+ 'url': "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/avazu-app.t.bz2",
+ 'feature_dim': 1000000,
+}
+
+kdda = {
+ 'data_name': 'kdda.t',
+ 'data_origin_name': 'kdda.t.bz2',
+ 'url': "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/kdda.t.bz2",
+ 'feature_dim': 20216830,
+}
+
+datasets = { 'kdda' : kdda, 'avazu' : avazu }
+
+
+def get_sym(feature_dim):
+ x = mx.symbol.Variable("data", stype='csr')
+ norm_init = mx.initializer.Normal(sigma=0.01)
+ w = mx.symbol.Variable("w", shape=(feature_dim, args.output_dim), init=norm_init, stype='row_sparse')
+ embed = mx.symbol.dot(x, w)
+ y = mx.symbol.Variable("softmax_label")
+ model = mx.symbol.SoftmaxOutput(data=embed, label=y, name="out")
+ return model
+
+
+def row_sparse_pull(kv, key, data, slices, weight_array, priority):
+ # if have kvstore, need to pull corresponding rows of
+ # the weights to each context
+ # column indices (NDArray type) of the csr data
+ # used as the row_idx of the weight row-sparse matrix
+ row_indices = data.indices
+ if len(slices) == 1:
+ kv.row_sparse_pull(key, weight_array, priority=priority, row_ids=row_indices)
+ else: # more than one slices, multi-GPU training. Need to retain weight rows according to data slices
+ # TODO(junwu):
+ # the following line blocks, may need to pre-compute
+ # and cache it outside the for loop
+ indptr = data.indptr.asnumpy()
+ row_idx_array = []
+ for s in slices:
+ row_idx_array.append(row_indices[indptr[s.start]:indptr[s.stop]])
+ kv.row_sparse_pull(key, weight_array, priority=priority, row_ids=row_idx_array)
+
+
+if __name__ == '__main__':
+
+ # arg parser
+ args = parser.parse_args()
+ num_epoch = args.num_epoch
+ num_batch = args.num_batch
+ kvstore = args.kvstore
+ profiler = args.profiler > 0
+ batch_size = args.batch_size if args.num_gpu == 0 else args.num_gpu * args.batch_size
+ dummy_iter = args.dummy_iter
+ dataset = args.dataset
+ log_level = args.log_level
+ contexts = mx.context.cpu(0) if args.num_gpu < 1\
+ else [mx.context.gpu(i) for i in range(args.num_gpu)]
+
+ # create kvstore when there are gpus
+ kv = mx.kvstore.create(kvstore) if args.num_gpu >= 1 else None
+ rank = kv.rank if kv is not None else 0
+ num_worker = kv.num_workers if kv is not None else 1
+
+ # only print log for rank 0 worker
+ import logging
+ if rank != 0:
+ log_level = logging.ERROR
+ elif log_level == 'DEBUG':
+ log_level = logging.DEBUG
+ else:
+ log_level = logging.INFO
+ head = '%(asctime)-15s %(message)s'
+ logging.basicConfig(level=log_level, format=head)
+
+ # dataset
+ assert(dataset in datasets), "unknown dataset " + dataset
+ metadata = datasets[dataset]
+ feature_dim = metadata['feature_dim']
+ if logging:
+ logging.debug('preparing data ... ')
+ data_dir = os.path.join(os.getcwd(), 'data')
+ path = os.path.join(data_dir, metadata['data_name'])
+ if not os.path.exists(path):
+ get_libsvm_data(data_dir, metadata['data_name'], metadata['url'],
+ metadata['data_origin_name'])
+ assert os.path.exists(path)
+
+ # data iterator
+ train_data = mx.io.LibSVMIter(data_libsvm=path, data_shape=(feature_dim,),
+ batch_size=batch_size, num_parts=num_worker,
+ part_index=rank)
+ if dummy_iter:
+ train_data = DummyIter(train_data)
+
+ # model
+ model = get_sym(feature_dim)
+
+ # module
+ mod = mx.mod.Module(symbol=model, data_names=['data'],
+ label_names=['softmax_label'], context=contexts)
+ mod.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label)
+ mod.init_params(initializer=mx.init.Uniform(scale=.1))
+ sgd = mx.optimizer.SGD(momentum=0.0, clip_gradient=5.0,
+ learning_rate=0.1, rescale_grad=1.0/batch_size/num_worker)
+ mod.init_optimizer(optimizer=sgd, kvstore=kv)
+ # use accuracy as the metric
+ metric = mx.metric.create('acc')
+
+ index = mod._exec_group.param_names.index('w')
+ # weight_array bound to executors of the contexts
+ weight_array = mod._exec_group.param_arrays[index]
+
+ mx.nd.waitall() # sync point for initialization
+ # start profiler
+ if profiler:
+ device = 'cpu'
+ if args.num_gpu > 0:
+ device = 'gpu' + str(args.num_gpu)
+ name = 'profile_' + args.dataset + '_' + device + '_nworker' + str(num_worker)\
+ + '_batchsize' + str(args.batch_size) + '_outdim' + str(args.output_dim) + '.json'
+ mx.profiler.profiler_set_config(mode='all', filename=name)
+ mx.profiler.profiler_set_state('run')
+
+ logging.debug('start training ...')
+ start = time.time()
+ data_iter = iter(train_data)
+ for epoch in range(num_epoch):
+ nbatch = 0
+ end_of_batch = False
+ data_iter.reset()
+ metric.reset()
+ next_batch = next(data_iter)
+ if kv is not None:
+ row_sparse_pull(kv, 'w', next_batch.data[0], mod._exec_group.slices, weight_array, -index)
+ while not end_of_batch:
+ nbatch += 1
+ batch = next_batch
+
+ mod.forward_backward(batch)
+ # update parameters
+ mod.update()
+
+ try:
+ # pre fetch next batch
+ next_batch = next(data_iter)
+ if nbatch == num_batch:
+ raise StopIteration
+ if kv is not None:
+ row_sparse_pull(kv, 'w', next_batch.data[0], mod._exec_group.slices, weight_array, -index)
+ except StopIteration:
+ end_of_batch = True
+ # accumulate prediction accuracy
+ if args.dummy_metric == 0:
+ mod.update_metric(metric, batch.label)
+ else: # call waitall to replace update_metric as sync point
+ mx.nd.waitall() # sync point for the current minibatch
+ logging.info('epoch %d, %s' % (epoch, metric.get()))
+ if epoch == 0:
+ print "num_batches = ", nbatch
+ if profiler:
+ mx.profiler.profiler_set_state('stop')
+ end = time.time()
+ time_cost = end - start
+ logging.info('num_worker = ' + str(num_worker) + ', time cost = ' + str(time_cost))
diff --git a/benchmark/python/sparse/sparse_op.py b/benchmark/python/sparse/sparse_op.py
new file mode 100644
index 000000000000..0683aa84eacb
--- /dev/null
+++ b/benchmark/python/sparse/sparse_op.py
@@ -0,0 +1,245 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import ctypes
+
+from mxnet.test_utils import *
+import scipy.sparse as sp
+import os
+import time
+import argparse
+
+from mxnet.base import check_call, _LIB
+from util import get_data, estimate_density
+
+parser = argparse.ArgumentParser(description="Benchmark sparse operators",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+parser.add_argument('--num-omp-threads', type=int, default=1, help='number of omp threads to set in MXNet')
+args = parser.parse_args()
+
+# some data information
+kdda = {
+ 'data_mini': 'kdda.t.mini',
+ 'data_name': 'kdda.t',
+ 'data_origin_name': 'kdda.t.bz2',
+ 'url': "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/kdda.t.bz2",
+ 'feature_dim': 20216830,
+ 'm': 200,
+ 'batch_size': [64]
+}
+
+avazu = {
+ 'data_mini': 'avazu-app.t.mini',
+ 'data_name': 'avazu-app.t',
+ 'data_origin_name': 'avazu-app.t.bz2',
+ 'url': "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/avazu-app.t.bz2",
+ 'feature_dim': 1000000,
+ 'm': 500,
+ 'batch_size': [64, 128]
+}
+
+
+def measure_cost(repeat, f, *args, **kwargs):
+ # start bench
+ start = time.time()
+ results = []
+ for i in range(repeat):
+ results.append(f(*args, **kwargs))
+ for result in results:
+ result.wait_to_read()
+ end = time.time()
+ diff = end - start
+ return diff / repeat
+
+
+def test_dot_real(data_dict):
+ def get_iter(path, data_shape, batch_size):
+ data_train = mx.io.LibSVMIter(data_libsvm=path,
+ data_shape=data_shape,
+ batch_size=batch_size)
+ data_iter = iter(data_train)
+ return data_iter
+
+ data_dir = os.path.join(os.getcwd(), 'data')
+
+ path = os.path.join(data_dir, data_dict['data_name'])
+ if not os.path.exists(path):
+ get_data(
+ data_dir,
+ data_dict['data_name'],
+ data_dict['url'],
+ data_dict['data_origin_name']
+ )
+ assert os.path.exists(path)
+
+ k = data_dict['feature_dim']
+ m = data_dict['m']
+ density = estimate_density(path, data_dict['feature_dim'])
+
+ mini_path = os.path.join(data_dir, data_dict['data_mini'])
+ if not os.path.exists(mini_path):
+ os.system("head -n 2000 %r > %r" % (path, mini_path))
+ assert os.path.exists(mini_path)
+
+ print "Running Benchmarking on %r data" % data_dict['data_mini']
+ for batch_size in data_dict['batch_size']: # iterator through different batch size of choice
+ print "batch_size is %d" % batch_size
+ # model
+ data_shape = (k, )
+ train_iter = get_iter(mini_path, data_shape, batch_size)
+ weight = mx.nd.random_uniform(low=0, high=1, shape=(k, m))
+
+ csr_data = []
+ dns_data = []
+ num_batch = 0
+ for batch in train_iter:
+ data = train_iter.getdata()
+ csr_data.append(data)
+ dns_data.append(data.tostype('default'))
+ num_batch += 1
+ bag_of_data = [csr_data, dns_data]
+ num_repeat = 5
+ costs = []
+ for d in bag_of_data:
+ weight.wait_to_read()
+ cost = 0.
+ count = 0
+ for d_batch in d:
+ d_batch.wait_to_read()
+ cost += measure_cost(num_repeat, mx.nd.dot, d_batch, weight)
+ count += 1
+ costs.append(cost/count)
+ t_sparse = costs[0]
+ t_dense = costs[1]
+ ratio = t_dense / t_sparse
+ print('density(%)\tn\tm\tk\tt_dense/t_sparse\tt_dense\tt_sparse')
+ fmt = "%0.4f\t\t%d\t%d\t%d\t%0.2f\t\t\t%0.4f\t%0.6f"
+ print(fmt % (density * 100, batch_size, m, k, ratio, t_dense, t_sparse))
+
+
+def test_dot_synthetic():
+ """benchmark mx.nd.dot(sparse_ndarray, dense_ndarray) with given density.
+ `t_sparse` is the time cost of dot(csr, dns), while `t_dense` is the time cost
+ of dot(dns, dns), with the same matrix except that it is in default storage type.
+ """
+ def measure_cost_forward_baseline(repeat, dot, lhs, rhs):
+ start = time.time()
+ for i in range(repeat):
+ dot(lhs, rhs)
+ end = time.time()
+ diff = end - start
+ return diff / repeat
+
+ def measure_cost_backward_baseline(repeat, dot, transpose, lhs, rhs):
+ start = time.time()
+ for i in range(repeat):
+ dot(transpose(lhs), rhs)
+ end = time.time()
+ diff = end - start
+ return diff / repeat
+
+ def bench_dot_forward(m, k, n, density, ctx, repeat):
+ set_default_context(ctx)
+ dns = mx.nd.random_uniform(shape=(k, n)).copyto(ctx)
+ data_shape = (m, k)
+ csr_data = rand_ndarray(data_shape, 'csr', density)
+ dns_data = csr_data.tostype('default')
+ rhs_dns_np = dns.asnumpy()
+ lhs_csr_sp = sp.csr_matrix(dns_data.asnumpy()) # csr in scipy
+ lhs_dns_np = lhs_csr_sp.tostype('default')
+
+ data = [dns_data, csr_data]
+ costs = []
+ for d in data:
+ dns.wait_to_read()
+ d.wait_to_read()
+ cost = measure_cost(repeat, mx.nd.dot, d, dns)
+ costs.append(cost)
+ ratio = costs[0] / costs[1]
+
+ costs_baseline = []
+ cost = measure_cost_forward_baseline(repeat, np.dot, lhs_dns_np, rhs_dns_np)
+ costs_baseline.append(cost)
+ cost = measure_cost_forward_baseline(repeat, sp.spmatrix.dot, lhs_csr_sp, rhs_dns_np)
+ costs_baseline.append(cost)
+ ratio_baseline = costs_baseline[0] / costs_baseline[1]
+ fmt = "%0.1f\t\t%s\t%d\t%d\t%d\t%0.2f\t\t\t%0.2f\t%0.5f\t\t%0.2f\t\t\t\t%0.6f\t%0.5f"
+ print(fmt % (density * 100, str(ctx), n, m, k, ratio, costs[0], costs[1],
+ ratio_baseline, costs_baseline[0], costs_baseline[1]))
+
+ def bench_dot_backward(m, k, n, density, ctx, repeat):
+ set_default_context(ctx)
+ dns = mx.nd.random_uniform(shape=(m, n)).copyto(ctx)
+ data_shape = (m, k)
+ csr_data = rand_ndarray(data_shape, 'csr', density)
+ dns_data = csr_data.tostype('default')
+ rhs_dns_np = dns.asnumpy()
+ lhs_csr_sp = sp.csr_matrix(dns_data.asnumpy())
+ lhs_dns_np = lhs_csr_sp.tostype('default')
+
+ data = [dns_data, csr_data]
+ costs = []
+ for d in data:
+ dns.wait_to_read()
+ d.wait_to_read()
+ cost = measure_cost(repeat, mx.nd.dot, d, dns, transpose_a=True)
+ costs.append(cost)
+ ratio = costs[0] / costs[1]
+
+ costs_baseline = []
+ cost = measure_cost_backward_baseline(repeat, np.dot, np.transpose, lhs_dns_np, rhs_dns_np)
+ costs_baseline.append(cost)
+ cost = measure_cost_backward_baseline(repeat, sp.spmatrix.dot, sp.spmatrix.transpose, lhs_csr_sp, rhs_dns_np)
+ costs_baseline.append(cost)
+ ratio_baseline = costs_baseline[0] / costs_baseline[1]
+ fmt = "%0.1f\t\t%s\t%d\t%d\t%d\t%0.2f\t\t\t%0.2f\t%0.5f\t\t%0.2f\t\t\t\t%0.6f\t%0.5f"
+ print(fmt % (density * 100, str(ctx), n, m, k, ratio, costs[0], costs[1],
+ ratio_baseline, costs_baseline[0], costs_baseline[1]))
+
+ print("A = sparse NDArray of shape(m, k)")
+ print("B = dense NDArray of shape(k, n)")
+ print("dot_forward\tdot(csr, dns)")
+ print('density(%)\tcontext\tn\tm\tk\tt_dense/t_sparse\tt_dense\tt_sparse'
+ '\tt_scipy_dense/t_scipy_sparse\tt_scipy_dense\tt_scipy_sparse')
+
+ check_call(_LIB.MXSetNumOMPThreads(ctypes.c_int(args.num_omp_threads)))
+ # TODO(haibin) make these runtime options
+ m = 512
+ k = [50000, 100000]
+ n = [64, 128]
+ density = [1.00, 0.90, 0.70, 0.50, 0.30, 0.20, 0.10, 0.07, 0.05, 0.02, 0.01, 0.005, 0.001]
+ num_repeat = 10
+ # contexts = [mx.cpu(), mx.gpu(0)]
+ contexts = [mx.cpu()]
+ for i in range(2):
+ for ctx in contexts:
+ for den in density:
+ bench_dot_forward(m, k[i], n[i], den, ctx, num_repeat)
+
+ print("dot_backward\tdot(csr.T, dns)")
+ print('density(%)\tcontext\tn\tm\tk\tt_dense/t_sparse\tt_dense\tt_sparse'
+ '\tt_scipy_dense/t_scipy_sparse\tt_scipy_dense\tt_scipy_sparse')
+ for i in range(2):
+ for ctx in contexts:
+ for den in density:
+ bench_dot_backward(m, k[i], n[i], den, ctx, num_repeat)
+
+
+if __name__ == "__main__":
+ test_dot_real(avazu)
+ test_dot_real(kdda)
+ test_dot_synthetic()
diff --git a/benchmark/python/sparse/util.py b/benchmark/python/sparse/util.py
new file mode 100644
index 000000000000..947ff4a65037
--- /dev/null
+++ b/benchmark/python/sparse/util.py
@@ -0,0 +1,50 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+import random
+
+
+def get_data(data_dir, data_name, url, data_origin_name):
+ if not os.path.isdir(data_dir):
+ os.system("mkdir " + data_dir)
+ os.chdir(data_dir)
+ if (not os.path.exists(data_name)):
+ import urllib
+ zippath = os.path.join(data_dir, data_origin_name)
+ urllib.urlretrieve(url, zippath)
+ os.system("bzip2 -d %r" % data_origin_name)
+ os.chdir("..")
+
+
+def estimate_density(DATA_PATH, feature_size):
+ """sample 10 times of a size of 1000 for estimating the density of the sparse dataset"""
+ if not os.path.exists(DATA_PATH):
+ raise Exception("Data is not there!")
+ density = []
+ P = 0.01
+ for _ in xrange(10):
+ num_non_zero = 0
+ num_sample = 0
+ with open(DATA_PATH) as f:
+ for line in f:
+ if (random.random() < P):
+ num_non_zero += len(line.split(" ")) - 1
+ num_sample += 1
+ density.append(num_non_zero * 1.0 / (feature_size * num_sample))
+ return sum(density) / len(density)
+
diff --git a/docs/api/python/ndarray.md b/docs/api/python/ndarray.md
index 5e9f7e1a1184..3f2cef24a73a 100644
--- a/docs/api/python/ndarray.md
+++ b/docs/api/python/ndarray.md
@@ -64,9 +64,21 @@ A detailed tutorial is available at
```
In the rest of this document, we first overview the methods provided by the
-`ndarray.NDArray` class, and then list other routines provided by the
-`ndarray` package.
+`ndarray.NDArray` class and its subclasses, and then list other routines
+provided by the `ndarray` package.
+The `ndarray` package provides several classes:
+
+```eval_rst
+.. autosummary::
+ :nosignatures:
+
+ NDArray
+ sparse.CSRNDArray
+ sparse.RowSparseNDArray
+```
+
+We summarize the interface for each class in the following sections.
## The `NDArray` class
@@ -80,6 +92,7 @@ In the rest of this document, we first overview the methods provided by the
NDArray.size
NDArray.context
NDArray.dtype
+ NDArray.stype
```
### Array conversion
@@ -94,6 +107,7 @@ In the rest of this document, we first overview the methods provided by the
NDArray.asnumpy
NDArray.asscalar
NDArray.astype
+ NDArray.tostype
```
### Array change shape
@@ -171,6 +185,35 @@ In the rest of this document, we first overview the methods provided by the
NDArray.wait_to_read
```
+## The `sparse.RowSparseNDArray` Class
+
+```eval_rst
+.. autosummary::
+ :nosignatures:
+
+ sparse.RowSparseNDArray.copyto
+ sparse.RowSparseNDArray.tostype
+ sparse.RowSparseNDArray.__setitem__
+ sparse.RowSparseNDArray.__getitem__
+ sparse.RowSparseNDArray.data
+ sparse.RowSparseNDArray.indices
+```
+
+## The `sparse.CSRNDArray` Class
+
+```eval_rst
+.. autosummary::
+ :nosignatures:
+
+ sparse.CSRNDArray.copyto
+ sparse.CSRNDArray.tostype
+ sparse.CSRNDArray.__setitem__
+ sparse.CSRNDArray.__getitem__
+ sparse.CSRNDArray.data
+ sparse.CSRNDArray.indices
+ sparse.CSRNDArray.indptr
+```
+
## Array creation routines
```eval_rst
@@ -499,8 +542,24 @@ The `contrib.ndarray` module contains many useful experimental APIs for new feat
```eval_rst
+
+.. autoclass:: mxnet.ndarray.NDArray
+ :members:
+ :special-members:
+
+.. autoclass:: mxnet.ndarray.sparse.CSRNDArray
+ :members:
+ :special-members:
+
+.. autoclass:: mxnet.ndarray.sparse.RowSparseNDArray
+ :members:
+ :special-members:
+
.. automodule:: mxnet.ndarray
:members:
+ :imported-members:
+ :special-members:
+ :exclude-members: CachedOp, BaseSparseNDArray, NDArray, CSRNDArray, RowSparseNDArray
.. automodule:: mxnet.random
:members:
diff --git a/example/sparse/get_data.py b/example/sparse/get_data.py
new file mode 100644
index 000000000000..578cf2ce5226
--- /dev/null
+++ b/example/sparse/get_data.py
@@ -0,0 +1,32 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: skip-file
+import os, gzip
+import pickle as pickle
+import sys
+
+def get_libsvm_data(data_dir, data_name, url, data_origin_name):
+ if not os.path.isdir(data_dir):
+ os.mkdir(data_dir)
+ os.chdir(data_dir)
+ if (not os.path.exists(data_name)):
+ import urllib
+ zippath = os.path.join(data_dir, data_origin_name)
+ urllib.urlretrieve(url, zippath)
+ os.system("bzip2 -d %r" % data_origin_name)
+ os.chdir("..")
diff --git a/example/sparse/linear_classification.py b/example/sparse/linear_classification.py
new file mode 100644
index 000000000000..567568c6eb80
--- /dev/null
+++ b/example/sparse/linear_classification.py
@@ -0,0 +1,185 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+from mxnet.test_utils import *
+from get_data import get_libsvm_data
+import time
+import argparse
+import os
+
+parser = argparse.ArgumentParser(description="Run sparse linear classification " \
+ "with distributed kvstore",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+parser.add_argument('--profiler', type=int, default=0,
+ help='whether to use profiler')
+parser.add_argument('--num-epoch', type=int, default=1,
+ help='number of epochs to train')
+parser.add_argument('--batch-size', type=int, default=8192,
+ help='number of examples per batch')
+parser.add_argument('--num-batch', type=int, default=99999999,
+ help='number of batches per epoch')
+parser.add_argument('--dummy-iter', type=int, default=0,
+ help='whether to use dummy iterator to exclude io cost')
+parser.add_argument('--kvstore', type=str, default='dist_sync',
+ help='what kvstore to use [local, dist_sync, etc]')
+parser.add_argument('--log-level', type=str, default='DEBUG',
+ help='logging level [debug, info, error]')
+parser.add_argument('--dataset', type=str, default='avazu',
+ help='what test dataset to use')
+
+class DummyIter(mx.io.DataIter):
+ "A dummy iterator that always return the same batch, used for speed testing"
+ def __init__(self, real_iter):
+ super(DummyIter, self).__init__()
+ self.real_iter = real_iter
+ self.provide_data = real_iter.provide_data
+ self.provide_label = real_iter.provide_label
+ self.batch_size = real_iter.batch_size
+
+ for batch in real_iter:
+ self.the_batch = batch
+ break
+
+ def __iter__(self):
+ return self
+
+ def next(self):
+ return self.the_batch
+
+# testing dataset sources
+avazu = {
+ 'data_name': 'avazu-app.t',
+ 'data_origin_name': 'avazu-app.t.bz2',
+ 'url': "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/avazu-app.t.bz2",
+ 'feature_dim': 1000000,
+}
+
+kdda = {
+ 'data_name': 'kdda.t',
+ 'data_origin_name': 'kdda.t.bz2',
+ 'url': "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/kdda.t.bz2",
+ 'feature_dim': 20216830,
+}
+
+datasets = { 'kdda' : kdda, 'avazu' : avazu }
+
+def linear_model(feature_dim):
+ x = mx.symbol.Variable("data", stype='csr')
+ norm_init = mx.initializer.Normal(sigma=0.01)
+ weight = mx.symbol.Variable("weight", shape=(feature_dim, 1), init=norm_init, stype='row_sparse')
+ bias = mx.symbol.Variable("bias", shape=(1,), init=norm_init)
+ dot = mx.symbol.dot(x, weight)
+ pred = mx.symbol.broadcast_add(dot, bias)
+ y = mx.symbol.Variable("softmax_label")
+ model = mx.symbol.SoftmaxOutput(data=pred, label=y, name="out")
+ return model
+
+if __name__ == '__main__':
+ # arg parser
+ args = parser.parse_args()
+ num_epoch = args.num_epoch
+ num_batch = args.num_batch
+ kvstore = args.kvstore
+ profiler = args.profiler > 0
+ batch_size = args.batch_size
+ dummy_iter = args.dummy_iter
+ dataset = args.dataset
+ log_level = args.log_level
+
+ # create kvstore
+ kv = mx.kvstore.create(kvstore)
+ rank = kv.rank
+ num_worker = kv.num_workers
+
+ # only print log for rank 0 worker
+ import logging
+ if rank != 0:
+ log_level = logging.ERROR
+ elif log_level == 'DEBUG':
+ log_level = logging.DEBUG
+ else:
+ log_level = logging.INFO
+ head = '%(asctime)-15s %(message)s'
+ logging.basicConfig(level=log_level, format=head)
+
+ # dataset
+ assert(dataset in datasets), "unknown dataset " + dataset
+ metadata = datasets[dataset]
+ feature_dim = metadata['feature_dim']
+ if logging:
+ logging.debug('preparing data ... ')
+ data_dir = os.path.join(os.getcwd(), 'data')
+ path = os.path.join(data_dir, metadata['data_name'])
+ if not os.path.exists(path):
+ get_libsvm_data(data_dir, metadata['data_name'], metadata['url'],
+ metadata['data_origin_name'])
+ assert os.path.exists(path)
+
+ # data iterator
+ train_data = mx.io.LibSVMIter(data_libsvm=path, data_shape=(feature_dim,),
+ batch_size=batch_size, num_parts=num_worker,
+ part_index=rank)
+ if dummy_iter:
+ train_data = DummyIter(train_data)
+
+ # model
+ model = linear_model(feature_dim)
+
+ # module
+ mod = mx.mod.Module(symbol=model, data_names=['data'], label_names=['softmax_label'])
+ mod.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label)
+ mod.init_params(initializer=mx.init.Uniform(scale=.1))
+ sgd = mx.optimizer.SGD(momentum=0.0, clip_gradient=5.0,
+ learning_rate=0.1, rescale_grad=1.0/batch_size/num_worker)
+ mod.init_optimizer(optimizer=sgd, kvstore=kv)
+ # use accuracy as the metric
+ metric = mx.metric.create('Accuracy')
+
+ # start profiler
+ if profiler:
+ name = 'profile_output_' + str(num_worker) + '.json'
+ mx.profiler.profiler_set_config(mode='all', filename=name)
+ mx.profiler.profiler_set_state('run')
+
+ logging.debug('start training ...')
+ start = time.time()
+ data_iter = iter(train_data)
+ for epoch in range(num_epoch):
+ nbatch = 0
+ data_iter.reset()
+ metric.reset()
+ for batch in data_iter:
+ nbatch += 1
+ row_ids = batch.data[0].indices
+ # pull sparse weight
+ index = mod._exec_group.param_names.index('weight')
+ kv.row_sparse_pull('weight', mod._exec_group.param_arrays[index],
+ priority=-index, row_ids=[row_ids])
+ mod.forward_backward(batch)
+ # update parameters
+ mod.update()
+ # accumulate prediction accuracy
+ mod.update_metric(metric, batch.label)
+ if nbatch == num_batch:
+ break
+ logging.info('epoch %d, %s' % (epoch, metric.get()))
+ if profiler:
+ mx.profiler.profiler_set_state('stop')
+ end = time.time()
+ time_cost = end - start
+ logging.info('num_worker = ' + str(num_worker) + ', time cost = ' + str(time_cost))
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 2289354e8a5e..a43f73fe45ab 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -276,6 +276,38 @@ MXNET_DLL int MXNDArrayCreateEx(const mx_uint *shape,
int delay_alloc,
int dtype,
NDArrayHandle *out);
+
+
+/*!
+ * \brief create an empty sparse NDArray with specified shape and data type
+ * \param storage_type the storage type of the ndarray
+ * \param shape the pointer to the shape
+ * \param ndim the dimension of the shape
+ * \param dev_type device type, specify device we want to take
+ * \param dev_id the device id of the specific device
+ * \param delay_alloc whether to delay allocation until
+ * the narray is first mutated
+ * \param dtype data type of created array
+ * \param num_aux the number of aux data to support this ndarray
+ * \param aux_type data type of the aux data for the created array
+ * \param aux_ndims the dimension of the shapes of aux data
+ * \param aux_shape the shapes of aux data
+ * \param out the returning handle
+ * \return 0 when success, -1 when failure happens
+ */
+MXNET_DLL int MXNDArrayCreateSparseEx(int storage_type,
+ const mx_uint *shape,
+ mx_uint ndim,
+ int dev_type,
+ int dev_id,
+ int delay_alloc,
+ int dtype,
+ mx_uint num_aux,
+ int *aux_type,
+ mx_uint *aux_ndims,
+ const mx_uint *aux_shape,
+ NDArrayHandle *out);
+
/*!
* \brief create a NDArray handle that is loaded from raw bytes.
* \param buf the head of the raw bytes
@@ -350,6 +382,17 @@ MXNET_DLL int MXNDArraySyncCopyFromCPU(NDArrayHandle handle,
MXNET_DLL int MXNDArraySyncCopyToCPU(NDArrayHandle handle,
void *data,
size_t size);
+/*!
+ * \brief Copy src.data() to dst.data() if i = -1, else dst.aux_data(i) if i >= 0
+ * This function blocks. Do not use it in performance critical code.
+ * \param handle_dst handle of a dst ndarray whose data/aux_data has been allocated
+ * \param handle_src handle of a src ndarray which has default storage type
+ * \param i dst data blob indicator
+ */
+MXNET_DLL int MXNDArraySyncCopyFromNDArray(NDArrayHandle handle_dst,
+ const NDArrayHandle handle_src,
+ const int i);
+
/*!
* \brief Wait until all the pending writes with respect NDArray are finished.
* Always call this before read data out synchronizely.
@@ -388,6 +431,7 @@ MXNET_DLL int MXNDArraySlice(NDArrayHandle handle,
mx_uint slice_begin,
mx_uint slice_end,
NDArrayHandle *out);
+
/*!
* \brief Index the NDArray along axis 0.
* \param handle the handle to the NDArray
@@ -398,6 +442,13 @@ MXNET_DLL int MXNDArraySlice(NDArrayHandle handle,
MXNET_DLL int MXNDArrayAt(NDArrayHandle handle,
mx_uint idx,
NDArrayHandle *out);
+
+/*!
+ * \brief get the storage type of the array
+ */
+MXNET_DLL int MXNDArrayGetStorageType(NDArrayHandle handle,
+ int *out_storage_type);
+
/*!
* \brief Reshape the NDArray.
* \param handle the handle to the narray
@@ -436,6 +487,34 @@ MXNET_DLL int MXNDArrayGetData(NDArrayHandle handle,
*/
MXNET_DLL int MXNDArrayGetDType(NDArrayHandle handle,
int *out_dtype);
+
+/*!
+ * \brief get the type of the ith aux data in NDArray
+ * \param handle the handle to the narray
+ * \param i the index of the aux data
+ * \param out_type pointer holder to get type of aux data
+ * \return 0 when success, -1 when failure happens
+ */
+MXNET_DLL int MXNDArrayGetAuxType(NDArrayHandle handle,
+ mx_uint i,
+ int *out_type);
+
+/*!
+ * \brief Get a deep copy of the ith aux data blob
+ * in the form of an NDArray of default storage type.
+ * This function blocks. Do not use it in performance critical code.
+ */
+MXNET_DLL int MXNDArrayGetAuxNDArray(NDArrayHandle handle,
+ mx_uint i,
+ NDArrayHandle *out);
+
+/*!
+ * \brief Get a deep copy of the data blob
+ * in the form of an NDArray of default storage type.
+ * This function blocks. Do not use it in performance critical code.
+ */
+MXNET_DLL int MXNDArrayGetDataNDArray(NDArrayHandle handle,
+ NDArrayHandle *out);
/*!
* \brief get the context of the NDArray
* \param handle the handle to the narray
@@ -581,6 +660,28 @@ MXNET_DLL int MXImperativeInvoke(AtomicSymbolCreator creator,
int num_params,
const char **param_keys,
const char **param_vals);
+/*!
+ * \brief invoke a nnvm op and imperative function
+ * \param creator the op
+ * \param num_inputs number of input NDArrays
+ * \param inputs input NDArrays
+ * \param num_outputs number of output NDArrays
+ * \param outputs output NDArrays
+ * \param num_params number of keyword parameters
+ * \param param_keys keys for keyword parameters
+ * \param param_vals values for keyword parameters
+ * \param out_stypes output ndarrays' stypes
+ * \return 0 when success, -1 when failure happens
+ */
+MXNET_DLL int MXImperativeInvokeEx(AtomicSymbolCreator creator,
+ int num_inputs,
+ NDArrayHandle *inputs,
+ int *num_outputs,
+ NDArrayHandle **outputs,
+ int num_params,
+ const char **param_keys,
+ const char **param_vals,
+ const int **out_stypes);
/*!
* \brief set whether to record operator for autograd
* \param is_recording 1 when recording, 0 when not recording.
@@ -666,6 +767,30 @@ MXNET_DLL int MXCreateCachedOp(SymbolHandle handle,
* \brief free cached operator
*/
MXNET_DLL int MXFreeCachedOp(CachedOpHandle handle);
+/*!
+ * \brief invoke cached operator
+ */
+MXNET_DLL int MXInvokeCachedOp(CachedOpHandle handle,
+ int num_inputs,
+ NDArrayHandle *inputs,
+ int *num_outputs,
+ NDArrayHandle **outputs);
+/*!
+ * \brief invoke a cached op
+ * \param handle the handle to the cached op
+ * \param num_inputs number of input NDArrays
+ * \param inputs input NDArrays
+ * \param num_outputs number of output NDArrays
+ * \param outputs output NDArrays
+ * \param out_stypes output ndarrays' stypes
+ * \return 0 when success, -1 when failure happens
+ */
+MXNET_DLL int MXInvokeCachedOpEx(CachedOpHandle handle,
+ int num_inputs,
+ NDArrayHandle *inputs,
+ int *num_outputs,
+ NDArrayHandle **outputs,
+ const int** out_stypes);
/*!
* \brief invoke cached operator
*/
@@ -1017,20 +1142,20 @@ MXNET_DLL int MXSymbolInferShape(SymbolHandle sym,
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolInferShapePartial(SymbolHandle sym,
- mx_uint num_args,
- const char** keys,
- const mx_uint *arg_ind_ptr,
- const mx_uint *arg_shape_data,
- mx_uint *in_shape_size,
- const mx_uint **in_shape_ndim,
- const mx_uint ***in_shape_data,
- mx_uint *out_shape_size,
- const mx_uint **out_shape_ndim,
- const mx_uint ***out_shape_data,
- mx_uint *aux_shape_size,
- const mx_uint **aux_shape_ndim,
- const mx_uint ***aux_shape_data,
- int *complete);
+ mx_uint num_args,
+ const char** keys,
+ const mx_uint *arg_ind_ptr,
+ const mx_uint *arg_shape_data,
+ mx_uint *in_shape_size,
+ const mx_uint **in_shape_ndim,
+ const mx_uint ***in_shape_data,
+ mx_uint *out_shape_size,
+ const mx_uint **out_shape_ndim,
+ const mx_uint ***out_shape_data,
+ mx_uint *aux_shape_size,
+ const mx_uint **aux_shape_ndim,
+ const mx_uint ***aux_shape_data,
+ int *complete);
/*!
* \brief infer type of unknown input types given the known one.
@@ -1061,6 +1186,10 @@ MXNET_DLL int MXSymbolInferType(SymbolHandle sym,
mx_uint *aux_type_size,
const int **aux_type_data,
int *complete);
+
+
+
+
//--------------------------------------------
// Part 4: Executor interface
//--------------------------------------------
@@ -1222,36 +1351,39 @@ MXNET_DLL int MXExecutorBindEX(SymbolHandle symbol_handle,
ExecutorHandle *out);
MXNET_DLL int MXExecutorSimpleBind(SymbolHandle symbol_handle,
- int dev_type,
- int dev_id,
- const mx_uint num_g2c_keys,
- const char** g2c_keys,
- const int* g2c_dev_types,
- const int* g2c_dev_ids,
- const mx_uint provided_grad_req_list_len,
- const char** provided_grad_req_names,
- const char** provided_grad_req_types,
- const mx_uint num_provided_arg_shapes,
- const char** provided_arg_shape_names,
- const mx_uint* provided_arg_shape_data,
- const mx_uint* provided_arg_shape_idx,
- const mx_uint num_provided_arg_dtypes,
- const char** provided_arg_dtype_names,
- const int* provided_arg_dtypes,
- const mx_uint num_shared_arg_names,
- const char** shared_arg_name_list,
- int* shared_buffer_len,
- const char** shared_buffer_name_list,
- NDArrayHandle* shared_buffer_handle_list,
- const char*** updated_shared_buffer_name_list,
- NDArrayHandle** updated_shared_buffer_handle_list,
- mx_uint* num_in_args,
- NDArrayHandle** in_args,
- NDArrayHandle** arg_grads,
- mx_uint* num_aux_states,
- NDArrayHandle** aux_states,
- ExecutorHandle shared_exec_handle,
- ExecutorHandle* out);
+ int dev_type,
+ int dev_id,
+ const mx_uint num_g2c_keys,
+ const char** g2c_keys,
+ const int* g2c_dev_types,
+ const int* g2c_dev_ids,
+ const mx_uint provided_grad_req_list_len,
+ const char** provided_grad_req_names,
+ const char** provided_grad_req_types,
+ const mx_uint num_provided_arg_shapes,
+ const char** provided_arg_shape_names,
+ const mx_uint* provided_arg_shape_data,
+ const mx_uint* provided_arg_shape_idx,
+ const mx_uint num_provided_arg_dtypes,
+ const char** provided_arg_dtype_names,
+ const int* provided_arg_dtypes,
+ const mx_uint num_provided_arg_stypes,
+ const char** provided_arg_stype_names,
+ const int* provided_arg_stypes,
+ const mx_uint num_shared_arg_names,
+ const char** shared_arg_name_list,
+ int* shared_buffer_len,
+ const char** shared_buffer_name_list,
+ NDArrayHandle* shared_buffer_handle_list,
+ const char*** updated_shared_buffer_name_list,
+ NDArrayHandle** updated_shared_buffer_handle_list,
+ mx_uint* num_in_args,
+ NDArrayHandle** in_args,
+ NDArrayHandle** arg_grads,
+ mx_uint* num_aux_states,
+ NDArrayHandle** aux_states,
+ ExecutorHandle shared_exec_handle,
+ ExecutorHandle* out);
/*!
* \brief set a call back to notify the completion of operation
*/
@@ -1468,6 +1600,26 @@ MXNET_DLL int MXKVStorePullEx(KVStoreHandle handle,
const char** keys,
NDArrayHandle* vals,
int priority);
+
+/*!
+ * \brief pull a list of (key, value) pairs from the kvstore, where each key is a string.
+ * The NDArray pulled back will be in row_sparse storage with only the specified
+ * row_ids present based row_ids (others rows are zeros).
+ * \param handle handle to the kvstore
+ * \param num the number of key-value pairs
+ * \param keys the list of keys
+ * \param vals the list of values
+ * \param row_ids the list of row_id NDArrays
+ * \param priority the priority of the action
+ * \return 0 when success, -1 when failure happens
+ */
+MXNET_DLL int MXKVStorePullRowSparse(KVStoreHandle handle,
+ mx_uint num,
+ const char** keys,
+ NDArrayHandle* vals,
+ const NDArrayHandle* row_ids,
+ int priority);
+
/*!
* \brief user-defined updater for the kvstore
* It's this updater's responsibility to delete \a recv and \a local
diff --git a/include/mxnet/executor.h b/include/mxnet/executor.h
index a74d3b07b5be..85d34778dd8c 100644
--- a/include/mxnet/executor.h
+++ b/include/mxnet/executor.h
@@ -133,6 +133,7 @@ class Executor {
const std::vector& aux_state_ctxes,
const std::unordered_map& arg_shape_map,
const std::unordered_map& arg_dtype_map,
+ const std::unordered_map& arg_stype_map,
const std::vector& grad_req_types,
const std::unordered_set& param_names,
std::vector* in_args,
diff --git a/include/mxnet/graph_attr_types.h b/include/mxnet/graph_attr_types.h
new file mode 100644
index 000000000000..3aba0119d8ca
--- /dev/null
+++ b/include/mxnet/graph_attr_types.h
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file graph_attr_types.h
+ * \brief Data structures that can appear in graph attributes.
+ */
+#ifndef MXNET_GRAPH_ATTR_TYPES_H_
+#define MXNET_GRAPH_ATTR_TYPES_H_
+
+#include
+
+namespace mxnet {
+
+/*!
+ * \brief The result holder of storage type of each NodeEntry in the graph.
+ * \note Stored under graph.attrs["storage_type"], provided by Pass "InferStorageType"
+ *
+ * \code
+ * Graph g = ApplyPass(src_graph, "InferStorageType");
+ * const StorageVector& stypes = g.GetAttr("storage_type");
+ * // get shape by entry id
+ * int entry_type = stypes[g.indexed_graph().entry_id(my_entry)];
+ * \endcode
+ *
+ * \sa FInferStorageType
+ */
+using StorageTypeVector = std::vector;
+
+} // namespace mxnet
+
+#endif // MXNET_GRAPH_ATTR_TYPES_H_
diff --git a/include/mxnet/kvstore.h b/include/mxnet/kvstore.h
index d2924ecea1b5..9ea63b4cec79 100644
--- a/include/mxnet/kvstore.h
+++ b/include/mxnet/kvstore.h
@@ -25,6 +25,7 @@
#define MXNET_KVSTORE_H_
#include
#include
+#include
#include
#include
#include
@@ -173,6 +174,29 @@ class KVStore {
const std::vector& values,
int priority = 0) = 0;
+ /*!
+ * \brief pull a list of key-value pairs from the store.
+ * The NDArray pulled back will be in row_sparse storage with only the
+ * specified row_ids present (others rows are zeros).
+ * \param keys the list of keys
+ * \param values the list of buffers - row_id pairs
+ * \param priority the priority of the action.
+ */
+ virtual void PullRowSparse(const std::vector& str_keys,
+ const std::vector>& val_rowids,
+ const int priority = 0) = 0;
+
+ /*!
+ * \brief pull a list of key-value pairs from the store, where each key is a string.
+ * The NDArray pulled back will be in row_sparse storage with only the
+ * specified row_ids present (others rows are zeros).
+ * \param keys the list of keys in string format
+ * \param values the list of buffers - row_id pairs
+ * \param priority the priority of the action.
+ */
+ virtual void PullRowSparse(const std::vector& str_keys,
+ const std::vector>& val_rowids,
+ const int priority = 0) = 0;
/**
* \brief the prototype of user-defined updater
diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index d7dff4098b27..754bc28e7bed 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -47,7 +47,6 @@
namespace mxnet {
-// forward declaration
namespace autograd {
class AGNode;
@@ -71,6 +70,23 @@ class AGNodeEntry {
class AutogradRuntime;
} // namespace autograd
+// enum for storage types
+namespace csr {
+enum CSRAuxType {kIndPtr, kIdx};
+}
+
+namespace rowsparse {
+enum RowSparseAuxType {kIdx};
+}
+
+enum NDArrayStorageType {
+ kUndefinedStorage = -1, // undefined storage
+ kDefaultStorage, // dense
+ kRowSparseStorage, // row sparse
+ kCSRStorage, // csr
+};
+
+
/*!
* \brief ndarray interface
*/
@@ -91,10 +107,55 @@ class NDArray {
*/
NDArray(const TShape &shape, Context ctx,
bool delay_alloc = false, int dtype = mshadow::default_type_flag)
- : ptr_(std::make_shared(shape.Size(), ctx, delay_alloc, dtype)),
+ : ptr_(std::make_shared(shape, ctx, delay_alloc, dtype)),
shape_(shape), dtype_(dtype), entry_({nullptr, 0, 0}) {
#if MKL_EXPERIMENTAL == 1
Mkl_mem_ = std::make_shared();
+#endif
+ }
+ /*! \brief constructor for NDArray with storage type
+ */
+ NDArray(const NDArrayStorageType stype, const TShape &shape, Context ctx,
+ bool delay_alloc = true, int dtype = mshadow::default_type_flag,
+ std::vector aux_types = {}, std::vector aux_shapes = {},
+ TShape storage_shape = TShape(mshadow::Shape1(0)))
+ : shape_(shape), dtype_(dtype), entry_({nullptr, 0, 0}) {
+ // Assign default aux types if not given
+ if (aux_types.size() == 0) {
+ if (stype == kRowSparseStorage) {
+ aux_types = {mshadow::kInt64};
+ } else if (stype == kCSRStorage) {
+ aux_types = {mshadow::kInt64, mshadow::kInt64};
+ } else {
+ LOG(FATAL) << "Unknown storage type " << stype;
+ }
+ }
+ // Assign default shapes if not given
+ // unknown shapes are intialized as {0} such that Size() would return 0
+ if (aux_shapes.size() == 0) {
+ if (stype == kRowSparseStorage) {
+ aux_shapes = {TShape(mshadow::Shape1(0))};
+ } else if (stype == kCSRStorage) {
+ // aux shapes for indptr and indices
+ aux_shapes = {TShape(mshadow::Shape1(0)), TShape(mshadow::Shape1(0))};
+ } else {
+ LOG(FATAL) << "Unknown storage type " << stype;
+ }
+ }
+ if (storage_shape.Size() == 0) {
+ if (stype == kRowSparseStorage) {
+ storage_shape = shape;
+ storage_shape[0] = aux_shapes[rowsparse::kIdx][0];
+ } else if (stype == kCSRStorage) {
+ storage_shape = aux_shapes[csr::kIdx];
+ } else {
+ LOG(FATAL) << "Unknown storage type " << stype;
+ }
+ }
+ ptr_ = std::make_shared(stype, storage_shape, ctx, delay_alloc,
+ dtype, aux_types, aux_shapes);
+#if MKL_EXPERIMENTAL == 1
+ Mkl_mem_ = std::make_shared();
#endif
}
/*!
@@ -111,17 +172,86 @@ class NDArray {
Mkl_mem_ = std::make_shared();
#endif
}
+
/*!
- * \return the shape of current NDArray
+ * \brief constructing a static NDArray of non-default storage that shares data with TBlob
+ * Use with caution: allocate ONLY ONE NDArray for each TBlob,
+ * make sure the memory region is available through out the life of NDArray
+ * \param stype the storage type of NDArray
+ * \param shape the shape of NDArray
+ * \param data the memory content of static data
+ * \param aux_data the memory content of static aux data
+ * \param dev_id the device id this tensor sits at
+ */
+ NDArray(const NDArrayStorageType stype, const TShape &shape,
+ const TBlob &data, const std::vector &aux_data, int dev_id)
+ : ptr_(std::make_shared(stype, data, aux_data, dev_id)), shape_(shape),
+ dtype_(data.type_flag_), entry_({nullptr, 0, 0}) {
+#if MKL_EXPERIMENTAL == 1
+ Mkl_mem_ = std::make_shared();
+#endif
+ }
+
+
+ /*!
+ * \return the shape of current NDArray.
*/
inline const TShape& shape() const {
return shape_;
}
+ /*!
+ * \return the shape of underlying chunk which stores the NDArray data/value.
+ * It is only intended for non-default storage. For row-sparse storage, it is the shape of
+ * the tensor which stores the non-zero values.
+ */
+ inline const TShape &storage_shape() const {
+ CHECK(ptr_ != nullptr);
+ CHECK_NE(storage_type(), kDefaultStorage)
+ << "storage_shape() is not intended for kDefaultStorage.";
+ return ptr_->storage_shape;
+ }
+
+ /*!
+ * \brief get the shape of aux_data(index)
+ * \param index the index of the aux data
+ * \return the shape of aux data at given index
+ */
+ inline const TShape& aux_shape(size_t index) const {
+ CHECK_NE(storage_type(), kDefaultStorage)
+ << "aux_shape() is not intended for kDefaultStorage.";
+ return ptr_->aux_shapes[index];
+ }
+
+ /* \return the shapes of all aux data */
+ const std::vector& aux_shapes() const {
+ CHECK_NE(storage_type(), kDefaultStorage)
+ << "aux_shapes() is not intended for kDefaultStorage.";
+ return ptr_->aux_shapes;
+ }
+
+ /*! returns the dtypes of all aux data */
+ const std::vector& aux_types() const {
+ CHECK_NE(storage_type(), kDefaultStorage)
+ << "aux_types() is not intended for kDefaultStorage.";
+ return ptr_->aux_types;
+ }
+
+ /*!
+ * \brief For a sparse operation on a csr matrix for example,
+ * the size of the column index array
+ * is an estimated value in the beginning for allocating enough capacity
+ * for the final result. After the operation is done, the exact size of
+ * the shape is known and need to be reset using this function.
+ */
+ inline void set_aux_shape(size_t index, const TShape& shape) const {
+ ptr_->set_aux_shape(index, shape);
+ }
+
/*!
* \return the data TBlob
*/
inline const TBlob& data() const {
- CheckAndAlloc();
+ if (storage_type() == kDefaultStorage) CheckAndAlloc();
SetTBlob();
return tblob_;
}
@@ -129,6 +259,26 @@ class NDArray {
* \return the gradient ndarray.
*/
NDArray grad() const;
+
+ /*!
+ * \return the aux TBlob
+ */
+ inline TBlob aux_data(size_t i) const {
+ auto stype = storage_type();
+ TBlob res;
+ auto shape = aux_shape(i);
+ auto type = aux_type(i);
+ MSHADOW_TYPE_SWITCH(type, DType, {
+ auto dptr = static_cast(ptr_->aux_handles[i].dptr);
+ CHECK(stype == kRowSparseStorage || stype == kCSRStorage)
+ << "Unexpected storage type: " << stype;
+ res = TBlob(dptr, shape, ptr_->aux_handles[i].ctx.dev_mask(), type);
+ });
+#if MKL_EXPERIMENTAL == 1
+ res.Mkl_mem_ = Mkl_mem_;
+#endif
+ return res;
+ }
/*!
* \return the context of NDArray, this function is only valid when the NDArray is not empty
*/
@@ -141,6 +291,15 @@ class NDArray {
inline int dtype() const {
return dtype_;
}
+ inline int aux_type(size_t i) const {
+ CHECK(!is_none());
+ return ptr_->aux_types[i];
+ }
+
+ inline NDArrayStorageType storage_type() const {
+ if (is_none()) return kUndefinedStorage;
+ return ptr_->storage_type;
+ }
/*! \return whether this ndarray is not initialized */
inline bool is_none() const {
return ptr_.get() == nullptr;
@@ -149,6 +308,27 @@ class NDArray {
bool fresh_out_grad() const;
/*! \return updated grad state in entry_ */
void set_fresh_out_grad(bool state) const;
+ // returns true if a sparse ndarray's aux_data and storage are initialized
+ inline bool storage_initialized() const {
+ if (is_none()) return false;
+ auto stype = storage_type();
+ CHECK_NE(stype, kDefaultStorage)
+ << "storage_initialized() is not intended for kDefaultStorage.";
+ if (stype == kRowSparseStorage) {
+ CHECK_EQ(aux_shape(rowsparse::kIdx)[0], storage_shape()[0])
+ << "inconsistent storage shape " << storage_shape()
+ << " vs. aux shape " << aux_shape(rowsparse::kIdx);
+ return aux_shape(0).Size() != 0;
+ } else if (stype == kCSRStorage) {
+ CHECK_EQ(aux_shape(csr::kIdx)[0], storage_shape()[0])
+ << "inconsistent storage shape " << storage_shape()
+ << " vs. aux shape " << aux_shape(csr::kIdx);
+ return aux_shape(0).Size() != 0;
+ } else {
+ LOG(FATAL) << "Unknown storage type";
+ }
+ return true;
+ }
/*!
* \brief Block until all the pending write operations with respect
* to current NDArray are finished, and read can be performed.
@@ -179,6 +359,12 @@ class NDArray {
* \param strm the output stream
*/
void Save(dmlc::Stream *strm) const;
+ /*!
+ * \brief load ndarrays before supporting sparse ndarrays
+ * \param strm the output stream
+ * \param magic the magic number used for version control
+ */
+ bool LegacyLoad(dmlc::Stream *strm, const uint32_t magic);
/*!
* \brief load the content from binary stream
* \param strm the output stream
@@ -269,6 +455,12 @@ class NDArray {
* \param size the size of the source array, in sizeof(DType) not raw btyes.
*/
void SyncCopyFromCPU(const void *data, size_t size) const;
+
+ /*!
+ * \brief Copy from src.data()/aux_data(i) to this->data()/aux_data(j)
+ */
+ void SyncCopyFromNDArray(const NDArray &src, int i = -1, int j = -1);
+
/*!
* \brief Do a synchronize copy to a continugous CPU memory region.
*
@@ -282,17 +474,31 @@ class NDArray {
void SyncCopyToCPU(void *data, size_t size) const;
/*!
* \brief Slice a NDArray
- * \param begin begin index in first dim
- * \param end end index in first dim
+ * \param begin begin index in first dim (inclusive)
+ * \param end end index in first dim (exclusive)
* \return sliced NDArray
*/
NDArray Slice(index_t begin, index_t end) const;
+
/*!
* \brief Index a NDArray
* \param idx the index
* \return idx-th sub array NDArray
*/
NDArray At(index_t idx) const;
+
+ /*!
+ * \brief Generate a deep copy of aux_data(i) returned as
+ * a default storage type NDArray
+ */
+ NDArray aux_ndarray(size_t i) const;
+
+ /*!
+ * \brief Generate a deep copy of data() returned as a
+ * default storage type NDArray
+ */
+ NDArray data_ndarray() const;
+
/*!
* \brief Create a NDArray that shares memory with current one
* The new array must have smaller memory size than the current array.
@@ -301,6 +507,8 @@ class NDArray {
* \return NDArray in new shape and type.
*/
inline NDArray AsArray(const TShape &shape, int dtype) const {
+ CHECK_EQ(storage_type(), kDefaultStorage)
+ << "AsArray is intended only for kDefaultStorage.";
CHECK_GE(shape_.Size() * mshadow::mshadow_sizeof(dtype_),
shape.Size() * mshadow::mshadow_sizeof(dtype))
<< "NDArray.AsArray: target memory size is bigger";
@@ -342,8 +550,45 @@ class NDArray {
* This is an internal function used by system that normal user should not use
*/
inline void CheckAndAlloc() const {
+ CHECK_EQ(storage_type(), kDefaultStorage);
ptr_->CheckAndAlloc();
}
+
+ /*!
+ * \brief Allocate the space if the allocation has been delayed
+ * or the requested size is bigger than the available one.
+ * This function can only be called by ndarray of default
+ * storage type and effectively changes the ndarray's shape_.
+ * Note: This function is named as this to avoid overload conflict
+ * with CheckAndAlloc(const std::vector &aux_shapes), since
+ * TShape tmp = some_shape is equivalent to TShape tmp = {some_shape}.
+ */
+ void ReshapeAndAlloc(const TShape& shape) {
+ CHECK_EQ(storage_type(), kDefaultStorage);
+ CHECK(!is_none());
+ shape_ = shape;
+ ptr_->CheckAndAlloc(shape.Size() * mshadow::mshadow_sizeof(dtype_));
+ }
+
+ /* !
+ * \brief Alloc memory for non-default storage
+ * aux_shape is only known at run time
+ */
+ inline void CheckAndAlloc(const std::vector &aux_shapes) const {
+ CHECK_NE(storage_type(), kDefaultStorage)
+ << "CheckAndAlloc(aux_shapes) is not intended for kDefaultStorage";
+ ptr_->CheckAndAlloc(shape_, aux_shapes, dtype_);
+ }
+ inline void CheckAndAllocData(const TShape &storage_shape) const {
+ CHECK_NE(storage_type(), kDefaultStorage)
+ << "CheckAndAllocData is not intended for kDefaultStorage";
+ ptr_->CheckAndAllocData(storage_shape, dtype_);
+ }
+ inline void CheckAndAllocAuxData(size_t i, const TShape &aux_shape) const {
+ CHECK_NE(storage_type(), kDefaultStorage)
+ << "CheckAndAllocAuxData is not intended for kDefaultStorage";
+ ptr_->CheckAndAllocAuxData(i, aux_shape);
+ }
/*!
* \brief Save list of ndarray into the Stream.x
* \param fo The stream of output.
@@ -366,44 +611,138 @@ class NDArray {
private:
friend class autograd::AutogradRuntime;
/*! \brief the real data chunk that backs NDArray */
+ // shandle is used to store the actual values in the NDArray
+ // aux_handles store the aux data(such as indices) if it's needed by non-default storage.
struct Chunk {
- /*! \brief storage handlefrom storage engine */
+ /*! \brief storage handle from storage engine.
+ for non-default storage, shandle stores the data(value) array.
+ */
Storage::Handle shandle;
+ /*! \brief storage handles for aux data (e.g index)
+ for row_sparse, aux_handles[0] = indices
+ for csr, aux_handles[0] = indptr, aux_handles[1] = indices
+ */
+ std::vector aux_handles;
/*! \brief variable from engine */
Engine::VarHandle var;
/*!
* \brief if this is true, this means the data do not come
* from Storage, and do not need to be freed
*/
+ /*! \brief construct from static data */
bool static_data;
- /*! \brief whether allocation is delayed */
+ /*! \brief whether data allocation is delayed. This doesn't indicate whether aux data
+ allocation is delayed. */
bool delay_alloc;
+ // the type of the storage. The storage_type is never kUndefinedStorage once the chunk
+ // is constructed.
+ NDArrayStorageType storage_type = kDefaultStorage;
+ /*! \brief type of aux */
+ std::vector aux_types;
+ // context of data
+ Context ctx;
+ // The shape of the chunk data.
+ // This might not be the same shape as the NDArray, since the storage may be sparse.
+ // The default value for storage_shape is {0} when an empty non-default NDArray is created.
+ TShape storage_shape;
+ // The shape of aux data. The default value for the shape depends on the type of storage.
+ // If aux_shapes[i].Size() is zero, aux data i is empty.
+ std::vector aux_shapes;
+
/*! \brief default cosntructor */
- Chunk() : static_data(true), delay_alloc(false) {
- var = Engine::Get()->NewVariable();
+ Chunk() : static_data(true), delay_alloc(false) {}
+
+ /*! \brief construct a new chunk */
+ Chunk(TShape shape, Context ctx_, bool delay_alloc_, int dtype)
+ : static_data(false), delay_alloc(true), ctx(ctx_) {
+ auto size = shape.Size();
+ storage_shape = shape;
+ var = Engine::Get()->NewVariable();
+ shandle.size = size * mshadow::mshadow_sizeof(dtype);
+ shandle.ctx = ctx_;
+ if (!delay_alloc_) this->CheckAndAlloc();
}
- /*! \brief construct from static data */
+
Chunk(const TBlob &data, int dev_id)
- : static_data(true),
- delay_alloc(false) {
+ : static_data(true), delay_alloc(false) {
+ CHECK(storage_type == kDefaultStorage);
var = Engine::Get()->NewVariable();
if (data.dev_mask() == cpu::kDevMask) {
- shandle.ctx = Context::CPU();
+ ctx = Context::CPU();
} else {
CHECK_EQ(data.dev_mask(), gpu::kDevMask);
- shandle.ctx = Context::GPU(dev_id);
+ ctx = Context::GPU(dev_id);
}
+ // init shandle
+ shandle.ctx = ctx;
shandle.dptr = data.dptr_;
shandle.size = data.shape_.Size() * mshadow::mshadow_sizeof(data.type_flag_);
+ storage_shape = data.shape_;
}
- /*! \brief construct a new chunk */
- Chunk(uint64_t size, Context ctx, bool delay_alloc_, int dtype)
- : static_data(false), delay_alloc(true) {
+ // Constructor for a non-default storage chunk
+ Chunk(NDArrayStorageType storage_type_, const TShape &storage_shape_, Context ctx_,
+ bool delay_alloc_, int dtype, const std::vector &aux_types_,
+ const std::vector &aux_shapes_)
+ : static_data(false), delay_alloc(delay_alloc_), storage_type(storage_type_),
+ aux_types(aux_types_), ctx(ctx_), storage_shape(storage_shape_),
+ aux_shapes(aux_shapes_) {
+ shandle.ctx = ctx;
var = Engine::Get()->NewVariable();
- shandle.size = size * mshadow::mshadow_sizeof(dtype);
+ // aux_handles always reflect the correct number of aux data
+ for (size_t i = 0; i < aux_shapes.size(); i++) {
+ CheckAndAllocAuxData(i, aux_shapes[i]);
+ // this line is needed in case when aux_shapes[i].Size() = 0
+ // aux_handles[i] will not be updated and take only default value.
+ aux_handles[i].ctx = ctx;
+ }
+ if (!delay_alloc) {
+ CheckAndAllocData(storage_shape, dtype);
+ }
+ }
+
+ Chunk(const NDArrayStorageType storage_type_, const TBlob &data,
+ const std::vector &aux_data, int dev_id)
+ : static_data(true), delay_alloc(false), storage_type(storage_type_) {
+ using namespace mshadow;
+ CHECK_NE(storage_type, kDefaultStorage);
+ // init var
+ var = Engine::Get()->NewVariable();
+ // init ctx
+ if (data.dev_mask() == cpu::kDevMask) {
+ ctx = Context::CPU();
+ } else {
+ CHECK_EQ(data.dev_mask(), gpu::kDevMask);
+ ctx = Context::GPU(dev_id);
+ }
+ // init shandle
shandle.ctx = ctx;
- if (!delay_alloc_) this->CheckAndAlloc();
+ shandle.dptr = data.dptr_;
+ shandle.size = data.shape_.Size() * mshadow_sizeof(data.type_flag_);
+ storage_shape = data.shape_;
+ // init aux handles
+ for (const auto &aux : aux_data) {
+ Storage::Handle aux_handle;
+ aux_handle.ctx = ctx;
+ aux_handle.dptr = aux.dptr_;
+ aux_handle.size = aux.shape_.Size() * mshadow_sizeof(aux.type_flag_);
+ aux_handles.push_back(aux_handle);
+ aux_types.emplace_back(aux.type_flag_);
+ aux_shapes.emplace_back(aux.shape_);
+ }
+ }
+
+ /*! \brief set the shape for ith aux data, and update storage shape if necessary */
+ inline void set_aux_shape(const size_t i, const TShape& shape) {
+ aux_shapes[i] = shape;
+ if (storage_shape.ndim() > 0) {
+ if (storage_type == kRowSparseStorage && i == rowsparse::kIdx) {
+ storage_shape[0] = shape[0];
+ } else if (storage_type == kCSRStorage && i == csr::kIdx) {
+ storage_shape[0] = shape[0];
+ }
+ }
}
+
/*! \brief check if delay alloc is on, do alloc if not yet done */
inline void CheckAndAlloc(void) {
if (delay_alloc) {
@@ -411,22 +750,113 @@ class NDArray {
delay_alloc = false;
}
}
- /*! \brief destructor */
- ~Chunk() {
- if (static_data || delay_alloc) {
- Engine::Get()->DeleteVariable([](RunContext s) {}, shandle.ctx, var);
+
+ /*! \brief Check and alloc memory for a dense ndarray */
+ // size is the number of bytes
+ void CheckAndAlloc(uint64_t dbytes) {
+ CHECK_EQ(kDefaultStorage, storage_type)
+ << "CheckAndAlloc(dbytes) is not intended for kDefaultStorage";
+ if (delay_alloc) {
+ shandle = Storage::Get()->Alloc(dbytes, shandle.ctx);
+ delay_alloc = false;
+ } else if (shandle.size < dbytes) {
+ // free storage if necessary and alloc again
+ if (shandle.size > 0) Storage::Get()->Free(shandle);
+ // init storage
+ shandle = Storage::Get()->Alloc(dbytes, shandle.ctx);
+ }
+ }
+
+ inline void CheckAndAlloc(const TShape &shape, const std::vector &aux_shapes,
+ int dtype) {
+ // calculate size, perform allocation
+ if (kRowSparseStorage == storage_type) {
+ // For row sparse, aux_shape indicates the number of rows to allocate
+ auto aux_shape = aux_shapes[rowsparse::kIdx];
+ CheckAndAllocAuxData(rowsparse::kIdx, aux_shape);
+ TShape storage_shape(shape);
+ storage_shape[0] = aux_shape[0];
+ CheckAndAllocData(storage_shape, dtype);
+ } else if (kCSRStorage == storage_type) {
+ CheckAndAllocAuxData(csr::kIndPtr, aux_shapes[csr::kIndPtr]);
+ CheckAndAllocAuxData(csr::kIdx, aux_shapes[csr::kIdx]);
+ CheckAndAllocData(aux_shapes[csr::kIdx], dtype);
} else {
- Storage::Handle h = this->shandle;
- Engine::Get()->DeleteVariable([h](RunContext s) {
- Storage::Get()->Free(h);
- }, shandle.ctx, var);
+ LOG(FATAL) << "Storage type " << storage_type << " not implemented for CheckAndAlloc";
+ }
+ }
+ // create storage handle for data based on shape and dtype, assuming ctx is set
+ // storage shape is also updated
+ // if data is already allocated, try reuse the storage. Otherwise, free the current one
+ // and allocate new storage
+ inline void CheckAndAllocData(const TShape &shape, int dtype) {
+ CHECK_NE(aux_shapes.size(), 0) << "data is expected to be allocated after aux_data";
+ auto dbytes = shape.Size() * mshadow::mshadow_sizeof(dtype);
+ if (shandle.size < dbytes) {
+ // free storage if necessary and alloc again
+ if (shandle.size > 0) Storage::Get()->Free(shandle);
+ // init storage
+ shandle = Storage::Get()->Alloc(dbytes, ctx);
}
+ // init shape
+ storage_shape = shape;
+ // delay_alloc is only set when data storage handle is present
+ delay_alloc = false;
+ }
+ // create storage handle for aux data based on shape
+ // this function assumes ctx, aux shapes and aux types are set
+ // aux shape is also updated
+ // if aux data is already allocated, try reuse the storage. Otherwise, free the current one
+ // and allocate new storage
+ inline void CheckAndAllocAuxData(size_t i, const TShape &shape) {
+ CHECK_EQ(shape.ndim(), 1) << "shape must be 1D in CheckAndAllocAuxData";
+ CHECK_NE(storage_type, kUndefinedStorage)
+ << "storage type cannot be kUndefinedStorage in CheckAndAllocAuxData";
+ CHECK_NE(storage_type, kDefaultStorage)
+ << "storage type cannot be kDefaultStorage in CheckAndAllocAuxData";
+ if (aux_handles.size() <= i) {
+ aux_handles.resize(i + 1);
+ }
+ size_t aux_bytes = shape.Size() * mshadow::mshadow_sizeof(aux_types[i]);
+ if (aux_handles[i].size < aux_bytes) {
+ // free storage if necessary and alloc again
+ if (aux_handles[i].size > 0) Storage::Get()->Free(aux_handles[i]);
+ // init aux storage
+ aux_handles[i] = Storage::Get()->Alloc(aux_bytes, ctx);
+ }
+ // init shape
+ set_aux_shape(i, shape);
+ }
+ /*! \brief destructor */
+ ~Chunk() {
+ bool skip_free = static_data || delay_alloc;
+ Storage::Handle h = this->shandle;
+ std::vector aux_h = this->aux_handles;
+ Engine::Get()->DeleteVariable([h, aux_h, skip_free](RunContext s) {
+ if (skip_free == false) {
+ Storage::Get()->Free(h);
+ for (size_t i = 0; i < aux_h.size(); i++) {
+ if (aux_h[i].size > 0) Storage::Get()->Free(aux_h[i]);
+ }
+ }
+ }, shandle.ctx, var);
}
- };
+ }; // struct Chunk
void SetTBlob() const {
- tblob_.dptr_ = static_cast(ptr_->shandle.dptr) + byte_offset_;
- tblob_.shape_ = shape_;
+ CHECK(ptr_ != nullptr);
+ TShape shape = shape_;
+ char *dptr = static_cast(ptr_->shandle.dptr);
+ auto stype = storage_type();
+ if (stype == kDefaultStorage) {
+ dptr += byte_offset_;
+ } else if (stype == kCSRStorage || stype == kRowSparseStorage) {
+ shape = storage_shape();
+ } else {
+ LOG(FATAL) << "unknown storage type " << stype;
+ }
+ tblob_.dptr_ = dptr;
+ tblob_.shape_ = shape;
tblob_.type_flag_ = dtype_;
tblob_.SetDLTensor(ptr_->shandle.ctx.dev_mask(), ptr_->shandle.ctx.dev_id);
#if MKL_EXPERIMENTAL == 1
@@ -438,7 +868,7 @@ class NDArray {
std::shared_ptr Mkl_mem_;
#endif
/*! \brief internal data of NDArray */
- std::shared_ptr ptr_;
+ std::shared_ptr ptr_{nullptr};
/*! \brief shape of current NDArray */
TShape shape_;
/*! \brief byte offset in chunk */
@@ -455,7 +885,12 @@ class NDArray {
* this situation.
*/
mutable TBlob tblob_;
-};
+}; // class NDArray
+
+/*!
+ * \return the number of aux data used for given storage type
+ */
+size_t num_aux_data(NDArrayStorageType stype);
/*!
* \brief issue an copy operation from one NDArray to another
@@ -470,7 +905,6 @@ class NDArray {
*/
void CopyFromTo(const NDArray &from, NDArray *to, int priority = 0);
-
/*!
* \brief Perform elementwise sum over each data from source, store result into out.
* \param source the ndarray we want to sum
diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h
index 1bcae0d29348..f559a921c522 100644
--- a/include/mxnet/op_attr_types.h
+++ b/include/mxnet/op_attr_types.h
@@ -25,7 +25,6 @@
#ifndef MXNET_OP_ATTR_TYPES_H_
#define MXNET_OP_ATTR_TYPES_H_
-
#include
#include
@@ -226,6 +225,23 @@ using FCompute = std::function& inputs,
const std::vector& req,
const std::vector& outputs)>;
+/*!
+ * \brief Resiger an NDArray compute function for simple stateless forward only operator
+ *
+ * \note Register under "FComputeEx" and "FComputeEx"
+ * Dispatched only when operators process non-default storage inputs or outputs
+ */
+using FComputeEx = std::function& inputs,
+ const std::vector& req,
+ const std::vector& outputs)>;
+
+using FInferStorageType = std::function* in_attrs,
+ std::vector* out_attrs)>;
+
} // namespace mxnet
#endif // MXNET_OP_ATTR_TYPES_H_
diff --git a/include/mxnet/storage.h b/include/mxnet/storage.h
index bfb42de8771a..7e3af8eeca81 100644
--- a/include/mxnet/storage.h
+++ b/include/mxnet/storage.h
@@ -41,11 +41,11 @@ class Storage {
/*!
* \brief Pointer to the data.
*/
- void* dptr;
+ void* dptr{nullptr};
/*!
* \brief Size of the storage.
*/
- size_t size;
+ size_t size{0};
/*!
* \brief Context information about device and ID.
*/
diff --git a/perl-package/AI-MXNetCAPI/mxnet.i b/perl-package/AI-MXNetCAPI/mxnet.i
index fd1a471bcf16..b4c1336de624 100644
--- a/perl-package/AI-MXNetCAPI/mxnet.i
+++ b/perl-package/AI-MXNetCAPI/mxnet.i
@@ -1203,6 +1203,12 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle,
const mx_uint num_provided_arg_dtypes,
const char** in, // provided_arg_dtype_names,
const int* in, // provided_arg_dtypes,
+
+//--------------- sparse related variables, ignored for now
+ const mx_uint num_provided_arg_stypes,
+ const char** provided_arg_stype_names,
+ const int* provided_arg_stypes,
+//---------------
const mx_uint num_shared_arg_names,
const char** in, // shared_arg_name_list,
//------------
diff --git a/perl-package/AI-MXNetCAPI/mxnet_typemaps.i b/perl-package/AI-MXNetCAPI/mxnet_typemaps.i
index 640215fd7792..5d2fbd6880a1 100644
--- a/perl-package/AI-MXNetCAPI/mxnet_typemaps.i
+++ b/perl-package/AI-MXNetCAPI/mxnet_typemaps.i
@@ -820,6 +820,17 @@
}
}
+%typemap(in,numinputs=0) (const mx_uint num_provided_arg_stypes, const char** provided_arg_stype_names,
+ const int* provided_arg_stypes)
+ (mx_uint temp1, char* temp2, int temp3)
+{
+ $2 = &temp2;
+ $3 = &temp3;
+ $1 = 0;
+ *$2 = NULL;
+ *$3 = 0;
+}
+
%typemap(in,numinputs=0) (mx_uint* num_aux_states,
NDArrayHandle** aux_states)
(mx_uint temp1,
diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py
index 3c3ce76a9284..72dc2b2fec8d 100644
--- a/python/mxnet/__init__.py
+++ b/python/mxnet/__init__.py
@@ -26,6 +26,7 @@
from . import base
from . import contrib
from . import ndarray
+from . import ndarray as nd
from . import name
# use mx.sym as short for symbol
from . import symbol as sym
@@ -34,8 +35,6 @@
from . import io
from . import recordio
from . import operator
-# use mx.nd as short for mx.ndarray
-from . import ndarray as nd
# use mx.rnd as short for mx.random
from . import random as rnd
from . import random
diff --git a/python/mxnet/_ctypes/ndarray.py b/python/mxnet/_ctypes/ndarray.py
index 5a50f80498ec..c2e6fce40de8 100644
--- a/python/mxnet/_ctypes/ndarray.py
+++ b/python/mxnet/_ctypes/ndarray.py
@@ -32,10 +32,19 @@
from ..ndarray_doc import _build_doc
+_STORAGE_TYPE_ID_TO_STR = {
+ -1 : 'undefined',
+ 0 : 'default',
+ 1 : 'row_sparse',
+ 2 : 'csr',
+}
+
+
class NDArrayBase(object):
"""Base data structure for ndarray"""
__slots__ = ["handle", "writable"]
# pylint: disable= no-member
+
def __init__(self, handle, writable=True):
"""initialize a new NDArray
@@ -78,7 +87,11 @@ def _imperative_invoke(handle, ndargs, keys, vals, out):
output_vars = ctypes.POINTER(NDArrayHandle)()
num_output = ctypes.c_int(0)
- check_call(_LIB.MXImperativeInvoke(
+ # return output stypes to avoid the c_api call for checking
+ # a handle's stype in _ndarray_cls
+ out_stypes = ctypes.POINTER(ctypes.c_int)()
+
+ check_call(_LIB.MXImperativeInvokeEx(
ctypes.c_void_p(handle),
ctypes.c_int(len(ndargs)),
c_array(NDArrayHandle, [arr.handle for arr in ndargs]),
@@ -86,14 +99,17 @@ def _imperative_invoke(handle, ndargs, keys, vals, out):
ctypes.byref(output_vars),
ctypes.c_int(len(keys)),
c_array(ctypes.c_char_p, [c_str(key) for key in keys]),
- c_array(ctypes.c_char_p, [c_str(str(val)) for val in vals])))
+ c_array(ctypes.c_char_p, [c_str(str(val)) for val in vals]),
+ ctypes.byref(out_stypes)))
if original_output is not None:
return original_output
if num_output.value == 1:
- return _ndarray_cls(ctypes.cast(output_vars[0], NDArrayHandle))
+ return _ndarray_cls(ctypes.cast(output_vars[0], NDArrayHandle),
+ stype=_STORAGE_TYPE_ID_TO_STR[out_stypes[0]])
else:
- return [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle))
+ return [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle),
+ stype=_STORAGE_TYPE_ID_TO_STR[out_stypes[i]])
for i in range(num_output.value)]
@@ -128,17 +144,24 @@ def __call__(self, *args, **kwargs):
"CachedOp.__call__ got unexpected keyword argument(s): " + \
', '.join(kwargs.keys()))
- check_call(_LIB.MXInvokeCachedOp(
+ # return output stypes to avoid the c_api call for checking
+ # a handle's stype in _ndarray_cls
+ out_stypes = ctypes.POINTER(ctypes.c_int)()
+
+ check_call(_LIB.MXInvokeCachedOpEx(
self.handle,
ctypes.c_int(len(args)),
c_array(NDArrayHandle, [arr.handle for arr in args]),
ctypes.byref(num_output),
- ctypes.byref(output_vars)))
+ ctypes.byref(output_vars),
+ ctypes.byref(out_stypes)))
if original_output is not None:
return original_output
if num_output.value == 1:
- return _ndarray_cls(ctypes.cast(output_vars[0], NDArrayHandle))
+ return _ndarray_cls(ctypes.cast(output_vars[0], NDArrayHandle),
+ stype=_STORAGE_TYPE_ID_TO_STR[out_stypes[0]])
else:
- return [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle))
+ return [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle),
+ stype=_STORAGE_TYPE_ID_TO_STR[out_stypes[i]])
for i in range(num_output.value)]
diff --git a/python/mxnet/base.py b/python/mxnet/base.py
index aad0580e7d07..d446355da0b5 100644
--- a/python/mxnet/base.py
+++ b/python/mxnet/base.py
@@ -72,6 +72,20 @@ def __str__(self):
msg += ' is not implemented for Symbol and only available in NDArray.'
return msg
+class NotSupportedForSparseNDArray(MXNetError):
+ def __init__(self, function, alias, *args):
+ super(NotSupportedForSparseNDArray, self).__init__()
+ self.function = function.__name__
+ self.alias = alias
+ self.args = [str(type(a)) for a in args]
+ def __str__(self):
+ msg = 'Function {}'.format(self.function)
+ if self.alias:
+ msg += ' (namely operator "{}")'.format(self.alias)
+ if self.args:
+ msg += ' with arguments ({})'.format(', '.join(self.args))
+ msg += ' is not supported for SparseNDArray and only available in NDArray.'
+ return msg
class MXCallbackList(ctypes.Structure):
"""Structure that holds Callback information. Passed to CustomOpProp."""
diff --git a/python/mxnet/contrib/autograd.py b/python/mxnet/contrib/autograd.py
index c7fb6e17803a..2d2500e7a217 100644
--- a/python/mxnet/contrib/autograd.py
+++ b/python/mxnet/contrib/autograd.py
@@ -24,6 +24,7 @@
import functools
from ..base import _LIB, check_call, string_types
from ..base import mx_uint, NDArrayHandle, c_array
+# pylint: disable= unused-import
from ..ndarray import NDArray, zeros_like
from ..symbol import _GRAD_REQ_MAP
diff --git a/python/mxnet/executor.py b/python/mxnet/executor.py
index baff834bb33a..5cc94a5e80ac 100644
--- a/python/mxnet/executor.py
+++ b/python/mxnet/executor.py
@@ -27,6 +27,7 @@
from .base import mx_uint, NDArrayHandle, ExecutorHandle
from .base import check_call, c_array, py_str
from .ndarray import NDArray
+from .ndarray import _ndarray_cls
from . import ndarray as nd
# those functions are not used here, we just import them to keep backward compatibility
@@ -105,7 +106,9 @@ def _get_outputs(self):
handles = ctypes.POINTER(NDArrayHandle)()
check_call(_LIB.MXExecutorOutputs(self.handle,
ctypes.byref(out_size), ctypes.byref(handles)))
- return [NDArray(NDArrayHandle(handles[i])) for i in range(out_size.value)]
+ num_output = out_size.value
+ outputs = [_ndarray_cls(NDArrayHandle(handles[i])) for i in range(num_output)]
+ return outputs
def forward(self, is_train=False, **kwargs):
"""Calculate the outputs specified by the bound symbol.
diff --git a/python/mxnet/image/detection.py b/python/mxnet/image/detection.py
index 8ac1aebe72dd..f67b05de5de3 100644
--- a/python/mxnet/image/detection.py
+++ b/python/mxnet/image/detection.py
@@ -27,7 +27,7 @@
from ..base import numeric_types
from .. import ndarray as nd
-from .._ndarray_internal import _cvcopyMakeBorder as copyMakeBorder
+from ..ndarray._internal import _cvcopyMakeBorder as copyMakeBorder
from .. import io
from .image import RandomOrderAug, ColorJitterAug, LightingAug, ColorNormalizeAug
from .image import ResizeAug, ForceResizeAug, CastAug, HueJitterAug, RandomGrayAug
diff --git a/python/mxnet/image/image.py b/python/mxnet/image/image.py
index 2e40019971ac..d99db214222c 100644
--- a/python/mxnet/image/image.py
+++ b/python/mxnet/image/image.py
@@ -34,9 +34,9 @@
from ..base import numeric_types
from .. import ndarray as nd
-from .. import _ndarray_internal as _internal
-from .._ndarray_internal import _cvimresize as imresize
-from .._ndarray_internal import _cvcopyMakeBorder as copyMakeBorder
+from ..ndarray import _internal
+from ..ndarray._internal import _cvimresize as imresize
+from ..ndarray._internal import _cvcopyMakeBorder as copyMakeBorder
from .. import io
from .. import recordio
diff --git a/python/mxnet/io.py b/python/mxnet/io.py
index 0404e34ea36c..4e69a8a801cb 100644
--- a/python/mxnet/io.py
+++ b/python/mxnet/io.py
@@ -34,6 +34,7 @@
from .base import mx_real_t
from .base import check_call, build_param_doc as _build_param_doc
from .ndarray import NDArray
+from .ndarray import _ndarray_cls
from .ndarray import array
from .ndarray import concatenate
@@ -801,12 +802,12 @@ def iter_next(self):
def getdata(self):
hdl = NDArrayHandle()
check_call(_LIB.MXDataIterGetData(self.handle, ctypes.byref(hdl)))
- return NDArray(hdl, False)
+ return _ndarray_cls(hdl, False)
def getlabel(self):
hdl = NDArrayHandle()
check_call(_LIB.MXDataIterGetLabel(self.handle, ctypes.byref(hdl)))
- return NDArray(hdl, False)
+ return _ndarray_cls(hdl, False)
def getindex(self):
index_size = ctypes.c_uint64(0)
diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py
index fd0091182aea..2af70e36e60a 100644
--- a/python/mxnet/kvstore.py
+++ b/python/mxnet/kvstore.py
@@ -22,6 +22,7 @@
import ctypes
import pickle
from .ndarray import NDArray
+from .ndarray import _ndarray_cls
from .base import _LIB
from .base import check_call, c_array, c_str, string_types, mx_uint, py_str
from .base import NDArrayHandle, KVStoreHandle
@@ -53,8 +54,8 @@ def _updater_wrapper(updater):
"""A wrapper for the user-defined handle."""
def updater_handle(key, lhs_handle, rhs_handle, _):
""" ctypes function """
- lhs = NDArray(NDArrayHandle(lhs_handle))
- rhs = NDArray(NDArrayHandle(rhs_handle))
+ lhs = _ndarray_cls(NDArrayHandle(lhs_handle))
+ rhs = _ndarray_cls(NDArrayHandle(rhs_handle))
updater(key, lhs, rhs)
return updater_handle
@@ -186,6 +187,8 @@ def pull(self, key, out=None, priority=0):
The returned values are gauranteed to be the latest values in the store.
+ For row_sparse values, please use `row_sparse_pull` instead.
+
Parameters
----------
key : int or list of int
@@ -236,6 +239,66 @@ def pull(self, key, out=None, priority=0):
self.handle, mx_uint(len(ckeys)), ckeys, cvals,
ctypes.c_int(priority)))
+ def row_sparse_pull(self, key, out=None, priority=0, row_ids=None):
+ """ Pulls a single row_sparse value or a sequence of row_sparse values from the store
+ with specified row_ids.
+
+ `row_sparse_pull` is executed asynchronously after all previous
+ `push`/`pull`/`row_sparse_pull` calls for the same input key(s) are finished.
+
+ The returned values are guaranteed to be the latest values in the store.
+
+ Parameters
+ ----------
+ key : str or list of str
+ Keys.
+
+ out: NDArray or list of NDArray or list of list of NDArray
+ Values corresponding to the keys. The stype is expected to be row_sparse
+
+ priority : int, optional
+ The priority of the pull operation.
+ Higher priority pull operations are likely to be executed before
+ other pull actions.
+
+ row_ids : NDArray or list of NDArray
+ The row_ids for which to pull for each value. Each row_id is an 1D-NDArray \
+ whose values don't have to be unique nor sorted.
+
+ Examples
+ --------
+ >>> shape = (3, 3)
+ >>> kv.init('3', mx.nd.ones(shape).tostype('row_sparse'))
+ >>> a = mx.nd.zeros(shape, stype='row_sparse')
+ >>> row_ids = mx.nd.array([0, 2], dtype='int64')
+ >>> kv.row_sparse_pull('3', out=a, row_ids=row_ids)
+ >>> print a.asnumpy()
+ [[ 1. 1. 1.]
+ [ 0. 0. 0.]
+ [ 1. 1. 1.]]
+ >>> duplicate_row_ids = mx.nd.array([2, 2], dtype='int64')
+ >>> kv.row_sparse_pull('3', out=a, row_ids=duplicate_row_ids)
+ >>> print a.asnumpy()
+ [[ 0. 0. 0.]
+ [ 0. 0. 0.]
+ [ 1. 1. 1.]]
+ >>> unsorted_row_ids = mx.nd.array([1, 0], dtype='int64')
+ >>> kv.row_sparse_pull('3', out=a, row_ids=unsorted_row_ids)
+ >>> print a.asnumpy()
+ [[ 1. 1. 1.]
+ [ 1. 1. 1.]
+ [ 0. 0. 0.]]
+ """
+ assert(out is not None)
+ assert(row_ids is not None)
+ ckeys, cvals = _ctype_key_value(key, out)
+ _, crow_ids = _ctype_key_value(key, row_ids)
+ assert(len(crow_ids) == len(cvals)), "number of row_ids doesn't match number of values"
+
+ check_call(_LIB.MXKVStorePullRowSparse(
+ self.handle, mx_uint(len(ckeys)), ckeys, cvals, crow_ids, ctypes.c_int(priority)))
+
+
def set_optimizer(self, optimizer):
""" Registers an optimizer with the kvstore.
diff --git a/python/mxnet/model.py b/python/mxnet/model.py
index 01b3fa50e18f..2444ca0dc59e 100644
--- a/python/mxnet/model.py
+++ b/python/mxnet/model.py
@@ -93,8 +93,7 @@ def _create_kvstore(kvstore, num_device, arg_params):
return (kv, update_on_kvstore)
-def _initialize_kvstore(kvstore, param_arrays, arg_params, param_names,
- update_on_kvstore):
+def _initialize_kvstore(kvstore, param_arrays, arg_params, param_names, update_on_kvstore):
"""Initialize kvstore"""
for idx, param_on_devs in enumerate(param_arrays):
name = param_names[idx]
@@ -118,10 +117,11 @@ def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore, param_names):
def _update_params(param_arrays, grad_arrays, updater, num_device,
kvstore=None, param_names=None):
"""Perform update of param_arrays from grad_arrays not on kvstore."""
- for index, pair in enumerate(zip(param_arrays, grad_arrays)):
+ for i, pair in enumerate(zip(param_arrays, grad_arrays)):
arg_list, grad_list = pair
if grad_list[0] is None:
continue
+ index = i
if kvstore:
name = param_names[index]
# push gradient, priority is negative index
@@ -131,7 +131,7 @@ def _update_params(param_arrays, grad_arrays, updater, num_device,
for k, p in enumerate(zip(arg_list, grad_list)):
# faked an index here, to make optimizer create diff
# state for the same index but on diff devs, TODO(mli)
- # use a better solution latter
+ # use a better solution later
w, g = p
updater(index*num_device+k, g, w)
diff --git a/python/mxnet/module/base_module.py b/python/mxnet/module/base_module.py
index 3123462f9c7c..bae166e3ffd8 100644
--- a/python/mxnet/module/base_module.py
+++ b/python/mxnet/module/base_module.py
@@ -957,7 +957,8 @@ def bind(self, data_shapes, label_shapes=None, for_training=True,
def init_optimizer(self, kvstore='local', optimizer='sgd',
optimizer_params=(('learning_rate', 0.01),), force_init=False):
- """Installs and initializes optimizers.
+ """Installs and initializes optimizers, as well as initialize kvstore for
+ distributed training
Parameters
----------
diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py
index 058edd57eb3d..d55b2117ebd3 100644
--- a/python/mxnet/module/module.py
+++ b/python/mxnet/module/module.py
@@ -25,7 +25,6 @@
import warnings
from .. import context as ctx
-from .. import ndarray as nd
from .. import optimizer as opt
from .executor_group import DataParallelExecutorGroup
@@ -33,6 +32,7 @@
from ..model import load_checkpoint
from ..initializer import Uniform, InitDesc
from ..io import DataDesc
+from ..ndarray import zeros
from .base_module import BaseModule, _check_input_names, _parse_data_desc
@@ -427,13 +427,13 @@ def bind(self, data_shapes, label_shapes=None, for_training=True,
else:
assert self._arg_params is None and self._aux_params is None
param_arrays = [
- nd.zeros(x[0].shape, dtype=x[0].dtype)
+ zeros(shape=x[0].shape, dtype=x[0].dtype, stype=x[0].stype)
for x in self._exec_group.param_arrays
]
self._arg_params = {name:arr for name, arr in zip(self._param_names, param_arrays)}
aux_arrays = [
- nd.zeros(x[0].shape, dtype=x[0].dtype)
+ zeros(x[0].shape, dtype=x[0].dtype)
for x in self._exec_group.aux_arrays
]
self._aux_params = {name:arr for name, arr in zip(self._aux_names, aux_arrays)}
@@ -441,7 +441,6 @@ def bind(self, data_shapes, label_shapes=None, for_training=True,
if shared_module is not None and shared_module.optimizer_initialized:
self.borrow_optimizer(shared_module)
-
def reshape(self, data_shapes, label_shapes=None):
"""Reshapes the module for new input shapes.
@@ -483,6 +482,7 @@ def init_optimizer(self, kvstore='local', optimizer='sgd',
if self._params_dirty:
self._sync_params_from_devices()
+
(kvstore, update_on_kvstore) = \
_create_kvstore(kvstore, len(self._context), self._arg_params)
diff --git a/python/mxnet/ndarray/__init__.py b/python/mxnet/ndarray/__init__.py
new file mode 100644
index 000000000000..63220787a43c
--- /dev/null
+++ b/python/mxnet/ndarray/__init__.py
@@ -0,0 +1,25 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""NDArray API of MXNet."""
+
+from . import _internal, sparse, op
+from .op import CachedOp
+# pylint: disable=wildcard-import, redefined-builtin
+from .ndarray import *
+from .utils import load, save, zeros, empty, array
+from .sparse import _ndarray_cls
diff --git a/python/mxnet/_ndarray_internal.py b/python/mxnet/ndarray/_internal.py
similarity index 100%
rename from python/mxnet/_ndarray_internal.py
rename to python/mxnet/ndarray/_internal.py
diff --git a/python/mxnet/ndarray.py b/python/mxnet/ndarray/ndarray.py
similarity index 87%
rename from python/mxnet/ndarray.py
rename to python/mxnet/ndarray/ndarray.py
index 42f0ff5e87cf..20ca2262f0cd 100644
--- a/python/mxnet/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -21,6 +21,7 @@
"""NDArray API of MXNet."""
from __future__ import absolute_import
from __future__ import division
+
try:
from __builtin__ import slice as py_slice
except ImportError:
@@ -28,40 +29,25 @@
import ctypes
import warnings
-
-import os as _os
-import sys as _sys
-
import operator
import numpy as np
-from .base import _LIB, string_types, numeric_types, integer_types
-from .base import c_array, py_str, c_str, mx_real_t, _Null # pylint: disable=unused-import
-from .base import mx_uint, NDArrayHandle, check_call, OpHandle
-from .base import ctypes2buffer
-from .context import Context
-from . import _ndarray_internal as _internal
-from .ndarray_doc import _build_doc
-
-
-# Use different version of SymbolBase
-# When possible, use cython to speedup part of computation.
-# pylint: disable=unused-import
-try:
- if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0:
- from ._ctypes.ndarray import NDArrayBase, _set_ndarray_class
- from ._ctypes.ndarray import CachedOp, _imperative_invoke
- elif _sys.version_info >= (3, 0):
- from ._cy3.ndarray import NDArrayBase, _set_ndarray_class, _imperative_invoke
- from ._cy3.ndarray import CachedOp, _imperative_invoke
- else:
- from ._cy2.ndarray import NDArrayBase, _set_ndarray_class, _imperative_invoke
- from ._cy2.ndarray import CachedOp, _imperative_invoke
-except ImportError:
- if int(_os.environ.get("MXNET_ENFORCE_CYTHON", False)) != 0:
- raise ImportError("Cython Module cannot be loaded but MXNET_ENFORCE_CYTHON=1")
- from ._ctypes.ndarray import NDArrayBase, _set_ndarray_class, _imperative_invoke
- from ._ctypes.ndarray import CachedOp, _imperative_invoke
-# pylint: enable=unused-import
+from ..base import _LIB, numeric_types, integer_types
+from ..base import c_array, mx_real_t
+from ..base import mx_uint, NDArrayHandle, check_call
+from ..base import ctypes2buffer
+from ..context import Context
+from . import _internal
+from .op import NDArrayBase, _STORAGE_TYPE_ID_TO_STR
+from . import broadcast_add, broadcast_mul, transpose, broadcast_not_equal, broadcast_power
+from . import broadcast_sub, broadcast_div, broadcast_to, broadcast_equal, cast_storage
+from . import broadcast_greater, broadcast_greater_equal, broadcast_lesser, broadcast_lesser_equal
+from . import zeros_like, slice
+
+__all__ = ["NDArray", "concatenate", "_DTYPE_NP_TO_MX", "_DTYPE_MX_TO_NP", "_GRAD_REQ_MAP",
+ "ones", "add", "arange", "divide", "equal", "full", "greater", "greater_equal",
+ "imdecode", "lesser", "lesser_equal", "maximum", "minimum", "moveaxis",
+ "multiply", "negative", "not_equal", "onehot_encode", "power", "subtract",
+ "true_divide", "waitall", "_new_empty_handle"]
# pylint: disable= no-member
_DTYPE_NP_TO_MX = {
@@ -74,7 +60,6 @@
np.int8 : 5,
np.int64 : 6,
}
-
_DTYPE_MX_TO_NP = {
-1 : None,
0 : np.float32,
@@ -85,7 +70,12 @@
5 : np.int8,
6 : np.int64,
}
-
+_STORAGE_TYPE_STR_TO_ID = {
+ 'undefined' : -1,
+ 'default' : 0,
+ 'row_sparse' : 1,
+ 'csr' : 2,
+}
_GRAD_REQ_MAP = {
'null': 0,
'write': 1,
@@ -93,6 +83,7 @@
}
# pylint: enable= no-member
+
def _new_empty_handle():
"""Returns a new empty handle.
@@ -107,6 +98,7 @@ def _new_empty_handle():
check_call(_LIB.MXNDArrayCreateNone(ctypes.byref(hdl)))
return hdl
+
def _new_alloc_handle(shape, ctx, delay_alloc, dtype=mx_real_t):
"""Return a new handle with specified shape and context.
@@ -128,6 +120,7 @@ def _new_alloc_handle(shape, ctx, delay_alloc, dtype=mx_real_t):
ctypes.byref(hdl)))
return hdl
+
def waitall():
"""Wait for all async operations to finish in MXNet.
@@ -135,6 +128,13 @@ def waitall():
"""
check_call(_LIB.MXNDArrayWaitAll())
+
+def _storage_type(handle):
+ storage_type = ctypes.c_int(0)
+ check_call(_LIB.MXNDArrayGetStorageType(handle, ctypes.byref(storage_type)))
+ return _STORAGE_TYPE_ID_TO_STR[storage_type.value]
+
+
class NDArray(NDArrayBase):
"""An array object representing a multidimensional, homogeneous array of
fixed-size items.
@@ -144,6 +144,7 @@ class NDArray(NDArrayBase):
# make numpy functions return NDArray instead of numpy object array
__array_priority__ = 1000.0
# pylint: disable= no-member, undefined-variable
+
def __repr__(self):
"""Returns a string representation of the array."""
shape_info = 'x'.join(['%d' % x for x in self.shape])
@@ -151,6 +152,9 @@ def __repr__(self):
self.__class__.__name__,
shape_info, self.context)
+ def __reduce__(self):
+ return NDArray, (None,), self.__getstate__()
+
def __add__(self, other):
"""x.__add__(y) <=> x+y <=> mx.nd.add(x, y) """
return add(self, other)
@@ -742,7 +746,6 @@ def wait_to_read(self):
"""
check_call(_LIB.MXNDArrayWaitToRead(self.handle))
-
@property
def ndim(self):
"""Returns the number of dimensions of this array
@@ -777,6 +780,7 @@ def shape(self):
self.handle, ctypes.byref(ndim), ctypes.byref(pdata)))
return tuple(pdata[:ndim.value])
+
@property
def size(self):
"""Number of elements in the array.
@@ -841,6 +845,12 @@ def dtype(self):
self.handle, ctypes.byref(mx_dtype)))
return _DTYPE_MX_TO_NP[mx_dtype.value]
+ @property
+ def stype(self):
+ """Storage-type of the array.
+ """
+ return _storage_type(self.handle)
+
@property
# pylint: disable= invalid-name, undefined-variable
def T(self):
@@ -964,7 +974,7 @@ def copyto(self, other):
Returns
-------
- NDArray
+ NDArray, CSRNDArray, RowSparseNDArray
The copied array. If ``other`` is an ``NDArray``, then the return value
and ``other`` will point to the same ``NDArray``.
@@ -1101,6 +1111,20 @@ def backward(self, out_grad=None, retain_graph=False, train_mode=True):
ctypes.c_int(retain_graph),
ctypes.c_int(train_mode)))
+ def tostype(self, stype):
+ """Return a copy of the array with chosen storage type.
+
+ See Also
+ ----------
+ :meth:`mxnet.ndarray.cast_storage`.
+
+ Returns
+ -------
+ NDArray, CSRNDArray or RowSparseNDArray
+ A copy of the array with the chosen storage stype
+ """
+ return cast_storage(self, stype=stype)
+
def onehot_encode(indices, out):
"""One-hot encoding indices into matrix out.
@@ -1113,74 +1137,7 @@ def onehot_encode(indices, out):
# pylint: enable= no-member, protected-access
-def empty(shape, ctx=None, dtype=mx_real_t):
- """Returns a new array of given shape and type, without initializing entries.
-
- Parameters
- ----------
- shape : int or tuple of int
- The shape of the empty array.
- ctx : Context, optional
- An optional device context (default is the current default context).
- dtype : str or numpy.dtype, optional
- An optional value type (default is `float32`).
-
- Returns
- -------
- NDArray
- A created array.
-
- Examples
- --------
- >>> mx.nd.empty(1)
-
- >>> mx.nd.empty((1,2), mx.gpu(0))
-
- >>> mx.nd.empty((1,2), mx.gpu(0), 'float16')
-
- """
- if isinstance(shape, integer_types):
- shape = (shape, )
- if ctx is None:
- ctx = Context.default_ctx
- return NDArray(handle=_new_alloc_handle(shape, ctx, False, dtype))
-
-def zeros(shape, ctx=None, dtype=mx_real_t, **kwargs):
- """Returns a new array filled with all zeros, with the given shape and type.
-
- Parameters
- ----------
- shape : int or tuple of int
- The shape of the empty array.
- ctx : Context, optional
- An optional device context (default is the current default context).
- dtype : str or numpy.dtype, optional
- An optional value type (default is `float32`).
- out : NDArray, optional
- The output NDArray (default is `None`).
-
- Returns
- -------
- NDArray
- A created array
-
- Examples
- --------
- >>> mx.nd.zeros(1).asnumpy()
- array([ 0.], dtype=float32)
- >>> mx.nd.zeros((1,2), mx.gpu(0))
-
- >>> mx.nd.zeros((1,2), mx.gpu(0), 'float16').asnumpy()
- array([[ 0., 0.]], dtype=float16)
- """
- # pylint: disable= unused-argument
- if ctx is None:
- ctx = Context.default_ctx
- # pylint: disable= no-member, protected-access
- return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype, **kwargs)
- # pylint: enable= no-member, protected-access
-
-def ones(shape, ctx=None, dtype=mx_real_t, **kwargs):
+def ones(shape, ctx=None, dtype=None, **kwargs):
"""Returns a new array filled with all ones, with the given shape and type.
Parameters
@@ -1212,10 +1169,12 @@ def ones(shape, ctx=None, dtype=mx_real_t, **kwargs):
# pylint: disable= unused-argument
if ctx is None:
ctx = Context.default_ctx
+ dtype = mx_real_t if dtype is None else dtype
# pylint: disable= no-member, protected-access
return _internal._ones(shape=shape, ctx=ctx, dtype=dtype, **kwargs)
# pylint: enable= no-member, protected-access
+
def full(shape, val, ctx=None, dtype=mx_real_t, out=None):
"""Returns a new array of given shape and type, filled with the given value `val`.
@@ -1269,18 +1228,6 @@ def array(source_array, ctx=None, dtype=None):
-------
NDArray
An `NDArray` with the same contents as the `source_array`.
-
- Examples
- --------
- >>> import numpy as np
- >>> mx.nd.array([1, 2, 3])
-
- >>> mx.nd.array([[1, 2], [3, 4]])
-
- >>> mx.nd.array(np.zeros((3, 2)))
-
- >>> mx.nd.array(np.zeros((3, 2)), mx.gpu(0))
-
"""
if isinstance(source_array, NDArray):
dtype = source_array.dtype if dtype is None else dtype
@@ -1382,6 +1329,7 @@ def arange(start, stop=None, step=1.0, repeat=1, ctx=None, dtype=mx_real_t):
dtype=dtype, ctx=str(ctx))
# pylint: enable= no-member, protected-access, too-many-arguments
+
#pylint: disable= too-many-arguments, no-member, protected-access
def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None):
""" Helper function for element-wise operation.
@@ -1430,6 +1378,7 @@ def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None):
raise TypeError('type %s not supported' % str(type(rhs)))
#pylint: enable= too-many-arguments, no-member, protected-access
+
def add(lhs, rhs):
"""Returns element-wise sum of the input arrays with broadcasting.
@@ -1491,6 +1440,7 @@ def add(lhs, rhs):
None)
# pylint: enable= no-member, protected-access
+
def subtract(lhs, rhs):
"""Returns element-wise difference of the input arrays with broadcasting.
@@ -1552,6 +1502,7 @@ def subtract(lhs, rhs):
_internal._rminus_scalar)
# pylint: enable= no-member, protected-access
+
def multiply(lhs, rhs):
"""Returns element-wise product of the input arrays with broadcasting.
@@ -1612,6 +1563,7 @@ def multiply(lhs, rhs):
None)
# pylint: enable= no-member, protected-access
+
def divide(lhs, rhs):
"""Returns element-wise division of the input arrays with broadcasting.
@@ -1668,6 +1620,7 @@ def divide(lhs, rhs):
_internal._rdiv_scalar)
# pylint: enable= no-member, protected-access
+
def modulo(lhs, rhs):
"""Returns element-wise modulo of the input arrays with broadcasting.
@@ -1724,6 +1677,7 @@ def modulo(lhs, rhs):
_internal._rmod_scalar)
# pylint: enable= no-member, protected-access
+
def power(base, exp):
"""Returns result of first array elements raised to powers from second array, element-wise
with broadcasting.
@@ -1785,6 +1739,7 @@ def power(base, exp):
_internal._rpower_scalar)
# pylint: enable= no-member, protected-access
+
def maximum(lhs, rhs):
"""Returns element-wise maximum of the input arrays with broadcasting.
@@ -1841,6 +1796,7 @@ def maximum(lhs, rhs):
None)
# pylint: enable= no-member, protected-access
+
def minimum(lhs, rhs):
"""Returns element-wise minimum of the input arrays with broadcasting.
@@ -1897,6 +1853,7 @@ def minimum(lhs, rhs):
None)
# pylint: enable= no-member, protected-access
+
def equal(lhs, rhs):
"""Returns the result of element-wise **equal to** (==) comparison operation with
broadcasting.
@@ -1960,6 +1917,7 @@ def equal(lhs, rhs):
None)
# pylint: enable= no-member, protected-access
+
def not_equal(lhs, rhs):
"""Returns the result of element-wise **not equal to** (!=) comparison operation
with broadcasting.
@@ -2026,6 +1984,7 @@ def not_equal(lhs, rhs):
None)
# pylint: enable= no-member, protected-access
+
def greater(lhs, rhs):
"""Returns the result of element-wise **greater than** (>) comparison operation
with broadcasting.
@@ -2089,6 +2048,7 @@ def greater(lhs, rhs):
_internal._lesser_scalar)
# pylint: enable= no-member, protected-access
+
def greater_equal(lhs, rhs):
"""Returns the result of element-wise **greater than or equal to** (>=) comparison
operation with broadcasting.
@@ -2152,6 +2112,7 @@ def greater_equal(lhs, rhs):
_internal._lesser_equal_scalar)
# pylint: enable= no-member, protected-access
+
def lesser(lhs, rhs):
"""Returns the result of element-wise **lesser than** (<) comparison operation
with broadcasting.
@@ -2279,12 +2240,14 @@ def lesser_equal(lhs, rhs):
_internal._greater_equal_scalar)
# pylint: enable= no-member, protected-access
+
def true_divide(lhs, rhs):
"""This function is similar to :meth:`divide`.
"""
return divide(lhs, rhs)
+
def negative(arr):
"""Numerical negative, element-wise.
@@ -2310,95 +2273,6 @@ def negative(arr):
return multiply(arr, -1.0)
-def load(fname):
- """Loads an array from file.
-
- See more details in ``save``.
-
- Parameters
- ----------
- fname : str
- The filename.
-
- Returns
- -------
- list of NDArray or dict of str to NDArray
- Loaded data.
- """
- if not isinstance(fname, string_types):
- raise TypeError('fname required to be a string')
- out_size = mx_uint()
- out_name_size = mx_uint()
- handles = ctypes.POINTER(NDArrayHandle)()
- names = ctypes.POINTER(ctypes.c_char_p)()
- check_call(_LIB.MXNDArrayLoad(c_str(fname),
- ctypes.byref(out_size),
- ctypes.byref(handles),
- ctypes.byref(out_name_size),
- ctypes.byref(names)))
- if out_name_size.value == 0:
- return [NDArray(NDArrayHandle(handles[i])) for i in range(out_size.value)]
- else:
- assert out_name_size.value == out_size.value
- return dict(
- (py_str(names[i]), NDArray(NDArrayHandle(handles[i]))) for i in range(out_size.value))
-
-
-def save(fname, data):
- """Saves a list of arrays or a dict of str->array to file.
-
- Examples of filenames:
-
- - ``/path/to/file``
- - ``s3://my-bucket/path/to/file`` (if compiled with AWS S3 supports)
- - ``hdfs://path/to/file`` (if compiled with HDFS supports)
-
- Parameters
- ----------
- fname : str
- The filename.
- data : ``NDArray``, list of ``NDArray` or dict of str to ``NDArray``
- The data to save.
-
- Examples
- --------
- >>> x = mx.nd.zeros((2,3))
- >>> y = mx.nd.ones((1,4))
- >>> mx.nd.save('my_list', [x,y])
- >>> mx.nd.save('my_dict', {'x':x, 'y':y})
- >>> mx.nd.load('my_list')
- [, ]
- >>> mx.nd.load('my_dict')
- {'y': , 'x': }
- """
- if isinstance(data, NDArray):
- data = [data]
- handles = []
- if isinstance(data, dict):
- keys = []
- for key, val in data.items():
- if not isinstance(key, string_types):
- raise TypeError('save only accept dict str->NDArray or list of NDArray')
- if not isinstance(val, NDArray):
- raise TypeError('save only accept dict str->NDArray or list of NDArray')
- keys.append(c_str(key))
- handles.append(val.handle)
- keys = c_array(ctypes.c_char_p, keys)
- elif isinstance(data, list):
- for val in data:
- if not isinstance(val, NDArray):
- raise TypeError('save only accept dict str->NDArray or list of NDArray')
- handles.append(val.handle)
- keys = None
- else:
- raise ValueError("data needs to either be a NDArray, dict of str, NDArray pairs "
- "or a list of NDarrays.")
- check_call(_LIB.MXNDArraySave(c_str(fname),
- mx_uint(len(handles)),
- c_array(NDArrayHandle, handles),
- keys))
-
-
def concatenate(arrays, axis=0, always_copy=True):
"""DEPRECATED, use ``concat`` instead
@@ -2455,6 +2329,7 @@ def concatenate(arrays, axis=0, always_copy=True):
return ret
+
def imdecode(str_img, clip_rect=(0, 0, 0, 0), out=None, index=0, channels=3, mean=None):
"""DEPRECATED, use mx.img instead
@@ -2497,159 +2372,65 @@ def imdecode(str_img, clip_rect=(0, 0, 0, 0), out=None, index=0, channels=3, mea
out=out)
-# pylint: disable=too-many-locals, invalid-name
-def _make_ndarray_function(handle, name):
- """Create a NDArray function from the FunctionHandle."""
- real_name = ctypes.c_char_p()
- desc = ctypes.c_char_p()
- num_args = mx_uint()
- arg_names = ctypes.POINTER(ctypes.c_char_p)()
- arg_types = ctypes.POINTER(ctypes.c_char_p)()
- arg_descs = ctypes.POINTER(ctypes.c_char_p)()
- key_var_num_args = ctypes.c_char_p()
- ret_type = ctypes.c_char_p()
-
- check_call(_LIB.MXSymbolGetAtomicSymbolInfo(
- handle, ctypes.byref(real_name), ctypes.byref(desc),
- ctypes.byref(num_args),
- ctypes.byref(arg_names),
- ctypes.byref(arg_types),
- ctypes.byref(arg_descs),
- ctypes.byref(key_var_num_args),
- ctypes.byref(ret_type)))
- narg = int(num_args.value)
- arg_names = [py_str(arg_names[i]) for i in range(narg)]
- arg_types = [py_str(arg_types[i]) for i in range(narg)]
- func_name = name
- key_var_num_args = py_str(key_var_num_args.value)
- ret_type = py_str(ret_type.value) if ret_type.value is not None else ''
- doc_str = _build_doc(func_name,
- py_str(desc.value),
- arg_names,
- arg_types,
- [py_str(arg_descs[i]) for i in range(narg)],
- key_var_num_args,
- ret_type)
-
- dtype_name = None
- arr_name = None
- ndsignature = []
- signature = []
- ndarg_names = []
- kwarg_names = []
- for i in range(narg):
- name, atype = arg_names[i], arg_types[i]
- if name == 'dtype':
- dtype_name = name
- signature.append('%s=_Null'%name)
- elif atype.startswith('NDArray') or atype.startswith('Symbol'):
- assert not arr_name, \
- "Op can only have one argument with variable " \
- "size and it must be the last argument."
- if atype.endswith('[]'):
- ndsignature.append('*%s'%name)
- arr_name = name
- else:
- ndsignature.append('%s=None'%name)
- ndarg_names.append(name)
- else:
- signature.append('%s=_Null'%name)
- kwarg_names.append(name)
- signature.append('out=None')
- signature.append('name=None')
- signature.append('**kwargs')
- signature = ndsignature + signature
-
- code = []
- if arr_name:
- code.append("""
-def %s(*%s, **kwargs):"""%(func_name, arr_name))
- code.append("""
- ndargs = []
- for i in {}:
- assert isinstance(i, NDArrayBase), \\
- "Positional arguments must have NDArray type, " \\
- "but got %s"%str(i)
- ndargs.append(i)""".format(arr_name))
- if dtype_name is not None:
- code.append("""
- if '%s' in kwargs:
- kwargs['%s'] = np.dtype(kwargs['%s']).name"""%(
- dtype_name, dtype_name, dtype_name))
- code.append("""
- _ = kwargs.pop('name', None)
- out = kwargs.pop('out', None)
- keys = list(kwargs.keys())
- vals = list(kwargs.values())""")
- else:
- code.append("""
-def %s(%s):
- ndargs = []
- keys = list(kwargs.keys())
- vals = list(kwargs.values())"""%(func_name, ', '.join(signature)))
- # NDArray args
- for name in ndarg_names: # pylint: disable=redefined-argument-from-local
- code.append("""
- if {name} is not None:
- assert isinstance({name}, NDArrayBase), \\
- "Argument {name} must have NDArray type, but got %s"%str({name})
- ndargs.append({name})""".format(name=name))
- # kwargs
- for name in kwarg_names: # pylint: disable=redefined-argument-from-local
- code.append("""
- if %s is not _Null:
- keys.append('%s')
- vals.append(%s)"""%(name, name, name))
- # dtype
- if dtype_name is not None:
- code.append("""
- if %s is not _Null:
- keys.append('%s')
- vals.append(np.dtype(%s).name)"""%(dtype_name, dtype_name, dtype_name))
-
- code.append("""
- return _imperative_invoke(%d, ndargs, keys, vals, out)"""%(
- handle.value))
-
- local = {}
- exec(''.join(code), None, local) # pylint: disable=exec-used
- ndarray_function = local[func_name]
- ndarray_function.__name__ = func_name
- ndarray_function.__doc__ = doc_str
- ndarray_function.__module__ = 'mxnet.ndarray'
- return ndarray_function
-
-
-# pylint: enable=too-many-locals, invalid-name
-def _init_ndarray_module(ndarray_class, root_namespace):
- """List and add all the ndarray functions to current module."""
- _set_ndarray_class(ndarray_class)
- plist = ctypes.POINTER(ctypes.c_char_p)()
- size = ctypes.c_uint()
-
- check_call(_LIB.MXListAllOpNames(ctypes.byref(size),
- ctypes.byref(plist)))
- op_names = []
- for i in range(size.value):
- op_names.append(py_str(plist[i]))
-
- module_obj = _sys.modules["%s.ndarray" % root_namespace]
- module_internal = _sys.modules["%s._ndarray_internal" % root_namespace]
- module_contrib = _sys.modules["%s.contrib.ndarray" % root_namespace]
- for name in op_names:
- hdl = OpHandle()
- check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl)))
- function = _make_ndarray_function(hdl, name)
- if function.__name__.startswith('_contrib_'):
- function.__name__ = function.__name__[9:]
- function.__module__ = 'mxnet.contrib.ndarray'
- setattr(module_contrib, function.__name__, function)
- elif function.__name__.startswith('_'):
- setattr(module_internal, function.__name__, function)
- else:
- setattr(module_obj, function.__name__, function)
+def zeros(shape, ctx=None, dtype=None, **kwargs):
+ """Returns a new array filled with all zeros, with the given shape and type.
+
+ Parameters
+ ----------
+ shape : int or tuple of int
+ The shape of the empty array.
+ ctx : Context, optional
+ An optional device context (default is the current default context).
+ dtype : str or numpy.dtype, optional
+ An optional value type (default is `float32`).
+ out : NDArray, optional
+ The output NDArray (default is `None`).
+
+ Returns
+ -------
+ NDArray
+ A created array
+
+ Examples
+ --------
+ >>> mx.nd.zeros(1).asnumpy()
+ array([ 0.], dtype=float32)
+ >>> mx.nd.zeros((1,2), mx.gpu(0))
+
+ >>> mx.nd.zeros((1,2), mx.gpu(0), 'float16').asnumpy()
+ array([[ 0., 0.]], dtype=float16)
+ """
+ # pylint: disable= unused-argument
+ if ctx is None:
+ ctx = Context.default_ctx
+ dtype = mx_real_t if dtype is None else dtype
+ # pylint: disable= no-member, protected-access
+ return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype, **kwargs)
+ # pylint: enable= no-member, protected-access
+
+
+def empty(shape, ctx=None, dtype=None):
+ """Returns a new array of given shape and type, without initializing entries.
+
+ Parameters
+ ----------
+ shape : int or tuple of int
+ The shape of the empty array.
+ ctx : Context, optional
+ An optional device context (default is the current default context).
+ dtype : str or numpy.dtype, optional
+ An optional value type (default is `float32`).
-_init_ndarray_module(NDArray, "mxnet")
+ Returns
+ -------
+ NDArray
+ A created array.
-# from .base import add_fileline_to_docstring
-# add_fileline_to_docstring(__name__)
+ """
+ if isinstance(shape, int):
+ shape = (shape, )
+ if ctx is None:
+ ctx = Context.default_ctx
+ if dtype is None:
+ dtype = mx_real_t
+ return NDArray(handle=_new_alloc_handle(shape, ctx, False, dtype))
diff --git a/python/mxnet/ndarray/op.py b/python/mxnet/ndarray/op.py
new file mode 100644
index 000000000000..e4a1ab0df48b
--- /dev/null
+++ b/python/mxnet/ndarray/op.py
@@ -0,0 +1,209 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Register backend ops in mxnet.ndarray namespace"""
+
+import sys as _sys
+import os as _os
+import ctypes
+import numpy as np # pylint: disable=unused-import
+
+from ..ndarray_doc import _build_doc
+
+# Use different version of SymbolBase
+# When possible, use cython to speedup part of computation.
+# pylint: disable=unused-import
+try:
+ if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0:
+ from .._ctypes.ndarray import NDArrayBase, _STORAGE_TYPE_ID_TO_STR
+ from .._ctypes.ndarray import CachedOp, _imperative_invoke
+ elif _sys.version_info >= (3, 0):
+ from .._cy3.ndarray import NDArrayBase, _imperative_invoke, _STORAGE_TYPE_ID_TO_STR
+ from .._cy3.ndarray import CachedOp, _imperative_invoke
+ else:
+ from .._cy2.ndarray import NDArrayBase, _imperative_invoke, _STORAGE_TYPE_ID_TO_STR
+ from .._cy2.ndarray import CachedOp, _imperative_invoke
+except ImportError:
+ if int(_os.environ.get("MXNET_ENFORCE_CYTHON", False)) != 0:
+ raise ImportError("Cython Module cannot be loaded but MXNET_ENFORCE_CYTHON=1")
+ from .._ctypes.ndarray import NDArrayBase, _imperative_invoke, _STORAGE_TYPE_ID_TO_STR
+ from .._ctypes.ndarray import CachedOp, _imperative_invoke
+
+from ..base import mx_uint, check_call, _LIB, py_str, OpHandle, c_str, _Null
+# pylint: enable=unused-import
+
+
+# pylint: disable=too-many-locals, invalid-name
+def _make_ndarray_function(handle, name):
+ """Create a NDArray function from the FunctionHandle."""
+ real_name = ctypes.c_char_p()
+ desc = ctypes.c_char_p()
+ num_args = mx_uint()
+ arg_names = ctypes.POINTER(ctypes.c_char_p)()
+ arg_types = ctypes.POINTER(ctypes.c_char_p)()
+ arg_descs = ctypes.POINTER(ctypes.c_char_p)()
+ key_var_num_args = ctypes.c_char_p()
+ ret_type = ctypes.c_char_p()
+
+ check_call(_LIB.MXSymbolGetAtomicSymbolInfo(
+ handle, ctypes.byref(real_name), ctypes.byref(desc),
+ ctypes.byref(num_args),
+ ctypes.byref(arg_names),
+ ctypes.byref(arg_types),
+ ctypes.byref(arg_descs),
+ ctypes.byref(key_var_num_args),
+ ctypes.byref(ret_type)))
+ narg = int(num_args.value)
+ arg_names = [py_str(arg_names[i]) for i in range(narg)]
+ arg_types = [py_str(arg_types[i]) for i in range(narg)]
+ func_name = name
+ key_var_num_args = py_str(key_var_num_args.value)
+ ret_type = py_str(ret_type.value) if ret_type.value is not None else ''
+ doc_str = _build_doc(func_name,
+ py_str(desc.value),
+ arg_names,
+ arg_types,
+ [py_str(arg_descs[i]) for i in range(narg)],
+ key_var_num_args,
+ ret_type)
+
+ dtype_name = None
+ arr_name = None
+ ndsignature = []
+ signature = []
+ ndarg_names = []
+ kwarg_names = []
+ for i in range(narg):
+ name, atype = arg_names[i], arg_types[i]
+ if name == 'dtype':
+ dtype_name = name
+ signature.append('%s=_Null'%name)
+ elif atype.startswith('NDArray') or atype.startswith('Symbol'):
+ assert not arr_name, \
+ "Op can only have one argument with variable " \
+ "size and it must be the last argument."
+ if atype.endswith('[]'):
+ ndsignature.append('*%s'%name)
+ arr_name = name
+ else:
+ ndsignature.append('%s=None'%name)
+ ndarg_names.append(name)
+ else:
+ signature.append('%s=_Null'%name)
+ kwarg_names.append(name)
+ signature.append('out=None')
+ signature.append('name=None')
+ signature.append('**kwargs')
+ signature = ndsignature + signature
+
+ code = []
+ if arr_name:
+ code.append("""
+def %s(*%s, **kwargs):"""%(func_name, arr_name))
+ code.append("""
+ ndargs = []
+ for i in {}:
+ assert isinstance(i, NDArrayBase), \\
+ "Positional arguments must have NDArray type, " \\
+ "but got %s"%str(i)
+ ndargs.append(i)""".format(arr_name))
+ if dtype_name is not None:
+ code.append("""
+ if '%s' in kwargs:
+ kwargs['%s'] = np.dtype(kwargs['%s']).name"""%(
+ dtype_name, dtype_name, dtype_name))
+ code.append("""
+ _ = kwargs.pop('name', None)
+ out = kwargs.pop('out', None)
+ keys = list(kwargs.keys())
+ vals = list(kwargs.values())""")
+ else:
+ code.append("""
+def %s(%s):
+ ndargs = []
+ keys = list(kwargs.keys())
+ vals = list(kwargs.values())"""%(func_name, ', '.join(signature)))
+ # NDArray args
+ for name in ndarg_names: # pylint: disable=redefined-argument-from-local
+ code.append("""
+ if {name} is not None:
+ assert isinstance({name}, NDArrayBase), \\
+ "Argument {name} must have NDArray type, but got %s"%str({name})
+ ndargs.append({name})""".format(name=name))
+ # kwargs
+ for name in kwarg_names: # pylint: disable=redefined-argument-from-local
+ code.append("""
+ if %s is not _Null:
+ keys.append('%s')
+ vals.append(%s)"""%(name, name, name))
+ # dtype
+ if dtype_name is not None:
+ code.append("""
+ if %s is not _Null:
+ keys.append('%s')
+ vals.append(np.dtype(%s).name)"""%(dtype_name, dtype_name, dtype_name))
+
+ code.append("""
+ return _imperative_invoke(%d, ndargs, keys, vals, out)"""%(
+ handle.value))
+
+ local = {}
+ exec(''.join(code), None, local) # pylint: disable=exec-used
+ ndarray_function = local[func_name]
+ ndarray_function.__name__ = func_name
+ ndarray_function.__doc__ = doc_str
+ ndarray_function.__module__ = 'mxnet.ndarray'
+ return ndarray_function
+
+
+# pylint: enable=too-many-locals, invalid-name
+def _init_ndarray_module(root_namespace):
+ """List and add all the ndarray functions to current module."""
+ plist = ctypes.POINTER(ctypes.c_char_p)()
+ size = ctypes.c_uint()
+
+ check_call(_LIB.MXListAllOpNames(ctypes.byref(size),
+ ctypes.byref(plist)))
+ op_names = []
+ for i in range(size.value):
+ op_names.append(py_str(plist[i]))
+
+ module_obj = _sys.modules["%s.ndarray" % root_namespace]
+ module_sparse = _sys.modules["%s.ndarray.sparse" % root_namespace]
+ module_internal = _sys.modules["%s.ndarray._internal" % root_namespace]
+ module_contrib = _sys.modules["%s.contrib.ndarray" % root_namespace]
+ for name in op_names:
+ hdl = OpHandle()
+ check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl)))
+ function = _make_ndarray_function(hdl, name)
+ if function.__name__.startswith('_contrib_'):
+ function.__name__ = function.__name__[9:]
+ function.__module__ = 'mxnet.contrib.ndarray'
+ setattr(module_contrib, function.__name__, function)
+ elif function.__name__.startswith('_'):
+ setattr(module_internal, function.__name__, function)
+ else:
+ setattr(module_obj, function.__name__, function)
+
+ # register sparse ops under mxnet.ndarray.sparse
+ if function.__name__.startswith('_sparse_'):
+ function.__name__ = function.__name__[8:]
+ function.__module__ = 'mxnet.ndarray.sparse'
+ setattr(module_sparse, function.__name__, function)
+
+# register backend operators in mx.nd
+_init_ndarray_module("mxnet")
diff --git a/python/mxnet/ndarray/sparse.py b/python/mxnet/ndarray/sparse.py
new file mode 100644
index 000000000000..97e43f5ebe79
--- /dev/null
+++ b/python/mxnet/ndarray/sparse.py
@@ -0,0 +1,923 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+"""Sparse NDArray API of MXNet."""
+
+from __future__ import absolute_import
+from __future__ import division
+try:
+ from __builtin__ import slice as py_slice
+except ImportError:
+ from builtins import slice as py_slice
+
+import ctypes
+import warnings
+
+import os as _os
+import sys as _sys
+
+# import operator
+import numpy as np
+from ..base import NotSupportedForSparseNDArray
+from ..base import _LIB, numeric_types
+from ..base import c_array, mx_real_t
+from ..base import mx_uint, NDArrayHandle, check_call
+from ..context import Context
+from . import _internal
+from .ndarray import _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP
+from .ndarray import _STORAGE_TYPE_STR_TO_ID
+from .ndarray import NDArray, _storage_type
+from .ndarray import zeros as _zeros_ndarray
+from .ndarray import array as _array
+from . import cast_storage
+from . import slice as nd_slice
+
+# Use different verison of SymbolBase
+# When possible, use cython to speedup part of computation.
+# pylint: disable=unused-import
+try:
+ if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0:
+ from .._ctypes.ndarray import _set_ndarray_class
+ elif _sys.version_info >= (3, 0):
+ from .._cy3.ndarray import _set_ndarray_class
+ else:
+ from .._cy2.ndarray import _set_ndarray_class
+except ImportError:
+ if int(_os.environ.get("MXNET_ENFORCE_CYTHON", False)) != 0:
+ raise ImportError("Cython Module cannot be loaded but MXNET_ENFORCE_CYTHON=1")
+ from .._ctypes.ndarray import _set_ndarray_class
+# pylint: enable=unused-import
+
+
+__all__ = ["_ndarray_cls", "csr_matrix", "row_sparse_array",
+ "BaseSparseNDArray", "CSRNDArray", "RowSparseNDArray"]
+
+
+_STORAGE_AUX_TYPES = {
+ 'row_sparse': [np.int64],
+ 'csr': [np.int64, np.int64]
+}
+
+
+def _new_alloc_handle(stype, shape, ctx, delay_alloc, dtype, aux_types, aux_shapes=None):
+ """Return a new handle with specified storage type, shape, dtype and context.
+
+ Empty handle is only used to hold results
+
+ Returns
+ -------
+ handle
+ A new empty ndarray handle
+ """
+ hdl = NDArrayHandle()
+ aux_type_ids = [int(_DTYPE_NP_TO_MX[np.dtype(aux_t).type]) for aux_t in aux_types]
+ aux_shapes = [(0,) for aux_t in aux_types] if aux_shapes is None else aux_shapes
+ aux_shape_lens = [len(aux_shape) for aux_shape in aux_shapes]
+ aux_shapes = sum(aux_shapes, ())
+ num_aux = mx_uint(len(aux_types))
+ check_call(_LIB.MXNDArrayCreateSparseEx(
+ ctypes.c_int(int(_STORAGE_TYPE_STR_TO_ID[stype])),
+ c_array(mx_uint, shape),
+ mx_uint(len(shape)),
+ ctypes.c_int(ctx.device_typeid),
+ ctypes.c_int(ctx.device_id),
+ ctypes.c_int(int(delay_alloc)),
+ ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])),
+ num_aux,
+ c_array(ctypes.c_int, aux_type_ids),
+ c_array(mx_uint, aux_shape_lens),
+ c_array(mx_uint, aux_shapes),
+ ctypes.byref(hdl)))
+ return hdl
+
+
+class BaseSparseNDArray(NDArray):
+ """The base class of an NDArray stored in a sparse storage format.
+
+ See CSRNDArray and RowSparseNDArray for more details.
+ """
+
+ def __iadd__(self, other):
+ raise NotImplementedError()
+
+ def __isub__(self, other):
+ raise NotImplementedError()
+
+ def __imul__(self, other):
+ raise NotImplementedError()
+
+ def __idiv__(self, other):
+ raise NotImplementedError()
+
+ def __itruediv__(self, other):
+ raise NotImplementedError()
+
+ def _sync_copyfrom(self, source_array):
+ raise NotImplementedError()
+
+ def _at(self, idx):
+ raise NotSupportedForSparseNDArray(self._at, '[idx]', idx)
+
+ def _slice(self, start, stop):
+ raise NotSupportedForSparseNDArray(self._slice, None, start, stop)
+
+ def reshape(self, shape):
+ raise NotSupportedForSparseNDArray(self.reshape, None, shape)
+
+ def _aux_type(self, i):
+ """Data-type of the array's ith aux data.
+
+ Returns
+ -------
+ numpy.dtype
+ This BaseSparseNDArray's aux data type.
+ """
+ aux_type = ctypes.c_int()
+ check_call(_LIB.MXNDArrayGetAuxType(self.handle, i, ctypes.byref(aux_type)))
+ return _DTYPE_MX_TO_NP[aux_type.value]
+
+ @property
+ def _num_aux(self):
+ """The number of aux data used to help store the sparse ndarray.
+ """
+ return len(_STORAGE_AUX_TYPES[self.stype])
+
+ @property
+ def _aux_types(self):
+ """The data types of the aux data for the BaseSparseNDArray.
+ """
+ aux_types = []
+ num_aux = self._num_aux
+ for i in range(num_aux):
+ aux_types.append(self._aux_type(i))
+ return aux_types
+
+ def asnumpy(self):
+ """Return a dense ``numpy.ndarray`` object with value copied from this array
+ """
+ return self.tostype('default').asnumpy()
+
+ def astype(self, dtype):
+ """Returns a copy of the array after casting to a specified type.
+ Parameters
+ ----------
+ dtype : numpy.dtype or str
+ The type of the returned array.
+ Examples
+ --------
+ >>> x = mx.nd.zeros('row_sparse', (2,3), dtype='float32')
+ >>> y = x.astype('int32')
+ >>> y.dtype
+
+ """
+ res = zeros(shape=self.shape, ctx=self.context,
+ dtype=dtype, stype=self.stype)
+ self.copyto(res)
+ return res
+
+ def copyto(self, other):
+ """Copies the value of this array to another array.
+
+ Parameters
+ ----------
+ other : NDArray or CSRNDArray or RowSparseNDArray or Context
+ The destination array or context.
+
+ Returns
+ -------
+ NDArray or CSRNDArray or RowSparseNDArray
+ The copied array.
+ """
+ if isinstance(other, NDArray):
+ if other.handle is self.handle:
+ warnings.warn('You are attempting to copy an array to itself', RuntimeWarning)
+ return
+ return _internal._copyto(self, out=other)
+ elif isinstance(other, Context):
+ hret = _ndarray_cls(_new_alloc_handle(self.stype, self.shape, other,
+ True, self.dtype, self._aux_types))
+ return _internal._copyto(self, out=hret)
+ else:
+ raise TypeError('copyto does not support type ' + str(type(other)))
+
+ def _data(self):
+ """A deep copy NDArray of the data array associated with the BaseSparseNDArray.
+
+ This function blocks. Do not use it in performance critical code.
+ """
+ self.wait_to_read()
+ hdl = NDArrayHandle()
+ check_call(_LIB.MXNDArrayGetDataNDArray(self.handle, ctypes.byref(hdl)))
+ return NDArray(hdl)
+
+
+ def _aux_data(self, i):
+ """ Get a deep copy NDArray of the i-th aux data array associated with the
+ BaseSparseNDArray.
+
+ This function blocks. Do not use it in performance critical code.
+ """
+ self.wait_to_read()
+ hdl = NDArrayHandle()
+ check_call(_LIB.MXNDArrayGetAuxNDArray(self.handle, i, ctypes.byref(hdl)))
+ return NDArray(hdl)
+
+
+# pylint: disable=abstract-method
+class CSRNDArray(BaseSparseNDArray):
+ """A sparse representation of 2D NDArray in the standard CSR format.
+
+ A CSRNDArray represents an NDArray as three separate arrays: `data`,
+ `indptr` and `indices`. It uses the standard CSR representation where the column indices for
+ row i are stored in indices[indptr[i]:indptr[i+1]] and their corresponding values are stored
+ in values[indptr[i]:indptr[i+1]].
+
+ The column indices for a given row are expected to be sorted in ascending order.
+ Duplicate column entries for the same row are not allowed.
+
+ Example
+ -------
+ >>> a = mx.nd.array([[0, 1, 0], [2, 0, 0], [0, 0, 0], [0, 0, 3]])
+ >>> a = a.tostype('csr')
+ >>> a.indices.asnumpy()
+ array([1, 0, 2])
+ >>> a.indptr.asnumpy()
+ array([0, 1, 2, 2, 3])
+ >>> a.data.asnumpy()
+ array([ 1., 2., 3.], dtype=float32)
+ """
+
+ def __reduce__(self):
+ return CSRNDArray, (None,), super(CSRNDArray, self).__getstate__()
+
+ def __iadd__(self, other):
+ (self + other).copyto(self)
+ return self
+
+ def __isub__(self, other):
+ (self - other).copyto(self)
+ return self
+
+ def __imul__(self, other):
+ (self * other).copyto(self)
+ return self
+
+ def __idiv__(self, other):
+ (self / other).copyto(self)
+ return self
+
+ def __itruediv__(self, other):
+ (self / other).copyto(self)
+ return self
+
+ def __getitem__(self, key):
+ """x.__getitem__(i) <=> x[i]
+
+ Returns a sliced view of this array.
+
+ Parameters
+ ----------
+ key : slice
+ Indexing key.
+
+ Examples
+ --------
+ >>> indptr = np.array([0, 2, 3, 6])
+ >>> indices = np.array([0, 2, 2, 0, 1, 2])
+ >>> data = np.array([1, 2, 3, 4, 5, 6])
+ >>> a = mx.nd.csr_matrix(data, indptr, indices, (3, 3))
+ >>> a.asnumpy()
+ array([[1, 0, 2],
+ [0, 0, 3],
+ [4, 5, 6]])
+ >>> a[1:2].asnumpy()
+ array([[0, 0, 3]], dtype=float32)
+ """
+ if isinstance(key, int):
+ raise ValueError("__getitem__ with int key is not implemented for CSRNDArray")
+ if isinstance(key, py_slice):
+ if key.step is not None:
+ raise ValueError('CSRNDArray only supports continuous slicing on axis 0')
+ if key.start is not None or key.stop is not None:
+ begin = key.start if key.start else 0
+ end = key.stop if key.stop else self.shape[0]
+ return nd_slice(self, begin=begin, end=end)
+ else:
+ return self
+ if isinstance(key, tuple):
+ raise ValueError('Multi-dimension indexing is not supported')
+
+ def __setitem__(self, key, value):
+ """x.__setitem__(i, y) <=> x[i]=y
+
+ Set self[key] to value. Only slice key [:] is supported.
+
+ Parameters
+ ----------
+ key : slice
+ The indexing key.
+ value : NDArray or CSRNDArray or numpy.ndarray
+ The value to set.
+
+ Examples
+ --------
+ >>> src = mx.nd.zeros((3,3), stype='csr')
+ >>> src.asnumpy()
+ array([[ 0., 0., 0.],
+ [ 0., 0., 0.],
+ [ 0., 0., 0.]], dtype=float32)
+ >>> # assign CSRNDArray with same storage type
+ >>> x = mx.nd.ones('row_sparse', (3,3)).tostype('csr')
+ >>> x[:] = src
+ >>> x.asnumpy()
+ array([[ 1., 1., 1.],
+ [ 1., 1., 1.],
+ [ 1., 1., 1.]], dtype=float32)
+ >>> # assign NDArray to CSRNDArray
+ >>> x[:] = mx.nd.ones((3,3)) * 2
+ >>> x.asnumpy()
+ array([[ 2., 2., 2.],
+ [ 2., 2., 2.],
+ [ 2., 2., 2.]], dtype=float32)
+ """
+ if not self.writable:
+ raise ValueError('Failed to assign to a readonly CSRNDArray')
+ if isinstance(key, py_slice):
+ if key.step is not None or key.start is not None or key.stop is not None:
+ raise ValueError('Assignment with slice for CSRNDArray is not ' \
+ 'implmented yet.')
+ if isinstance(value, NDArray):
+ # avoid copying to itself
+ if value.handle is not self.handle:
+ value.copyto(self)
+ elif isinstance(value, numeric_types):
+ raise ValueError("Assigning numeric types to CSRNDArray is " \
+ "not implemented yet.")
+ elif isinstance(value, (np.ndarray, np.generic)):
+ # TODO(haibin/anisub) check scipy.sparse and use _sync_copy_from to
+ # avoid the temporary copy
+ warnings.warn('Assigning non-NDArray object to CSRNDArray is not efficient',
+ RuntimeWarning)
+ tmp = _array(value)
+ tmp.copyto(self)
+ else:
+ raise TypeError('type %s not supported' % str(type(value)))
+ else:
+ assert(isinstance(key, (int, tuple)))
+ raise Exception('CSRNDArray only supports [:] for assignment')
+
+ @property
+ def indices(self):
+ """A deep copy NDArray of the indices array of the CSRNDArray.
+ This generates a deep copy of the column indices of the current `csr` matrix.
+
+ Returns
+ -------
+ NDArray
+ This CSRNDArray's indices array.
+ """
+ return self._aux_data(1)
+
+ @property
+ def indptr(self):
+ """A deep copy NDArray of the indptr array of the CSRNDArray.
+ This generates a deep copy of the `indptr` of the current `csr` matrix.
+
+ Returns
+ -------
+ NDArray
+ This CSRNDArray's indptr array.
+ """
+ return self._aux_data(0)
+
+ @property
+ def data(self):
+ """A deep copy NDArray of the data array of the CSRNDArray.
+ This generates a deep copy of the `data` of the current `csr` matrix.
+
+ Returns
+ -------
+ NDArray
+ This CSRNDArray's data array.
+ """
+ return self._data()
+
+ def tostype(self, stype):
+ """Return a copy of the array with chosen storage type.
+
+ Returns
+ -------
+ NDArray or CSRNDArray
+ A copy of the array with the chosen storage stype
+ """
+ if stype == 'row_sparse':
+ raise ValueError("cast_storage from csr to row_sparse is not supported")
+ return cast_storage(self, stype=stype)
+
+ def copyto(self, other):
+ """Copies the value of this array to another array.
+
+ If ``other`` is a ``NDArray`` or ``CSRNDArray`` object, then ``other.shape`` and
+ ``self.shape`` should be the same. This function copies the value from
+ ``self`` to ``other``.
+
+ If ``other`` is a context, a new ``CSRNDArray`` will be first created on
+ the target context, and the value of ``self`` is copied.
+
+ Parameters
+ ----------
+ other : NDArray or CSRNDArray or Context
+ The destination array or context.
+
+ Returns
+ -------
+ NDArray or CSRNDArray
+ The copied array. If ``other`` is an ``NDArray`` or ``CSRNDArray``, then the return
+ value and ``other`` will point to the same ``NDArray`` or ``CSRNDArray``.
+ """
+ if isinstance(other, Context):
+ return super(CSRNDArray, self).copyto(other)
+ elif isinstance(other, NDArray):
+ stype = other.stype
+ if stype == 'default' or stype == 'csr':
+ return super(CSRNDArray, self).copyto(other)
+ else:
+ raise TypeError('copyto does not support destination NDArray stype ' + str(stype))
+ else:
+ raise TypeError('copyto does not support type ' + str(type(other)))
+
+
+# pylint: disable=abstract-method
+class RowSparseNDArray(BaseSparseNDArray):
+ """A sparse representation of a set of NDArray row slices at given indices.
+
+ A RowSparseNDArray represents a multidimensional NDArray using two separate arrays: `data` and
+ `indices`.
+
+ - data: an NDArray of any dtype with shape [D0, D1, ..., Dn].
+ - indices: a 1-D int64 NDArray with shape [D0].
+
+ The `indices` stores the indices of the row slices with non-zeros,
+ while the values are stored in `data`. The corresponding NDArray ``dense``
+ represented by RowSparseNDArray ``rsp`` has
+
+ ``dense[rsp.indices[i], :, :, :, ...] = rsp.data[i, :, :, :, ...]``
+
+ >>> dense.asnumpy()
+ array([[ 1., 2., 3.],
+ [ 0., 0., 0.],
+ [ 4., 0., 5.],
+ [ 0., 0., 0.],
+ [ 0., 0., 0.]], dtype=float32)
+ >>> rsp = dense.tostype('row_sparse')
+ >>> rsp.indices.asnumpy()
+ array([0, 2], dtype=int64)
+ >>> rsp.data.asnumpy()
+ array([[ 1., 2., 3.],
+ [ 4., 0., 5.]], dtype=float32)
+
+ A RowSparseNDArray is typically used to represent non-zero row-slices of a large NDArray
+ of shape [LARGE0, D1, .. , Dn] where LARGE0 >> D0 and most row slices are zeros.
+
+ The indices are expected to be sorted in ascending order.
+
+ RowSparseNDArray is used principally in the definition of gradients for operations
+ that have sparse gradients (e.g. sparse dot and sparse embedding).
+ """
+ def __reduce__(self):
+ return RowSparseNDArray, (None,), super(RowSparseNDArray, self).__getstate__()
+
+ def __iadd__(self, other):
+ (self + other).copyto(self)
+ return self
+
+ def __isub__(self, other):
+ (self - other).copyto(self)
+ return self
+
+ def __imul__(self, other):
+ (self * other).copyto(self)
+ return self
+
+ def __idiv__(self, other):
+ (self / other).copyto(self)
+ return self
+
+ def __itruediv__(self, other):
+ (self / other).copyto(self)
+ return self
+
+ def __getitem__(self, key):
+ """x.__getitem__(i) <=> x[i]
+
+ Returns a sliced view of this array.
+
+ Parameters
+ ----------
+ key : slice
+ Indexing key.
+
+ Examples
+ --------
+ >>> x = mx.nd.zeros((2, 3), stype='row_sparse')
+ >>> x[:].asnumpy()
+ array([[ 0., 0., 0.],
+ [ 0., 0., 0.]], dtype=float32)
+ """
+ if isinstance(key, int):
+ raise Exception("__getitem__ with int key is not implemented for RowSparseNDArray yet")
+ if isinstance(key, py_slice):
+ if key.step is not None or key.start is not None or key.stop is not None:
+ raise Exception('RowSparseNDArray only supports [:] for __getitem__')
+ else:
+ return self
+ if isinstance(key, tuple):
+ raise ValueError('Multi-dimension indexing is not supported')
+
+ def __setitem__(self, key, value):
+ """x.__setitem__(i, y) <=> x[i]=y
+
+ Set self[key] to value. Only slice key [:] is supported.
+
+ Parameters
+ ----------
+ key : slice
+ The indexing key.
+ value : NDArray or numpy.ndarray
+ The value to set.
+
+ Examples
+ --------
+ >>> src = mx.nd.row_sparse([[1, 0, 2], [4, 5, 6]], [0, 2], (3,3))
+ >>> src.asnumpy()
+ array([[ 1., 0., 2.],
+ [ 0., 0., 0.],
+ [ 4., 5., 6.]], dtype=float32)
+ >>> # assign RowSparseNDArray with same storage type
+ >>> x = mx.nd.zeros('row_sparse', (3,3))
+ >>> x[:] = src
+ >>> x.asnumpy()
+ array([[ 1., 0., 2.],
+ [ 0., 0., 0.],
+ [ 4., 5., 6.]], dtype=float32)
+ >>> # assign NDArray to RowSparseNDArray
+ >>> x[:] = mx.nd.ones((3,3))
+ >>> x.asnumpy()
+ array([[ 1., 1., 1.],
+ [ 1., 1., 1.],
+ [ 1., 1., 1.]], dtype=float32)
+ """
+ if not self.writable:
+ raise ValueError('Failed to assign to a readonly RowSparseNDArray')
+ if isinstance(key, py_slice):
+ if key.step is not None or key.start is not None or key.stop is not None:
+ raise ValueError('Assignment with slice for RowSparseNDArray ' \
+ 'is not implmented yet.')
+ if isinstance(value, NDArray):
+ # avoid copying to itself
+ if value.handle is not self.handle:
+ value.copyto(self)
+ elif isinstance(value, numeric_types):
+ raise ValueError("Assigning numeric types to RowSparseNDArray " \
+ "is not implemented yet.")
+ elif isinstance(value, (np.ndarray, np.generic)):
+ warnings.warn('Assigning non-NDArray object to RowSparseNDArray is not efficient',
+ RuntimeWarning)
+ tmp = _array(value)
+ tmp.copyto(self)
+ else:
+ raise TypeError('type %s not supported' % str(type(value)))
+ else:
+ assert(isinstance(key, (int, tuple)))
+ raise TypeError('RowSparseNDArray only supports [:] for assignment')
+
+ @property
+ def indices(self):
+ """A deep copy NDArray of the indices array of the RowSparseNDArray.
+ This generates a deep copy of the row indices of the current `row_sparse` matrix.
+
+ Returns
+ -------
+ NDArray
+ This RowSparseNDArray's indices array.
+ """
+ return self._aux_data(0)
+
+ @property
+ def data(self):
+ """A deep copy NDArray of the data array of the RowSparseNDArray.
+ This generates a deep copy of the `data` of the current `row_sparse` matrix.
+
+ Returns
+ -------
+ NDArray
+ This RowSparseNDArray's data array.
+ """
+ return self._data()
+
+ def tostype(self, stype):
+ """Return a copy of the array with chosen storage type.
+
+ Returns
+ -------
+ NDArray or RowSparseNDArray
+ A copy of the array with the chosen storage stype
+ """
+ if stype == 'csr':
+ raise ValueError("cast_storage from row_sparse to csr is not supported")
+ return cast_storage(self, stype=stype)
+
+ def copyto(self, other):
+ """Copies the value of this array to another array.
+
+ If ``other`` is a ``NDArray`` or ``RowSparseNDArray`` object, then ``other.shape``
+ and ``self.shape`` should be the same. This function copies the value from
+ ``self`` to ``other``.
+
+ If ``other`` is a context, a new ``RowSparseNDArray`` will be first created on
+ the target context, and the value of ``self`` is copied.
+
+ Parameters
+ ----------
+ other : NDArray or RowSparseNDArray or Context
+ The destination array or context.
+
+ Returns
+ -------
+ NDArray or RowSparseNDArray
+ The copied array. If ``other`` is an ``NDArray`` or ``RowSparseNDArray``, then the
+ return value and ``other`` will point to the same ``NDArray`` or ``RowSparseNDArray``.
+ """
+ if isinstance(other, Context):
+ return super(RowSparseNDArray, self).copyto(other)
+ elif isinstance(other, NDArray):
+ stype = other.stype
+ if stype == 'default' or stype == 'row_sparse':
+ return super(RowSparseNDArray, self).copyto(other)
+ else:
+ raise TypeError('copyto does not support destination NDArray stype ' + str(stype))
+ else:
+ raise TypeError('copyto does not support type ' + str(type(other)))
+
+
+def _prepare_src_array(src, dtype, default_dtype):
+ """Prepare `src` and its dtype so that they can be used to construct NDArray.
+ `src` is converted to a `np.ndarray` if it's neither an `NDArray` nor an `np.ndarray`.
+ """
+ if isinstance(src, NDArray):
+ dtype = src.dtype if dtype is None else dtype
+ else:
+ dtype = default_dtype if dtype is None else dtype
+ if not isinstance(src, np.ndarray):
+ try:
+ src = np.array(src, dtype=dtype)
+ except:
+ raise TypeError('values must be array like object')
+ return src, dtype
+
+
+def csr_matrix(data, indptr, indices, shape, ctx=None, dtype=None, indptr_type=None,
+ indices_type=None):
+ """Creates a 2D array with compressed sparse row(CSR) format.
+
+ Parameters
+ ----------
+ data: array_like
+ An object exposing the array interface, with shape [nnz], where D0 is the number of
+ non-zero entries.
+ indptr: array_like
+ An object exposing the array interface, with shape [D0 + 1]. The first element in indptr
+ should always be zero.
+ indices: array_like
+ An object exposing the array interface, with shape [nnz].
+ ctx: Context, optional
+ Device context (default is the current default context).
+ dtype: str or numpy.dtype, optional
+ The data type of the output array. The default dtype is ``values.dtype``
+ if `values` is an `NDArray`, `float32` otherwise.
+ indptr_type: str or numpy.dtype, optional
+ The data type of the indices array. The default dtype is ``indptr.dtype``
+ if `indptr` is an `NDArray`, `int64` otherwise.
+ indices_type: str or numpy.dtype, optional
+ The data type of the indices array. The default dtype is ``indices.dtype``
+ if `indicies` is an `NDArray`, `int64` otherwise.
+
+ Returns
+ -------
+ CSRNDArray
+ A `CSRNDArray` with the `csr` storage representation.
+
+ Example
+ -------
+ >>> import mxnet as mx
+ >>> a = mx.nd.csr_matrix([1, 2, 3], [0, 1, 2, 2, 3], [1, 0, 2], (4, 3))
+ >>> a.asnumpy()
+ array([[ 0., 1., 0.],
+ [ 2., 0., 0.],
+ [ 0., 0., 0.],
+ [ 0., 0., 3.]], dtype=float32)
+ """
+ storage_type = 'csr'
+ # context
+ if ctx is None:
+ ctx = Context.default_ctx
+ # prepare src array and types
+ data, dtype = _prepare_src_array(data, dtype, mx_real_t)
+ indptr, indptr_type = _prepare_src_array(indptr, indptr_type,
+ _STORAGE_AUX_TYPES[storage_type][0])
+ indices, indices_type = _prepare_src_array(indices, indices_type,
+ _STORAGE_AUX_TYPES[storage_type][1])
+ # verify types
+ assert('int64' in str(indptr_type)), "expected int64 for indptr"
+ assert('int64' in str(indices_type)), "expected int64 for indices"
+ # verify shapes
+ aux_shapes = [indptr.shape, indices.shape]
+ assert(data.ndim == 1)
+ assert(indptr.ndim == 1)
+ assert(indices.ndim == 1)
+ assert(len(shape) == 2)
+ result = CSRNDArray(_new_alloc_handle(storage_type, shape, ctx, False, dtype,
+ [indptr_type, indices_type], aux_shapes))
+ # TODO(junwu): Convert data, indptr, and indices to mxnet NDArrays
+ # if they are not for now. In the future, we should provide a c-api
+ # to accept np.ndarray types to copy from to result.data and aux_data
+ if not isinstance(data, NDArray):
+ data = _array(data, ctx, dtype)
+ if not isinstance(indptr, NDArray):
+ indptr = _array(indptr, ctx, indptr_type)
+ if not isinstance(indices, NDArray):
+ indices = _array(indices, ctx, indices_type)
+ check_call(_LIB.MXNDArraySyncCopyFromNDArray(result.handle, data.handle, ctypes.c_int(-1)))
+ check_call(_LIB.MXNDArraySyncCopyFromNDArray(result.handle, indptr.handle, ctypes.c_int(0)))
+ check_call(_LIB.MXNDArraySyncCopyFromNDArray(result.handle, indices.handle, ctypes.c_int(1)))
+ return result
+
+
+def row_sparse_array(data, indices, shape, ctx=None, dtype=None, indices_type=None):
+ """Creates a multidimensional row sparse array with a set of tensor slices at given indices.
+
+ Parameters
+ ----------
+ data: array_like
+ An object exposing the array interface, with shape [D0, D1, .. DK], where D0 is
+ the number of rows with non-zeros entries.
+ indices: array_like
+ An object exposing the array interface, with shape [D0].
+ ctx : Context, optional
+ Device context (default is the current default context).
+ dtype : str or numpy.dtype, optional
+ The data type of the output array. The default dtype is ``data.dtype``
+ if `data` is an `NDArray`, `float32` otherwise.
+ indices_type: str or numpy.dtype, optional
+ The data type of the indices array. The default dtype is ``indices.dtype``
+ if `indicies` is an `NDArray`, `int64` otherwise.
+
+ Returns
+ -------
+ RowSparseNDArray
+ An `RowSparseNDArray` with the `row_sparse` storage representation.
+
+ Example
+ -------
+ >>> a = mx.nd.row_sparse_array([[1, 2], [3, 4]], [1, 4], (6, 2))
+ >>> a.asnumpy()
+ array([[ 0., 0.],
+ [ 1., 2.],
+ [ 0., 0.],
+ [ 0., 0.],
+ [ 3., 4.],
+ [ 0., 0.]], dtype=float32)
+ """
+ storage_type = 'row_sparse'
+ # context
+ if ctx is None:
+ ctx = Context.default_ctx
+ # prepare src array and types
+ data, dtype = _prepare_src_array(data, dtype, mx_real_t)
+ indices, indices_type = _prepare_src_array(indices, indices_type,
+ _STORAGE_AUX_TYPES[storage_type][0])
+ # verify types
+ assert('int64' in str(indices_type)), "expected int64 for indices"
+ # verify shapes
+ assert(data.ndim == len(shape))
+ assert(indices.ndim == 1)
+ result = RowSparseNDArray(_new_alloc_handle(storage_type, shape, ctx, False, dtype,
+ [indices_type], [indices.shape]))
+
+ # TODO(junwu): Convert data, indptr, and indices to mxnet NDArrays
+ # if they are not for now. In the future, we should provide a c-api
+ # to accept np.ndarray types to copy from to result.data and aux_data
+ if not isinstance(data, NDArray):
+ data = _array(data, ctx, dtype)
+ if not isinstance(indices, NDArray):
+ indices = _array(indices, ctx, indices_type)
+ check_call(_LIB.MXNDArraySyncCopyFromNDArray(result.handle, data.handle, ctypes.c_int(-1)))
+ check_call(_LIB.MXNDArraySyncCopyFromNDArray(result.handle, indices.handle, ctypes.c_int(0)))
+ return result
+
+
+def _ndarray_cls(handle, writable=True, stype=None):
+ if stype is None:
+ stype = _storage_type(handle)
+ if stype == 'default':
+ return NDArray(handle, writable=writable)
+ elif stype == 'csr':
+ return CSRNDArray(handle, writable=writable)
+ elif stype == 'row_sparse':
+ return RowSparseNDArray(handle, writable=writable)
+ else:
+ raise Exception("unknown storage type")
+
+
+_set_ndarray_class(_ndarray_cls)
+
+
+def zeros(stype, shape, ctx=None, dtype=None, aux_types=None, **kwargs):
+ """Return a new array of given shape and type, filled with zeros.
+
+ Parameters
+ ----------
+ stype: string
+ The storage type of the empty array, such as 'row_sparse', 'csr', etc
+ shape : int or tuple of int
+ The shape of the empty array
+ ctx : Context, optional
+ An optional device context (default is the current default context)
+ dtype : str or numpy.dtype, optional
+ An optional value type (default is `float32`)
+ aux_types: list of numpy.dtype, optional
+ An optional list of types of the aux data for RowSparseNDArray or CSRNDArray
+ (default values depends on the storage type)
+
+ Returns
+ -------
+ RowSparseNDArray or CSRNDArray
+ A created array
+ Examples
+ --------
+ >>> mx.nd.zeros((1,2), mx.cpu(), stype='csr')
+
+ >>> mx.nd.zeros((1,2), mx.cpu(), 'float16', stype='row_sparse').asnumpy()
+ array([[ 0., 0.]], dtype=float16)
+ """
+ if stype == 'default':
+ return _zeros_ndarray(shape, ctx=ctx, dtype=dtype, **kwargs)
+ if ctx is None:
+ ctx = Context.default_ctx
+ dtype = mx_real_t if dtype is None else dtype
+ if aux_types is None:
+ if stype == 'row_sparse' or stype == 'csr':
+ aux_types = _STORAGE_AUX_TYPES[stype]
+ else:
+ raise Exception("unknown storage type")
+ assert(len(aux_types) == len(_STORAGE_AUX_TYPES[stype]))
+ out = _ndarray_cls(_new_alloc_handle(stype, shape, ctx, True, dtype, aux_types))
+ return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype, out=out, **kwargs)
+
+
+def empty(stype, shape, ctx=None, dtype=None, aux_types=None):
+ """Returns a new array of given shape and type, without initializing entries.
+ """
+ if isinstance(shape, int):
+ shape = (shape, )
+ if ctx is None:
+ ctx = Context.default_ctx
+ if dtype is None:
+ dtype = mx_real_t
+ assert(stype is not None)
+ if stype == 'csr' or stype == 'row_sparse':
+ return zeros(stype, shape, ctx=ctx, dtype=dtype, aux_types=aux_types)
+ else:
+ raise Exception("unknown stype : " + str(stype))
+
+
+def array(source_array, ctx=None, dtype=None, aux_types=None):
+ """Creates a sparse array from any object exposing the array interface.
+ """
+ if isinstance(source_array, NDArray):
+ assert(source_array.stype != 'default'), \
+ "Please use `cast_storage` to create BaseSparseNDArray from an NDArray"
+ dtype = source_array.dtype if dtype is None else dtype
+ aux_types = source_array._aux_types if aux_types is None else aux_types
+ else:
+ # TODO(haibin/anisub) support creation from scipy object when `_sync_copy_from` is ready
+ raise NotImplementedError('creating BaseSparseNDArray from ' \
+ ' a non-NDArray object is not implemented.')
+ arr = empty(source_array.stype, source_array.shape, ctx, dtype, aux_types)
+ arr[:] = source_array
+ return arr
diff --git a/python/mxnet/ndarray/utils.py b/python/mxnet/ndarray/utils.py
new file mode 100644
index 000000000000..a0dd83692b87
--- /dev/null
+++ b/python/mxnet/ndarray/utils.py
@@ -0,0 +1,240 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+"""Utility functions for NDArray and BaseSparseNDArray."""
+import ctypes
+
+from ..base import _LIB, check_call, py_str, c_str, string_types, mx_uint, NDArrayHandle, c_array
+from .ndarray import NDArray
+from .ndarray import array as _array
+from .ndarray import empty as _empty_ndarray
+from .ndarray import zeros as _zeros_ndarray
+from .sparse import zeros as _zeros_sparse_ndarray
+from .sparse import empty as _empty_sparse_ndarray
+from .sparse import array as _sparse_array
+from .sparse import _ndarray_cls
+
+
+def zeros(shape, ctx=None, dtype=None, stype=None, aux_types=None, **kwargs):
+ """Return a new array of given shape and type, filled with zeros.
+
+ Parameters
+ ----------
+ shape : int or tuple of int
+ The shape of the empty array
+ ctx : Context, optional
+ An optional device context (default is the current default context)
+ dtype : str or numpy.dtype, optional
+ An optional value type (default is `float32`)
+ stype: string, optional
+ The storage type of the empty array, such as 'row_sparse', 'csr', etc.
+ aux_types: list of numpy.dtype, optional
+ An optional list of types of the aux data for RowSparseNDArray or CSRNDArray
+ (default values depend on the storage type)
+
+ Returns
+ -------
+ NDArray, CSRNDArray or RowSparseNDArray
+ A created array
+ Examples
+ --------
+ >>> mx.nd.zeros((1,2), mx.cpu(), stype='csr')
+
+ >>> mx.nd.zeros((1,2), mx.cpu(), 'float16', stype='row_sparse').asnumpy()
+ array([[ 0., 0.]], dtype=float16)
+ """
+
+ if stype is None or stype == 'default':
+ return _zeros_ndarray(shape, ctx, dtype, **kwargs)
+ else:
+ return _zeros_sparse_ndarray(stype, shape, ctx, dtype, aux_types, **kwargs)
+
+
+def empty(shape, ctx=None, dtype=None, stype=None, aux_types=None):
+ """Returns a new array of given shape and type, without initializing entries.
+
+ Parameters
+ ----------
+ shape : int or tuple of int
+ The shape of the empty array.
+ ctx : Context, optional
+ An optional device context (default is the current default context).
+ dtype : str or numpy.dtype, optional
+ An optional value type (default is `float32`).
+ stype : str, optional
+ An optional storage type (default is `default`).
+ aux_types: list of numpy.dtype, optional
+ An optional list of types of the aux data for RowSparseNDArray or CSRNDArray
+ (default values depend on the storage type)
+
+ Returns
+ -------
+ NDArray, CSRNDArray or RowSparseNDArray
+ A created array.
+
+ Examples
+ --------
+ >>> mx.nd.empty(1)
+
+ >>> mx.nd.empty((1,2), mx.gpu(0))
+
+ >>> mx.nd.empty((1,2), mx.gpu(0), 'float16')
+
+ >>> mx.nd.empty((1,2), stype='csr')
+
+ """
+ if stype is None or stype == 'default':
+ return _empty_ndarray(shape, ctx, dtype)
+ else:
+ return _empty_sparse_ndarray(stype, shape, ctx, dtype, aux_types)
+
+
+def array(source_array, ctx=None, dtype=None, aux_types=None):
+ """Creates an array from any object exposing the array interface.
+
+ Parameters
+ ----------
+ source_array : array_like
+ An object exposing the array interface, an object whose `__array__`
+ method returns an array, or any (nested) sequence.
+ ctx : Context, optional
+ Device context (default is the current default context).
+ dtype : str or numpy.dtype, optional
+ The data type of the output array. The default dtype is ``source_array.dtype``
+ if `source_array` is an `NDArray`, `float32` otherwise.
+ aux_types: list of numpy.dtype, optional
+ An optional list of types of the aux data for RowSparseNDArray or CSRNDArray
+ (default values depend on the storage type)
+
+ Returns
+ -------
+ NDArray, RowSparseNDArray or CSRNDArray
+ An array with the same contents as the `source_array`.
+
+ Examples
+ --------
+ >>> import numpy as np
+ >>> mx.nd.array([1, 2, 3])
+
+ >>> mx.nd.array([[1, 2], [3, 4]])
+
+ >>> mx.nd.array(np.zeros((3, 2)))
+
+ >>> mx.nd.array(np.zeros((3, 2)), mx.gpu(0))
+
+ >>> mx.nd.array(mx.nd.zeros((3, 2), stype='row_sparse'))
+
+ """
+ # TODO(haibin/anisub) Check if input is scipy.sparse object with `scipy.sparse.issparse`
+ if isinstance(source_array, NDArray) and source_array.stype != 'default':
+ return _sparse_array(source_array, ctx=ctx, dtype=dtype, aux_types=aux_types)
+ else:
+ return _array(source_array, ctx=ctx, dtype=dtype)
+
+
+def load(fname):
+ """Loads an array from file.
+
+ See more details in ``save``.
+
+ Parameters
+ ----------
+ fname : str
+ The filename.
+
+ Returns
+ -------
+ list of NDArray, RowSparseNDArray or CSRNDArray, or \
+ dict of str to NDArray, RowSparseNDArray or CSRNDArray
+ Loaded data.
+ """
+ if not isinstance(fname, string_types):
+ raise TypeError('fname required to be a string')
+ out_size = mx_uint()
+ out_name_size = mx_uint()
+ handles = ctypes.POINTER(NDArrayHandle)()
+ names = ctypes.POINTER(ctypes.c_char_p)()
+ check_call(_LIB.MXNDArrayLoad(c_str(fname),
+ ctypes.byref(out_size),
+ ctypes.byref(handles),
+ ctypes.byref(out_name_size),
+ ctypes.byref(names)))
+ if out_name_size.value == 0:
+ return [_ndarray_cls(NDArrayHandle(handles[i])) for i in range(out_size.value)]
+ else:
+ assert out_name_size.value == out_size.value
+ return dict(
+ (py_str(names[i]), _ndarray_cls(NDArrayHandle(handles[i])))
+ for i in range(out_size.value))
+
+
+def save(fname, data):
+ """Saves a list of arrays or a dict of str->array to file.
+
+ Examples of filenames:
+
+ - ``/path/to/file``
+ - ``s3://my-bucket/path/to/file`` (if compiled with AWS S3 supports)
+ - ``hdfs://path/to/file`` (if compiled with HDFS supports)
+
+ Parameters
+ ----------
+ fname : str
+ The filename.
+ data : NDArray, RowSparseNDArray or CSRNDArray, \
+ or list of NDArray, RowSparseNDArray or CSRNDArray, \
+ or dict of str to NDArray, RowSparseNDArray or CSRNDArray
+ The data to save.
+
+ Examples
+ --------
+ >>> x = mx.nd.zeros((2,3))
+ >>> y = mx.nd.ones((1,4))
+ >>> mx.nd.save('my_list', [x,y])
+ >>> mx.nd.save('my_dict', {'x':x, 'y':y})
+ >>> mx.nd.load('my_list')
+ [, ]
+ >>> mx.nd.load('my_dict')
+ {'y': , 'x': }
+ """
+ if isinstance(data, NDArray):
+ data = [data]
+ handles = []
+ if isinstance(data, dict):
+ keys = []
+ for key, val in data.items():
+ if not isinstance(key, string_types):
+ raise TypeError('save only accept dict str->NDArray or list of NDArray')
+ if not isinstance(val, NDArray):
+ raise TypeError('save only accept dict str->NDArray or list of NDArray')
+ keys.append(c_str(key))
+ handles.append(val.handle)
+ keys = c_array(ctypes.c_char_p, keys)
+ elif isinstance(data, list):
+ for val in data:
+ if not isinstance(val, NDArray):
+ raise TypeError('save only accept dict str->NDArray or list of NDArray')
+ handles.append(val.handle)
+ keys = None
+ else:
+ raise ValueError("data needs to either be a NDArray, dict of str, NDArray pairs "
+ "or a list of NDarrays.")
+ check_call(_LIB.MXNDArraySave(c_str(fname),
+ mx_uint(len(handles)),
+ c_array(NDArrayHandle, handles),
+ keys))
diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py
index 1ef9cc845036..e7e283f88e43 100644
--- a/python/mxnet/optimizer.py
+++ b/python/mxnet/optimizer.py
@@ -339,8 +339,8 @@ class SGD(Optimizer):
state = momentum * state + lr * rescale_grad * clip(grad, clip_gradient) + wd * weight
weight = weight - state
- For details of the update algorithm see :class:`~mxnet.ndarray.sgd_update` and
- :class:`~mxnet.ndarray.sgd_mom_update`.
+ Sparse updating is supported. For details of the update algorithm see
+ :class:`~mxnet.ndarray.sgd_update` and :class:`~mxnet.ndarray.sgd_mom_update`.
This optimizer accepts the following parameters in addition to those accepted
by :class:`.Optimizer`.
@@ -367,7 +367,8 @@ def create_state(self, index, weight):
if self.multi_precision and weight.dtype == numpy.float16:
weight_master_copy = array(weight, ctx=weight.context, dtype=numpy.float32)
if self.momentum != 0.0:
- momentum = zeros(weight.shape, weight.context, dtype=numpy.float32)
+ momentum = zeros(weight.shape, weight.context, dtype=numpy.float32,
+ stype=weight.stype)
return (momentum, weight_master_copy)
if weight.dtype == numpy.float16 and not self.multi_precision:
warnings.warn("Accumulating with float16 in optimizer can lead to "
@@ -375,7 +376,7 @@ def create_state(self, index, weight):
"Consider using multi_precision=True option of the "
"SGD optimizer")
if self.momentum != 0.0:
- momentum = zeros(weight.shape, weight.context, dtype=weight.dtype)
+ momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=weight.stype)
return momentum
def update(self, index, weight, grad, state):
@@ -563,8 +564,10 @@ def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
self.epsilon = epsilon
def create_state(self, index, weight):
- return (zeros(weight.shape, weight.context, dtype=weight.dtype), # mean
- zeros(weight.shape, weight.context, dtype=weight.dtype)) # variance
+ return (zeros(weight.shape, weight.context, dtype=weight.dtype,
+ stype=weight.stype), # mean
+ zeros(weight.shape, weight.context, dtype=weight.dtype,
+ stype=weight.stype)) # variance
def update(self, index, weight, grad, state):
assert(isinstance(weight, NDArray))
@@ -669,11 +672,11 @@ def __init__(self, learning_rate=0.001, gamma1=0.9, gamma2=0.9,
def create_state(self, index, weight):
if self.centered:
return (
- zeros(weight.shape, weight.context), # n
- zeros(weight.shape, weight.context), # g
- zeros(weight.shape, weight.context)) # delta
+ zeros(weight.shape, weight.context, stype=weight.stype), # n
+ zeros(weight.shape, weight.context, stype=weight.stype), # g
+ zeros(weight.shape, weight.context, stype=weight.stype)) # delta
else:
- return (zeros(weight.shape, weight.context), ) # n
+ return (zeros(weight.shape, weight.context, stype=weight.stype),) # n
def update(self, index, weight, grad, state):
assert(isinstance(weight, NDArray))
diff --git a/python/mxnet/random.py b/python/mxnet/random.py
index 29b250d980ce..14bfc2731bd6 100644
--- a/python/mxnet/random.py
+++ b/python/mxnet/random.py
@@ -22,13 +22,13 @@
import ctypes
from .base import _LIB, check_call
-from ._ndarray_internal import _sample_uniform as uniform
-from ._ndarray_internal import _sample_normal as normal
-from ._ndarray_internal import _sample_gamma as gamma
-from ._ndarray_internal import _sample_exponential as exponential
-from ._ndarray_internal import _sample_poisson as poisson
-from ._ndarray_internal import _sample_negbinomial as negative_binomial
-from ._ndarray_internal import _sample_gennegbinomial as generalized_negative_binomial
+from .ndarray._internal import _sample_uniform as uniform
+from .ndarray._internal import _sample_normal as normal
+from .ndarray._internal import _sample_gamma as gamma
+from .ndarray._internal import _sample_exponential as exponential
+from .ndarray._internal import _sample_poisson as poisson
+from .ndarray._internal import _sample_negbinomial as negative_binomial
+from .ndarray._internal import _sample_gennegbinomial as generalized_negative_binomial
def seed(seed_state):
"""Seeds the random number generators in MXNet.
diff --git a/python/mxnet/symbol/__init__.py b/python/mxnet/symbol/__init__.py
new file mode 100644
index 000000000000..d93a230f490d
--- /dev/null
+++ b/python/mxnet/symbol/__init__.py
@@ -0,0 +1,23 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Symbol API of MXNet."""
+
+from . import _internal, sparse, op
+# pylint: disable=wildcard-import, redefined-builtin
+from .symbol import *
+from ..ndarray import _GRAD_REQ_MAP
diff --git a/python/mxnet/_symbol_internal.py b/python/mxnet/symbol/_internal.py
similarity index 100%
rename from python/mxnet/_symbol_internal.py
rename to python/mxnet/symbol/_internal.py
diff --git a/python/mxnet/symbol/op.py b/python/mxnet/symbol/op.py
new file mode 100644
index 000000000000..82884a5cc6a2
--- /dev/null
+++ b/python/mxnet/symbol/op.py
@@ -0,0 +1,242 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Register backend ops in mxnet.symbol namespace."""
+
+import sys as _sys
+import os as _os
+import ctypes
+import numpy as _numpy # pylint: disable=unused-import
+
+from mxnet.base import mx_uint, check_call, _LIB, py_str, OpHandle, c_str
+from mxnet.symbol_doc import _build_doc
+
+# Use different version of SymbolBase
+# When possible, use cython to speedup part of computation.
+# pylint: disable=unused-import
+try:
+ if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0:
+ from .._ctypes.symbol import SymbolBase, _set_symbol_class
+ from .._ctypes.symbol import _symbol_creator
+ elif _sys.version_info >= (3, 0):
+ from .._cy3.symbol import SymbolBase, _set_symbol_class
+ from .._cy3.symbol import _symbol_creator
+ else:
+ from .._cy2.symbol import SymbolBase, _set_symbol_class
+ from .._cy2.symbol import _symbol_creator
+except ImportError:
+ if int(_os.environ.get("MXNET_ENFORCE_CYTHON", False)) != 0:
+ raise ImportError("Cython Module cannot be loaded but MXNET_ENFORCE_CYTHON=1")
+ from .._ctypes.symbol import SymbolBase, _set_symbol_class
+ from .._ctypes.symbol import _symbol_creator
+
+from ..base import _Null
+from ..name import NameManager
+from ..attribute import AttrScope
+# pylint: enable=unused-import
+
+
+def _make_atomic_symbol_function(handle, name):
+ """Create an atomic symbol function by handle and function name."""
+ real_name = ctypes.c_char_p()
+ desc = ctypes.c_char_p()
+ num_args = mx_uint()
+ arg_names = ctypes.POINTER(ctypes.c_char_p)()
+ arg_types = ctypes.POINTER(ctypes.c_char_p)()
+ arg_descs = ctypes.POINTER(ctypes.c_char_p)()
+ key_var_num_args = ctypes.c_char_p()
+ ret_type = ctypes.c_char_p()
+
+ check_call(_LIB.MXSymbolGetAtomicSymbolInfo(
+ handle, ctypes.byref(real_name), ctypes.byref(desc),
+ ctypes.byref(num_args),
+ ctypes.byref(arg_names),
+ ctypes.byref(arg_types),
+ ctypes.byref(arg_descs),
+ ctypes.byref(key_var_num_args),
+ ctypes.byref(ret_type)))
+ narg = int(num_args.value)
+ arg_names = [py_str(arg_names[i]) for i in range(narg)]
+ arg_types = [py_str(arg_types[i]) for i in range(narg)]
+ func_name = name
+ key_var_num_args = py_str(key_var_num_args.value)
+ ret_type = py_str(ret_type.value) if ret_type.value is not None else ''
+ doc_str = _build_doc(func_name,
+ py_str(desc.value),
+ arg_names,
+ arg_types,
+ [py_str(arg_descs[i]) for i in range(narg)],
+ key_var_num_args,
+ ret_type)
+
+ dtype_name = None
+ arr_name = None
+ ndsignature = []
+ signature = []
+ ndarg_names = []
+ kwarg_names = []
+ for i in range(narg):
+ name, atype = arg_names[i], arg_types[i]
+ if name == 'dtype':
+ dtype_name = name
+ signature.append('%s=_Null'%name)
+ elif atype.startswith('NDArray') or atype.startswith('Symbol'):
+ assert not arr_name, \
+ "Op can only have one argument with variable " \
+ "size and it must be the last argument."
+ if atype.endswith('[]'):
+ ndsignature.append('*%s'%name)
+ arr_name = name
+ else:
+ ndsignature.append('%s=None'%name)
+ ndarg_names.append(name)
+ else:
+ signature.append('%s=_Null'%name)
+ kwarg_names.append(name)
+ #signature.append('is_train=False')
+ signature.append('name=None')
+ signature.append('attr=None')
+ signature.append('out=None')
+ signature.append('**kwargs')
+ signature = ndsignature + signature
+
+ code = []
+ if arr_name:
+ code.append("""
+def %s(*%s, **kwargs):"""%(func_name, arr_name))
+ code.append("""
+ sym_args = []
+ for i in {}:
+ assert isinstance(i, SymbolBase), \\
+ "Positional arguments must be Symbol instances, " \\
+ "but got %s"%str(i)
+ sym_args.append(i)""".format(arr_name))
+ if dtype_name is not None:
+ code.append("""
+ if '%s' in kwargs:
+ kwargs['%s'] = _numpy.dtype(kwargs['%s']).name"""%(
+ dtype_name, dtype_name, dtype_name))
+ code.append("""
+ attr = kwargs.pop('attr', None)
+ kwargs.update(AttrScope.current.get(attr))
+ name = kwargs.pop('name', None)
+ name = NameManager.current.get(name, '%s')
+ _ = kwargs.pop('out', None)
+ keys = []
+ vals = []
+ sym_kwargs = dict()
+ for k, v in kwargs.items():
+ if isinstance(v, SymbolBase):
+ sym_kwargs[k] = v
+ else:
+ keys.append(k)
+ vals.append(v)"""%(func_name.lower()))
+ if key_var_num_args:
+ code.append("""
+ if '%s' not in kwargs:
+ keys.append('%s')
+ vals.append(len(sym_args) + len(sym_kwargs))"""%(
+ key_var_num_args, key_var_num_args))
+
+ code.append("""
+ return _symbol_creator(%d, sym_args, sym_kwargs, keys, vals, name)"""%(
+ handle.value))
+ else:
+ code.append("""
+def %s(%s):
+ kwargs.update(AttrScope.current.get(attr))
+ sym_kwargs = dict()
+ keys = []
+ vals = []"""%(func_name, ', '.join(signature)))
+ code.append("""
+ for k, v in kwargs.items():
+ if isinstance(v, SymbolBase):
+ sym_kwargs[k] = v
+ else:
+ keys.append(k)
+ vals.append(v)""")
+ # NDArray args
+ for name in ndarg_names: # pylint: disable=redefined-argument-from-local
+ code.append("""
+ if {name} is not None:
+ assert isinstance({name}, SymbolBase), \\
+ "Argument {name} must be Symbol instances, but got %s"%str({name})
+ sym_kwargs['{name}'] = {name}""".format(name=name))
+ # kwargs
+ for name in kwarg_names: # pylint: disable=redefined-argument-from-local
+ code.append("""
+ if %s is not _Null:
+ keys.append('%s')
+ vals.append(%s)"""%(name, name, name))
+ # dtype
+ if dtype_name is not None:
+ code.append("""
+ if %s is not _Null:
+ keys.append('%s')
+ vals.append(_numpy.dtype(%s).name)"""%(dtype_name, dtype_name, dtype_name))
+
+ code.append("""
+ name = NameManager.current.get(name, '%s')
+ return _symbol_creator(%d, None, sym_kwargs, keys, vals, name)"""%(
+ func_name.lower(), handle.value))
+
+ local = {}
+ exec(''.join(code), None, local) # pylint: disable=exec-used
+ symbol_function = local[func_name]
+ symbol_function.__name__ = func_name
+ symbol_function.__doc__ = doc_str
+ symbol_function.__module__ = 'mxnet.symbol'
+ return symbol_function
+
+
+def _init_symbol_module(root_namespace):
+ """List and add all the atomic symbol functions to current module."""
+ plist = ctypes.POINTER(ctypes.c_char_p)()
+ size = ctypes.c_uint()
+
+ check_call(_LIB.MXListAllOpNames(ctypes.byref(size),
+ ctypes.byref(plist)))
+ op_names = []
+ for i in range(size.value):
+ op_names.append(py_str(plist[i]))
+
+ module_obj = _sys.modules["%s.symbol" % root_namespace]
+ module_sparse = _sys.modules["%s.symbol.sparse" % root_namespace]
+ module_internal = _sys.modules["%s.symbol._internal" % root_namespace]
+ module_contrib = _sys.modules["%s.contrib.symbol" % root_namespace]
+ for name in op_names:
+ hdl = OpHandle()
+ check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl)))
+ function = _make_atomic_symbol_function(hdl, name)
+ if function.__name__.startswith('_contrib_'):
+ function.__name__ = function.__name__[9:]
+ function.__module__ = 'mxnet.contrib.symbol'
+ setattr(module_contrib, function.__name__, function)
+ elif function.__name__.startswith('_'):
+ setattr(module_internal, function.__name__, function)
+ else:
+ setattr(module_obj, function.__name__, function)
+
+ # register sparse ops under mxnet.symbol.sparse
+ if function.__name__.startswith('_sparse_'):
+ function.__name__ = function.__name__[8:]
+ function.__module__ = 'mxnet.symbol.sparse'
+ setattr(module_sparse, function.__name__, function)
+
+
+# Initialize the atomic symbol in startups
+_init_symbol_module("mxnet")
diff --git a/python/mxnet/symbol/sparse.py b/python/mxnet/symbol/sparse.py
new file mode 100644
index 000000000000..1d94f2b85bc7
--- /dev/null
+++ b/python/mxnet/symbol/sparse.py
@@ -0,0 +1,18 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Sparse Symbol API of MXNet."""
diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol/symbol.py
similarity index 90%
rename from python/mxnet/symbol.py
rename to python/mxnet/symbol/symbol.py
index 14cb3811deeb..aa8ca0b8dd53 100644
--- a/python/mxnet/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -29,39 +29,19 @@
import warnings
from numbers import Number
-import os as _os
-import sys as _sys
import numpy as _numpy
-from .base import _LIB, numeric_types
-from .base import c_array, c_str, mx_uint, py_str, string_types
-from .base import NDArrayHandle, ExecutorHandle, SymbolHandle, OpHandle
-from .base import check_call, MXNetError, NotImplementedForSymbol, _Null # pylint: disable=unused-import
-from .context import Context
-from .ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP, _GRAD_REQ_MAP
-from .name import NameManager # pylint: disable=unused-import
-from .executor import Executor
-from . import _symbol_internal as _internal
-from .attribute import AttrScope
-from .symbol_doc import _build_doc
-
-# Use different version of SymbolBase
-# When possible, use cython to speedup part of computation.
-try:
- if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0:
- from ._ctypes.symbol import SymbolBase, _set_symbol_class
- from ._ctypes.symbol import _symbol_creator # pylint: disable=unused-import
- elif _sys.version_info >= (3, 0):
- from ._cy3.symbol import SymbolBase, _set_symbol_class
- from ._cy3.symbol import _symbol_creator # pylint: disable=unused-import
- else:
- from ._cy2.symbol import SymbolBase, _set_symbol_class
- from ._cy2.symbol import _symbol_creator # pylint: disable=unused-import
-except ImportError:
- if int(_os.environ.get("MXNET_ENFORCE_CYTHON", False)) != 0:
- raise ImportError("Cython Module cannot be loaded but MXNET_ENFORCE_CYTHON=1")
- from ._ctypes.symbol import SymbolBase, _set_symbol_class
- from ._ctypes.symbol import _symbol_creator # pylint: disable=unused-import
+from ..base import _LIB, numeric_types
+from ..base import c_array, c_str, mx_uint, py_str, string_types
+from ..base import NDArrayHandle, ExecutorHandle, SymbolHandle
+from ..base import check_call, MXNetError, NotImplementedForSymbol
+from ..context import Context
+from ..ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP, _GRAD_REQ_MAP
+from ..ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID
+from ..ndarray import _ndarray_cls
+from ..executor import Executor
+from . import _internal, reshape
+from .op import SymbolBase, _set_symbol_class, AttrScope, _Null # pylint: disable=unused-import
class Symbol(SymbolBase):
@@ -1263,8 +1243,9 @@ def _get_ndarray_inputs(arg_key, args, arg_names, allow_missing):
raise TypeError('Only accept list of NDArrays or dict of str to NDArray')
return c_array(NDArrayHandle, arg_handles), arg_arrays
- def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None,
- shared_arg_names=None, shared_exec=None, shared_buffer=None, **kwargs):
+ def simple_bind(self, ctx, grad_req='write', type_dict=None, stype_dict=None,
+ group2ctx=None, shared_arg_names=None, shared_exec=None,
+ shared_buffer=None, **kwargs):
"""Bind current symbol to get an executor, allocate all the arguments needed.
Allows specifying data types.
@@ -1306,6 +1287,9 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None,
type_dict : Dict of str->numpy.dtype
Input type dictionary, name->dtype
+ stype_dict : Dict of str->str
+ Input storage type dictionary, name->storage_type
+
group2ctx : Dict of string to mx.Context
The dict mapping the `ctx_group` attribute to the context assignment.
@@ -1320,7 +1304,8 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None,
shared_buffer : Dict of string to `NDArray`
The dict mapping argument names to the `NDArray` that can be reused for initializing
the current executor. This buffer will be checked for reuse if one argument name
- of the current executor is not found in `shared_arg_names`.
+ of the current executor is not found in `shared_arg_names`. The `NDArray`s are
+ expected have default storage type.
kwargs : Dict of str->shape
Input shape dictionary, name->shape
@@ -1330,6 +1315,7 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None,
executor : mxnet.Executor
The generated executor
"""
+ # data types
num_provided_arg_types = 0
provided_arg_type_names = ctypes.POINTER(ctypes.c_char_p)() # provided type argument names
provided_arg_type_data = ctypes.POINTER(mx_uint)() # provided types
@@ -1345,6 +1331,22 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None,
provided_arg_type_names = c_array(ctypes.c_char_p, provided_arg_type_names)
provided_arg_type_data = c_array(ctypes.c_int, provided_arg_type_data)
+ # storage types
+ num_provided_arg_stypes = 0
+ # provided storage type argument names
+ provided_arg_stype_names = ctypes.POINTER(ctypes.c_char_p)()
+ provided_arg_stype_data = ctypes.POINTER(mx_uint)() # provided storage types
+ if stype_dict is not None:
+ provided_arg_stype_names = []
+ provided_arg_stype_data = []
+ for k, v in stype_dict.items():
+ if v in _STORAGE_TYPE_STR_TO_ID:
+ provided_arg_stype_names.append(c_str(k))
+ provided_arg_stype_data.append(ctypes.c_int(_STORAGE_TYPE_STR_TO_ID[v]))
+ num_provided_arg_stypes = mx_uint(len(provided_arg_stype_names))
+ provided_arg_stype_names = c_array(ctypes.c_char_p, provided_arg_stype_names)
+ provided_arg_stype_data = c_array(ctypes.c_int, provided_arg_stype_data)
+
provided_arg_shape_data = [] # shape data
# argument shape index in sdata,
# e.g. [sdata[indptr[0]], sdata[indptr[1]]) is the shape of the first arg
@@ -1418,6 +1420,8 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None,
shared_buffer_names = []
shared_buffer_handles = []
for k, v in shared_buffer.items():
+ assert(v.stype == 'default'), \
+ "shared_buffer is expected to only contain NDArrays with default storage"
shared_buffer_names.append(c_str(k))
shared_buffer_handles.append(v.handle)
shared_buffer_names = c_array(ctypes.c_char_p, shared_buffer_names)
@@ -1457,6 +1461,9 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None,
num_provided_arg_types,
provided_arg_type_names,
provided_arg_type_data,
+ num_provided_arg_stypes,
+ provided_arg_stype_names,
+ provided_arg_stype_data,
mx_uint(len(shared_arg_name_list)),
c_array(ctypes.c_char_p, shared_arg_name_list),
ctypes.byref(shared_buffer_len),
@@ -1486,11 +1493,12 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None,
shared_buffer[k] = v
# create in_args, arg_grads, and aux_states for the current executor
- arg_arrays = [NDArray(NDArrayHandle(in_arg_handles[i])) for i in range(num_in_args.value)]
- grad_arrays = [NDArray(NDArrayHandle(arg_grad_handles[i]))
+ arg_arrays = [_ndarray_cls(NDArrayHandle(in_arg_handles[i])) \
+ for i in range(num_in_args.value)]
+ grad_arrays = [_ndarray_cls(NDArrayHandle(arg_grad_handles[i]))
if arg_grad_handles[i] is not None
else None for i in range(num_in_args.value)]
- aux_arrays = [NDArray(NDArrayHandle(aux_state_handles[i]))
+ aux_arrays = [_ndarray_cls(NDArrayHandle(aux_state_handles[i]))
for i in range(num_aux_states.value)]
executor = Executor(exe_handle, self, ctx, grad_req, group2ctx)
@@ -1767,7 +1775,8 @@ def detach(self):
def backward(self):
raise NotImplementedForSymbol(self.backward, None)
-def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None, init=None, **kwargs):
+def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None,
+ init=None, stype=None, **kwargs):
"""Creates a symbolic variable with specified name.
Example usage:
@@ -1794,6 +1803,8 @@ def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None, ini
The dtype for input variable. If not specified, this value will be inferred.
init : initializer (mxnet.init.*)
Initializer for this variable to (optionally) override the default initializer.
+ stype : str
+ The storage type of the variable.
kwargs : Additional attribute variables
Additional attributes must start and end with double underscores.
@@ -1821,6 +1832,8 @@ def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None, ini
if not isinstance(init, string_types):
init = init.dumps()
attr['__init__'] = init
+ if stype is not None:
+ attr['__storage_type__'] = str(_STORAGE_TYPE_STR_TO_ID[stype])
for k, v in kwargs.items():
if k.startswith('__') and k.endswith('__'):
attr[k] = str(v)
@@ -2195,188 +2208,4 @@ def arange(start, stop=None, step=1.0, repeat=1, name=None, dtype=None):
return _internal._arange(start=start, stop=stop, step=step, repeat=repeat,
name=name, dtype=dtype)
-
-def _make_atomic_symbol_function(handle, name):
- """Create an atomic symbol function by handle and function name."""
- real_name = ctypes.c_char_p()
- desc = ctypes.c_char_p()
- num_args = mx_uint()
- arg_names = ctypes.POINTER(ctypes.c_char_p)()
- arg_types = ctypes.POINTER(ctypes.c_char_p)()
- arg_descs = ctypes.POINTER(ctypes.c_char_p)()
- key_var_num_args = ctypes.c_char_p()
- ret_type = ctypes.c_char_p()
-
- check_call(_LIB.MXSymbolGetAtomicSymbolInfo(
- handle, ctypes.byref(real_name), ctypes.byref(desc),
- ctypes.byref(num_args),
- ctypes.byref(arg_names),
- ctypes.byref(arg_types),
- ctypes.byref(arg_descs),
- ctypes.byref(key_var_num_args),
- ctypes.byref(ret_type)))
- narg = int(num_args.value)
- arg_names = [py_str(arg_names[i]) for i in range(narg)]
- arg_types = [py_str(arg_types[i]) for i in range(narg)]
- func_name = name
- key_var_num_args = py_str(key_var_num_args.value)
- ret_type = py_str(ret_type.value) if ret_type.value is not None else ''
- doc_str = _build_doc(func_name,
- py_str(desc.value),
- arg_names,
- arg_types,
- [py_str(arg_descs[i]) for i in range(narg)],
- key_var_num_args,
- ret_type)
-
- dtype_name = None
- arr_name = None
- ndsignature = []
- signature = []
- ndarg_names = []
- kwarg_names = []
- for i in range(narg):
- name, atype = arg_names[i], arg_types[i]
- if name == 'dtype':
- dtype_name = name
- signature.append('%s=_Null'%name)
- elif atype.startswith('NDArray') or atype.startswith('Symbol'):
- assert not arr_name, \
- "Op can only have one argument with variable " \
- "size and it must be the last argument."
- if atype.endswith('[]'):
- ndsignature.append('*%s'%name)
- arr_name = name
- else:
- ndsignature.append('%s=None'%name)
- ndarg_names.append(name)
- else:
- signature.append('%s=_Null'%name)
- kwarg_names.append(name)
- #signature.append('is_train=False')
- signature.append('name=None')
- signature.append('attr=None')
- signature.append('out=None')
- signature.append('**kwargs')
- signature = ndsignature + signature
-
- code = []
- if arr_name:
- code.append("""
-def %s(*%s, **kwargs):"""%(func_name, arr_name))
- code.append("""
- sym_args = []
- for i in {}:
- assert isinstance(i, SymbolBase), \\
- "Positional arguments must be Symbol instances, " \\
- "but got %s"%str(i)
- sym_args.append(i)""".format(arr_name))
- if dtype_name is not None:
- code.append("""
- if '%s' in kwargs:
- kwargs['%s'] = _numpy.dtype(kwargs['%s']).name"""%(
- dtype_name, dtype_name, dtype_name))
- code.append("""
- attr = kwargs.pop('attr', None)
- kwargs.update(AttrScope.current.get(attr))
- name = kwargs.pop('name', None)
- name = NameManager.current.get(name, '%s')
- _ = kwargs.pop('out', None)
- keys = []
- vals = []
- sym_kwargs = dict()
- for k, v in kwargs.items():
- if isinstance(v, SymbolBase):
- sym_kwargs[k] = v
- else:
- keys.append(k)
- vals.append(v)"""%(func_name.lower()))
- if key_var_num_args:
- code.append("""
- if '%s' not in kwargs:
- keys.append('%s')
- vals.append(len(sym_args) + len(sym_kwargs))"""%(
- key_var_num_args, key_var_num_args))
-
- code.append("""
- return _symbol_creator(%d, sym_args, sym_kwargs, keys, vals, name)"""%(
- handle.value))
- else:
- code.append("""
-def %s(%s):
- kwargs.update(AttrScope.current.get(attr))
- sym_kwargs = dict()
- keys = []
- vals = []"""%(func_name, ', '.join(signature)))
- code.append("""
- for k, v in kwargs.items():
- if isinstance(v, SymbolBase):
- sym_kwargs[k] = v
- else:
- keys.append(k)
- vals.append(v)""")
- # NDArray args
- for name in ndarg_names: # pylint: disable=redefined-argument-from-local
- code.append("""
- if {name} is not None:
- assert isinstance({name}, SymbolBase), \\
- "Argument {name} must be Symbol instances, but got %s"%str({name})
- sym_kwargs['{name}'] = {name}""".format(name=name))
- # kwargs
- for name in kwarg_names: # pylint: disable=redefined-argument-from-local
- code.append("""
- if %s is not _Null:
- keys.append('%s')
- vals.append(%s)"""%(name, name, name))
- # dtype
- if dtype_name is not None:
- code.append("""
- if %s is not _Null:
- keys.append('%s')
- vals.append(_numpy.dtype(%s).name)"""%(dtype_name, dtype_name, dtype_name))
-
- code.append("""
- name = NameManager.current.get(name, '%s')
- return _symbol_creator(%d, None, sym_kwargs, keys, vals, name)"""%(
- func_name.lower(), handle.value))
-
- local = {}
- exec(''.join(code), None, local) # pylint: disable=exec-used
- symbol_function = local[func_name]
- symbol_function.__name__ = func_name
- symbol_function.__doc__ = doc_str
- symbol_function.__module__ = 'mxnet.symbol'
- return symbol_function
-
-
-def _init_symbol_module(symbol_class, root_namespace):
- """List and add all the atomic symbol functions to current module."""
- _set_symbol_class(symbol_class)
- plist = ctypes.POINTER(ctypes.c_char_p)()
- size = ctypes.c_uint()
-
- check_call(_LIB.MXListAllOpNames(ctypes.byref(size),
- ctypes.byref(plist)))
- op_names = []
- for i in range(size.value):
- op_names.append(py_str(plist[i]))
-
- module_obj = _sys.modules["%s.symbol" % root_namespace]
- module_internal = _sys.modules["%s._symbol_internal" % root_namespace]
- module_contrib = _sys.modules["%s.contrib.symbol" % root_namespace]
- for name in op_names:
- hdl = OpHandle()
- check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl)))
- function = _make_atomic_symbol_function(hdl, name)
- if function.__name__.startswith('_contrib_'):
- function.__name__ = function.__name__[9:]
- function.__module__ = 'mxnet.contrib.symbol'
- setattr(module_contrib, function.__name__, function)
- elif function.__name__.startswith('_'):
- setattr(module_internal, function.__name__, function)
- else:
- setattr(module_obj, function.__name__, function)
-
-
-# Initialize the atomic symbol in startups
-_init_symbol_module(Symbol, "mxnet")
+_set_symbol_class(Symbol)
diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py
index c5587f8d80a8..e1210fbd3e6e 100644
--- a/python/mxnet/test_utils.py
+++ b/python/mxnet/test_utils.py
@@ -31,15 +31,17 @@
from contextlib import contextmanager
import numpy as np
import numpy.testing as npt
-import mxnet as mx
-from .context import Context
-from .ndarray import array
-from .symbol import Symbol
+import numpy.random as rnd
try:
import requests
except ImportError:
# in rare cases requests may be not installed
pass
+import mxnet as mx
+from .context import Context
+from .ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID
+from .ndarray import array
+from .symbol import Symbol
_rng = np.random.RandomState(1234)
@@ -85,6 +87,184 @@ def random_arrays(*shapes):
return arrays
+def random_sample(population, k):
+ """Return a k length list of the elements chosen from the population sequence."""
+ assert 0 <= k <= len(population)
+ population_copy = population[:]
+ np.random.shuffle(population_copy)
+ return population_copy[0:k]
+
+
+def _validate_csr_generation_inputs(num_rows, num_cols, density,
+ distribution="uniform"):
+ """Validates inputs for csr generation helper functions
+ """
+ total_nnz = int(num_rows * num_cols * density)
+ if density < 0 or density > 1:
+ raise ValueError("density has to be between 0 and 1")
+
+ if num_rows <= 0 or num_cols <= 0:
+ raise ValueError("num_rows or num_cols should be greater than 0")
+
+ if distribution == "powerlaw":
+ if total_nnz < 2 * num_rows:
+ raise ValueError("not supported for this density: %s"
+ " for this shape (%s, %s)"
+ " Please keep :"
+ " num_rows * num_cols * density >= 2 * num_rows"
+ % (density, num_rows, num_cols))
+
+
+def _get_uniform_dataset_csr(num_rows, num_cols, density=0.1, dtype=None):
+ """Returns CSRNDArray with uniform distribution
+ This generates a csr matrix with totalnnz unique randomly chosen numbers
+ from num_rows*num_cols and arranges them in the 2d array in the
+ following way: row_index = (random_number_generated / num_rows)
+ col_index = random_number_generated - row_index * num_cols
+ """
+ _validate_csr_generation_inputs(num_rows, num_cols, density,
+ distribution="uniform")
+ from scipy import sparse as sp
+ csr = sp.rand(num_rows, num_cols, density, dtype=dtype, format="csr")
+ result = mx.nd.sparse.csr_matrix(csr.data, csr.indptr, csr.indices,
+ (num_rows, num_cols), dtype=dtype)
+ return result
+
+
+def _get_powerlaw_dataset_csr(num_rows, num_cols, density=0.1, dtype=None):
+ """Returns CSRNDArray with powerlaw distribution
+ with exponentially increasing number of non zeros in each row.
+ Not supported for cases where total_nnz < 2*num_rows. This is because
+ the algorithm first tries to ensure that there are rows with no zeros by
+ putting non zeros at beginning of each row.
+ """
+
+ _validate_csr_generation_inputs(num_rows, num_cols, density,
+ distribution="powerlaw")
+
+ total_nnz = int(num_rows * num_cols * density)
+
+ unused_nnz = total_nnz
+ output_arr = np.zeros((num_rows, num_cols), dtype=dtype)
+ # Start with ones on each row so that no row is empty
+ for row in range(num_rows):
+ output_arr[row][0] = 1 + rnd.uniform(0.001, 2)
+ unused_nnz = unused_nnz - 1
+ if unused_nnz <= 0:
+ return mx.nd.array(output_arr).tostype("csr")
+
+ # Populate rest of matrix with 2^i items in ith row.
+ # if we have used all total nnz return the sparse matrix
+ # else if we reached max column size then fill up full columns until we use all nnz
+ col_max = 2
+ for row in range(num_rows):
+ col_limit = min(num_cols, col_max)
+ # In case col_limit reached assign same value to all elements, which is much faster
+ if col_limit == num_cols and unused_nnz > col_limit:
+ output_arr[row] = 1 + rnd.uniform(0.001, 2)
+ unused_nnz = unused_nnz - col_limit + 1
+ if unused_nnz <= 0:
+ return mx.nd.array(output_arr).tostype("csr")
+ else:
+ continue
+ for col_index in range(1, col_limit):
+ output_arr[row][col_index] = 1 + rnd.uniform(0.001, 2)
+ unused_nnz = unused_nnz - 1
+ if unused_nnz <= 0:
+ return mx.nd.array(output_arr).tostype("csr")
+ col_max = col_max * 2
+
+ if unused_nnz > 0:
+ raise ValueError("not supported for this density: %s"
+ " for this shape (%s,%s)" % (density, num_rows, num_cols))
+ else:
+ return mx.nd.array(output_arr).tostype("csr")
+
+
+def rand_sparse_ndarray(shape, stype, density=None, distribution=None, dtype=None):
+ """Generate a random sparse ndarray. Returns the ndarray, value(np) and indices(np)
+ Parameters
+ ----------
+ shape: list or tuple
+ stype: str, valid values: "csr" or "row_sparse"
+ density, optional: float, should be between 0 and 1
+ distribution, optional: str, valid values: "uniform" or "powerlaw"
+ dtype, optional: numpy.dtype, default value is None
+ Returns
+ -------
+ Result of type CSRNDArray or RowSparseNDArray
+ Examples
+ --------
+ Below is an example of the powerlaw distribution with csr as the stype.
+ It calculates the nnz using the shape and density.
+ It fills up the ndarray with exponentially increasing number of elements.
+ If there are enough unused_nnzs, n+1th row will have twice more nnzs compared to nth row.
+ else, remaining unused_nnzs will be used in n+1th row
+ If number of cols is too small and we have already reached column size it will fill up
+ all following columns in all followings rows until we reach the required density.
+
+ >>> csr_arr, _ = rand_sparse_ndarray(shape=(5, 16), stype="csr",
+ density=0.50, distribution="powerlaw")
+ >>> indptr = csr_arr.indptr.asnumpy()
+ >>> indices = csr_arr.indices.asnumpy()
+ >>> data = csr_arr.data.asnumpy()
+ >>> row2nnz = len(data[indptr[1]:indptr[2]])
+ >>> row3nnz = len(data[indptr[2]:indptr[3]])
+ >>> assert(row3nnz == 2*row2nnz)
+ >>> row4nnz = len(data[indptr[3]:indptr[4]])
+ >>> assert(row4nnz == 2*row3nnz)
+ """
+ density = rnd.rand() if density is None else density
+ dtype = default_dtype() if dtype is None else dtype
+ distribution = "uniform" if distribution is None else distribution
+ if stype == 'row_sparse':
+ assert (distribution == "uniform"), \
+ "Distribution %s not supported for row_sparse" % (distribution)
+ # sample index
+ idx_sample = rnd.rand(shape[0])
+ indices = np.argwhere(idx_sample < density).flatten()
+ if indices.shape[0] == 0:
+ result = mx.nd.zeros(shape, stype='row_sparse', dtype=dtype)
+ return result, (np.array([], dtype=dtype), np.array([], dtype='int64'))
+ # generate random values
+ val = rnd.rand(indices.shape[0], *shape[1:]).astype(dtype)
+ arr = mx.nd.sparse.row_sparse_array(val, indices, shape, indices_type=np.int64, dtype=dtype)
+ return arr, (val, indices)
+ elif stype == 'csr':
+ assert len(shape) == 2
+ if distribution == "uniform":
+ csr = _get_uniform_dataset_csr(shape[0], shape[1], density, dtype=dtype)
+ return csr, (csr.indptr, csr.indices, csr.data)
+ elif distribution == "powerlaw":
+ csr = _get_powerlaw_dataset_csr(shape[0], shape[1], density, dtype=dtype)
+ return csr, (csr.indptr, csr.indices, csr.data)
+ else:
+ assert(False), "Distribution not supported: %s" % (distribution)
+ else:
+ assert(False), "unknown storage type"
+
+
+def rand_ndarray(shape, stype, density=None, dtype=None, distribution=None):
+ if stype == 'default':
+ arr = mx.nd.array(random_arrays(shape), dtype=dtype)
+ else:
+ arr, _ = rand_sparse_ndarray(shape, stype, density=density, dtype=dtype,
+ distribution=distribution)
+ return arr
+
+
+def rand_shape_2d(dim0=10, dim1=10):
+ return rnd.randint(1, dim0 + 1), rnd.randint(1, dim1 + 1)
+
+
+def rand_shape_3d(dim0=10, dim1=10, dim2=10):
+ return rnd.randint(1, dim0 + 1), rnd.randint(1, dim1 + 1), rnd.randint(1, dim2 + 1)
+
+
+def rand_shape_nd(n, dim=10):
+ return rnd.randint(1, dim+1, size=n)
+
+
def np_reduce(dat, axis, keepdims, numpy_reduce_func):
"""Compatible reduce for old version of NumPy.
@@ -316,7 +496,8 @@ def _parse_location(sym, location, ctx):
% (str(set(sym.list_arguments())), str(set(location.keys()))))
else:
location = {k: v for k, v in zip(sym.list_arguments(), location)}
- location = {k: mx.nd.array(v, ctx=ctx) for k, v in location.items()}
+ location = {k: mx.nd.array(v, ctx=ctx) if isinstance(v, np.ndarray) \
+ else v for k, v in location.items()}
return location
@@ -437,7 +618,8 @@ def numeric_grad(executor, location, aux_states=None, eps=1e-4, use_forward_trai
def check_numeric_gradient(sym, location, aux_states=None, numeric_eps=1e-3, rtol=1e-2,
- atol=None, grad_nodes=None, use_forward_train=True, ctx=None):
+ atol=None, grad_nodes=None, use_forward_train=True, ctx=None,
+ grad_stype_dict=None):
"""Verify an operation by checking backward pass via finite difference method.
Based on Theano's `theano.gradient.verify_grad` [1]
@@ -454,7 +636,7 @@ def check_numeric_gradient(sym, location, aux_states=None, numeric_eps=1e-3, rto
- if type is dict of str -> numpy.ndarray
maps the name of arguments to the corresponding numpy.ndarray.
*In either case, value of all the arguments must be provided.*
- aux_states : ist or tuple or dict, optional
+ aux_states : list or tuple or dict, optional
The auxiliary states required when generating the executor for the symbol.
numeric_eps : float, optional
Delta for the finite difference method that approximates the gradient.
@@ -466,6 +648,8 @@ def check_numeric_gradient(sym, location, aux_states=None, numeric_eps=1e-3, rto
Whether to use is_train=True when computing the finite-difference.
ctx : Context, optional
Check the gradient computation on the specified device.
+ grad_stype_dict : dict of str->str, optional
+ Storage type dictionary for gradient ndarrays.
References
---------
..[1] https://github.com/Theano/Theano/blob/master/theano/gradient.py
@@ -489,7 +673,7 @@ def random_projection(shape):
location_npy = {k:v.asnumpy() for k, v in location.items()}
aux_states = _parse_aux_states(sym=sym, aux_states=aux_states, ctx=ctx)
if aux_states is not None:
- aux_states_npy = {k:v.asnumpy() for k, v in aux_states.items()}
+ aux_states_npy = {k: v.asnumpy() for k, v in aux_states.items()}
else:
aux_states_npy = None
if grad_nodes is None:
@@ -516,6 +700,14 @@ def random_projection(shape):
+ [("__random_proj", _rng.normal(0, 0.01, size=out_shape[0]))])
args_grad = {k: mx.nd.array(v, ctx=ctx) for k, v in args_grad_npy.items()}
+ if grad_stype_dict is not None:
+ assert isinstance(grad_stype_dict, dict), "grad_stype_dict must be a dict"
+ for k, v in grad_stype_dict.items():
+ if k in args_grad and v in _STORAGE_TYPE_STR_TO_ID and v != 'default':
+ # create an uninitialized sparse ndarray for executor
+ # if the symbolic grad is expected to be zero, it should not be initialized at all
+ args_grad[k] = mx.nd.zeros(args_grad[k].shape, args_grad[k].context,
+ args_grad[k].dtype, v)
executor = out.bind(ctx, grad_req=grad_req,
args=location, args_grad=args_grad, aux_states=aux_states)
@@ -607,15 +799,15 @@ def check_symbolic_forward(sym, location, expected, rtol=1E-4, atol=None,
g[:] = 0
executor.forward(is_train=False)
- outputs = [x.asnumpy() for x in executor.outputs]
+ outputs = [x.asnumpy() for x in executor.outputs]
for output_name, expect, output in zip(sym.list_outputs(), expected, outputs):
assert_almost_equal(expect, output, rtol, atol,
("EXPECTED_%s"%output_name, "FORWARD_%s"%output_name))
def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol=None,
- aux_states=None, grad_req='write', ctx=None):
+ aux_states=None, grad_req='write', ctx=None, grad_stypes=None):
"""Compares a symbol's backward results with the expected ones.
Prints error messages if the backward results are not the same as the expected results.
@@ -651,6 +843,8 @@ def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol=
Gradient requirements. 'write', 'add' or 'null'.
ctx : Context, optional
Running context.
+ grad_stypes: dict of str->str
+ dictionary of mapping argument name to stype for the gradient
Example
-------
@@ -676,14 +870,23 @@ def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol=
if isinstance(expected, (list, tuple)):
expected = {k:v for k, v in zip(sym.list_arguments(), expected)}
args_grad_npy = {k:_rng.normal(size=v.shape) for k, v in expected.items()}
- args_grad_data = {k: mx.nd.array(v, ctx=ctx) for k, v in args_grad_npy.items()}
+ args_grad_data = {}
+ for k, v in args_grad_npy.items():
+ nd = mx.nd.array(v, ctx=ctx)
+ if grad_stypes is not None and k in grad_stypes:
+ args_grad_data[k] = nd.tostype(grad_stypes[k])
+ else:
+ args_grad_data[k] = nd
+
if isinstance(grad_req, str):
grad_req = {k:grad_req for k in sym.list_arguments()}
elif isinstance(grad_req, (list, tuple)):
grad_req = {k:v for k, v in zip(sym.list_arguments(), grad_req)}
- executor = sym.bind(ctx=ctx, args=location, args_grad=args_grad_data, aux_states=aux_states)
+ executor = sym.bind(ctx=ctx, args=location, args_grad=args_grad_data,
+ aux_states=aux_states, grad_req=grad_req)
executor.forward(is_train=True)
+
if isinstance(out_grads, (tuple, list)):
out_grads = [mx.nd.array(v, ctx=ctx) for v in out_grads]
elif isinstance(out_grads, (dict)):
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 93458d21ac5a..0fe3fe3e302e 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -172,6 +172,39 @@ int MXNDArrayCreateEx(const mx_uint *shape,
API_END();
}
+int MXNDArrayCreateSparseEx(int storage_type,
+ const mx_uint *shape,
+ mx_uint ndim,
+ int dev_type,
+ int dev_id,
+ int delay_alloc,
+ int dtype,
+ mx_uint num_aux,
+ int *aux_type,
+ mx_uint *aux_ndims,
+ const mx_uint *aux_shape,
+ NDArrayHandle *out) {
+ API_BEGIN();
+ std::vector aux_types;
+ std::vector aux_shapes;
+ auto shape_start = aux_shape;
+ for (size_t i = 0; i < num_aux; i++) {
+ // types
+ aux_types.push_back(aux_type[i]);
+ // shapes
+ aux_shapes.emplace_back(shape_start, shape_start + aux_ndims[i]);
+ shape_start += aux_ndims[i];
+ }
+ *out = new NDArray(
+ NDArrayStorageType(storage_type),
+ TShape(shape, shape + ndim),
+ Context::Create(static_cast(dev_type), dev_id),
+ delay_alloc != 0,
+ dtype, aux_types, aux_shapes);
+ API_END();
+}
+
+
int MXNDArrayLoadFromRawBytes(const void *buf,
size_t size,
NDArrayHandle *out) {
@@ -215,6 +248,23 @@ int MXNDArraySyncCopyToCPU(NDArrayHandle handle,
API_END();
}
+/*!
+ * \brief Copy src.data() to dst.data() if i = -1, else dst.aux_data(i) if i >= 0
+ * This function blocks. Do not use it in performance critical code.
+ * \param handle_dst handle of a dst ndarray whose data/aux_data has been allocated
+ * \param handle_src handle of a src ndarray which has default storage type
+ * \param i dst data blob indicator
+ */
+int MXNDArraySyncCopyFromNDArray(NDArrayHandle handle_dst,
+ const NDArrayHandle handle_src,
+ const int i) {
+ API_BEGIN();
+ NDArray* dst = static_cast(handle_dst);
+ NDArray* src = static_cast(handle_src);
+ dst->SyncCopyFromNDArray(*src, -1, i);
+ API_END();
+}
+
int MXNDArrayWaitToRead(NDArrayHandle handle) {
API_BEGIN();
static_cast(handle)->WaitToRead();
@@ -351,6 +401,18 @@ MXNET_DLL int MXNDArrayReshape(NDArrayHandle handle,
API_END_HANDLE_ERROR(delete ptr);
}
+int MXNDArrayGetStorageType(NDArrayHandle handle,
+ int *out_storage_type) {
+ API_BEGIN();
+ NDArray *arr = static_cast(handle);
+ if (!arr->is_none()) {
+ *out_storage_type = arr->storage_type();
+ } else {
+ *out_storage_type = kUndefinedStorage;
+ }
+ API_END();
+}
+
int MXNDArrayGetShape(NDArrayHandle handle,
mx_uint *out_dim,
const mx_uint **out_pdata) {
@@ -400,6 +462,42 @@ int MXNDArrayGetDType(NDArrayHandle handle,
API_END();
}
+int MXNDArrayGetAuxType(NDArrayHandle handle,
+ mx_uint i,
+ int *out_type) {
+ API_BEGIN();
+ NDArray *arr = static_cast(handle);
+ *out_type = arr->aux_type(i);
+ API_END();
+}
+
+/*!
+ * \brief Get a deep copy of the ith aux data blob
+ * in the form of an NDArray of default storage type.
+ * This function blocks. Do not use it in performance critical code.
+ */
+int MXNDArrayGetAuxNDArray(NDArrayHandle handle,
+ mx_uint i,
+ NDArrayHandle *out) {
+ API_BEGIN();
+ NDArray *arr = static_cast(handle);
+ *out = new NDArray(arr->aux_ndarray(i));
+ API_END();
+}
+
+/*!
+ * \brief Get a deep copy of the data blob
+ * in the form of an NDArray of default storage type.
+ * This function blocks. Do not use it in performance critical code.
+ */
+int MXNDArrayGetDataNDArray(NDArrayHandle handle,
+ NDArrayHandle *out) {
+ API_BEGIN();
+ NDArray *arr = static_cast(handle);
+ *out = new NDArray(arr->data_ndarray());
+ API_END();
+}
+
int MXNDArrayGetContext(NDArrayHandle handle,
int *out_dev_type,
int *out_dev_id) {
@@ -735,6 +833,24 @@ int MXKVStorePullEx(KVStoreHandle handle,
API_END();
}
+int MXKVStorePullRowSparse(KVStoreHandle handle,
+ mx_uint num,
+ const char** keys,
+ NDArrayHandle* vals,
+ const NDArrayHandle* row_ids,
+ int priority) {
+ API_BEGIN();
+ std::vector v_keys(num);
+ std::vector> v_val_rowids(num);
+ for (mx_uint i = 0; i < num; ++i) {
+ v_keys[i] = keys[i];
+ v_val_rowids[i] = std::make_pair(static_cast(vals[i]),
+ *static_cast(row_ids[i]));
+ }
+ static_cast(handle)->PullRowSparse(v_keys, v_val_rowids, priority);
+ API_END();
+}
+
int MXKVStoreSetUpdater(KVStoreHandle handle,
MXKVStoreUpdater updater,
void* updater_handle) {
diff --git a/src/c_api/c_api_common.h b/src/c_api/c_api_common.h
index 846b53973b07..fee3f03f6db0 100644
--- a/src/c_api/c_api_common.h
+++ b/src/c_api/c_api_common.h
@@ -76,6 +76,8 @@ struct MXAPIThreadLocalEntry {
std::vector arg_shapes, out_shapes, aux_shapes;
/*! \brief result holder for returning type flags */
std::vector arg_types, out_types, aux_types;
+ /*! \brief result holder for returning storage types */
+ std::vector arg_storage_types, out_storage_types, aux_storage_types;
/*! \brief result holder for returning shape dimensions */
std::vector arg_shape_ndim, out_shape_ndim, aux_shape_ndim;
/*! \brief result holder for returning shape pointer */
diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc
index a4c48e426879..631c1a7d93eb 100644
--- a/src/c_api/c_api_executor.cc
+++ b/src/c_api/c_api_executor.cc
@@ -198,6 +198,9 @@ int MXExecutorBindEX(SymbolHandle symbol_handle,
* \param num_provided_arg_dtypes number of user provided in_arg and axu_state dtypes
* \param provided_arg_dtype_names argument name list of provided dtypes
* \param provided_arg_dtypes data of provided dtypes
+ * \param num_provided_arg_stypes number of user provided in_arg and axu_state storage types
+ * \param provided_arg_stype_names argument name list of provided storage types
+ * \param provided_arg_stypes data of provided storage types
* \param num_shared_arg_names number of parameter names passed from _bind_ith_exec
* \param shared_arg_name_list parameter name list passed from _bind_ith_exec
* \param shared_buffer_len number of shared data arrays passed from _bind_ith_exec
@@ -230,6 +233,9 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle,
const mx_uint num_provided_arg_dtypes,
const char** provided_arg_dtype_names,
const int* provided_arg_dtypes,
+ const mx_uint num_provided_arg_stypes,
+ const char** provided_arg_stype_names,
+ const int* provided_arg_stypes,
const mx_uint num_shared_arg_names,
const char** shared_arg_name_list,
int* shared_buffer_len,
@@ -254,7 +260,7 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle,
// attr_dict for setting up type_dict and arg/aux ctx
std::unordered_map> attr_dict;
- if (nullptr == provided_arg_dtypes || nullptr != g2c_keys) {
+ if (nullptr == provided_arg_dtypes || nullptr != g2c_keys || nullptr == provided_arg_stypes) {
std::vector> attrs =
sym->ListAttrsRecursive();
attr_dict.reserve(attrs.size());
@@ -280,6 +286,23 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle,
}
}
+ // setup arg_stype_map
+ std::unordered_map arg_stype_map;
+ if (nullptr == provided_arg_stypes) { // use attr_dict
+ for (const auto& arg_name : in_arg_names) {
+ const auto it = attr_dict.find(arg_name);
+ if (it == attr_dict.end() || !it->second.count("__storage_type__")) {
+ arg_stype_map[arg_name] = kDefaultStorage;
+ }
+ }
+ } else { // use user input type_dict
+ // create stype map for in_args and aux_states
+ arg_stype_map.reserve(num_provided_arg_stypes);
+ for (mx_uint i = 0; i < num_provided_arg_stypes; ++i) {
+ arg_stype_map[provided_arg_stype_names[i]] = provided_arg_stypes[i];
+ }
+ }
+
// create default ctx
Context ctx = Context::Create(static_cast(dev_type), dev_id);
// create ctx map
@@ -420,9 +443,10 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle,
std::vector aux_state_vec;
*out = Executor::SimpleBind(*sym, ctx, ctx_map, in_arg_ctx_vec, arg_grad_ctx_vec,
- aux_state_ctx_vec, arg_shape_map, arg_dtype_map, grad_req_type_vec,
- shared_arg_name_set, &in_arg_vec, &arg_grad_vec, &aux_state_vec,
- use_shared_buffer? &shared_buffer_map : nullptr,
+ aux_state_ctx_vec, arg_shape_map, arg_dtype_map, arg_stype_map,
+ grad_req_type_vec, shared_arg_name_set, &in_arg_vec,
+ &arg_grad_vec, &aux_state_vec,
+ use_shared_buffer ? &shared_buffer_map : nullptr,
reinterpret_cast(shared_exec_handle));
// copy ndarray ptrs to ret->handles so that front end
diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc
index 3202f55abea7..d392baf45d3e 100644
--- a/src/c_api/c_api_ndarray.cc
+++ b/src/c_api/c_api_ndarray.cc
@@ -18,7 +18,8 @@
*/
/*!
- * \file c_api_symbolic.cc
+ * Copyright (c) 2016 by Contributors
+ * \file c_api_ndarray.cc
* \brief C API of mxnet
*/
@@ -150,14 +151,17 @@ void SetContext(Context* p_ctx,
#endif // MXNET_USE_CUDA
}
+// Set the shape, dtype and storage type
void SetShapeType(const nnvm::Op* op,
const nnvm::NodeAttrs& attrs,
const Context& ctx,
const std::vector& ndinputs,
- std::vector* p_ndoutputs) {
+ std::vector* p_ndoutputs,
+ int* dispatch_stype) {
std::vector& ndoutputs = *p_ndoutputs;
static auto& infershape = nnvm::Op::GetAttr("FInferShape");
static auto& infertype = nnvm::Op::GetAttr("FInferType");
+ static auto& inferstorage = nnvm::Op::GetAttr("FInferStorageType");
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
// infer shape
std::vector& in_shapes = ret->arg_shapes;
@@ -193,9 +197,35 @@ void SetShapeType(const nnvm::Op* op,
CHECK(infertype[op](attrs, &in_types, &out_types));
CHECK_EQ(out_types.size(), ndoutputs.size());
+ // infer storage type
+ auto& in_storage_types = ret->arg_storage_types;
+ auto& out_storage_types = ret->out_storage_types;
+ in_storage_types.clear();
+ out_storage_types.clear();
+ for (auto& i : ndinputs) {
+ in_storage_types.push_back(i.storage_type());
+ }
+ for (auto& i : ndoutputs) {
+ out_storage_types.push_back(i.storage_type());
+ }
+ if (inferstorage.count(op)) {
+ CHECK(inferstorage[op](attrs, ctx, &in_storage_types, &out_storage_types));
+ CHECK_EQ(out_storage_types.size(), ndoutputs.size());
+ }
+
+ bool contains_non_default = common::ContainsNonDefaultStorage(in_storage_types);
+ contains_non_default |= common::ContainsNonDefaultStorage(out_storage_types);
+ int kNonDefaultStorage = -2;
+ *dispatch_stype = contains_non_default ? kNonDefaultStorage : kDefaultStorage;
for (size_t i = 0; i < ndoutputs.size(); ++i) {
+ NDArrayStorageType storage_type = static_cast(out_storage_types[i]);
if (ndoutputs[i].is_none()) {
- ndoutputs[i] = NDArray(out_shapes[i], ctx, true, out_types[i]);
+ // if failed to infer the storage type, assume the output storage is dense
+ if (storage_type == kDefaultStorage || out_storage_types[i] == kUndefinedStorage) {
+ ndoutputs[i] = NDArray(out_shapes[i], ctx, true, out_types[i]);
+ } else {
+ ndoutputs[i] = NDArray(storage_type, out_shapes[i], ctx, true, out_types[i]);
+ }
} else {
CHECK_EQ(ndoutputs[i].shape(), out_shapes[i])
<< i << "th output has invalid shape. "
@@ -212,7 +242,7 @@ void SetShapeType(const nnvm::Op* op,
void SetDependency(std::vector *p_read_vars,
std::vector *p_write_vars,
std::vector *p_requested,
- std::vector *p_auxidx,
+ std::vector *p_mutate_idx,
const nnvm::Op* op,
const nnvm::NodeAttrs& attrs,
const Context& ctx,
@@ -224,7 +254,7 @@ void SetDependency(std::vector *p_read_vars,
std::vector& read_vars = *p_read_vars;
std::vector& write_vars = *p_write_vars;
std::vector& requested = *p_requested;
- std::vector& auxidx = *p_auxidx;
+ std::vector& mutate_idx = *p_mutate_idx;
if (tmp_resource.count(op)) {
int ntmp = 0;
@@ -250,15 +280,30 @@ void SetDependency(std::vector *p_read_vars,
write_vars.push_back(i.var());
}
if (mutate.count(op)) {
- auxidx = mutate[op](attrs);
- std::sort(auxidx.begin(), auxidx.end());
- for (auto & i : auxidx) {
+ mutate_idx = mutate[op](attrs);
+ std::sort(mutate_idx.begin(), mutate_idx.end());
+ for (auto & i : mutate_idx) {
write_vars.push_back(ndinputs[i].var());
}
}
Engine::Get()->DeduplicateVarHandle(&read_vars, &write_vars);
}
+inline void SetWriteInplaceReq(const std::vector &ndinputs,
+ const std::vector &ndoutputs,
+ std::vector *req) {
+ std::unordered_set in_vars;
+ for (auto &nd : ndinputs) {
+ in_vars.insert(nd.var());
+ }
+ for (size_t i = 0; i < ndoutputs.size(); i++) {
+ // output NDArray shares the memory with the input NDArray
+ if (in_vars.find(ndoutputs[i].var()) != in_vars.end()) {
+ req->at(i) = kWriteInplace;
+ }
+ }
+}
+
void PushFCompute(const FCompute& fn,
const nnvm::Op* op,
const nnvm::NodeAttrs& attrs,
@@ -267,24 +312,75 @@ void PushFCompute(const FCompute& fn,
const std::vector& write_vars,
const std::vector& requested,
const std::vector& ndinputs,
- const std::vector& ndoutputs) {
+ const std::vector& ndoutputs,
+ const std::vector& mutate_idx) {
+ using namespace common;
bool is_train = AutogradRuntime::Get()->IsTraining();
Engine::Get()->PushAsync(
- [ctx, attrs, fn, ndinputs, ndoutputs, requested, is_train](
+ [ctx, attrs, fn, ndinputs, ndoutputs, requested, is_train, mutate_idx](
RunContext rctx,
engine::CallbackOnComplete on_complete) {
std::vector input_blobs, output_blobs;
- for (auto& i : ndinputs) {
- input_blobs.push_back(i.data());
- }
- for (auto& i : ndoutputs) {
- output_blobs.push_back(i.data());
+ // pre-fcompute and post-fcompute storage fallback src NDArrays and dst NDArrays
+ std::vector pre_temp_src, pre_temp_dst, post_temp_dst, post_temp_src;
+ // mapping from index in input_blobs to index in pre_temp_dst
+ std::unordered_map in_temp_idx_map;
+ // populate input blobs and output blobs
+ SetupDefaultBlobs(ndinputs, &input_blobs, &pre_temp_src, &pre_temp_dst, &in_temp_idx_map);
+ SetupDefaultBlobs(ndoutputs, &output_blobs, &post_temp_dst, &post_temp_src);
+ // add mutable inputs to post temp list
+ for (const auto idx : mutate_idx) {
+ auto map_iter = in_temp_idx_map.find(idx);
+ if (map_iter != in_temp_idx_map.end()) {
+ post_temp_src.push_back(pre_temp_dst[map_iter->second]);
+ post_temp_dst.push_back(ndinputs[idx]);
+ }
}
OpContext opctx{is_train, rctx,
engine::CallbackOnComplete(),
requested};
std::vector req(output_blobs.size(), kWriteTo);
- fn(attrs, opctx, input_blobs, req, output_blobs);
+ if (ctx.dev_mask() == gpu::kDevMask) {
+#if MXNET_USE_CUDA
+ CastNonDefaultStorage(pre_temp_src, pre_temp_dst, opctx);
+ fn(attrs, opctx, input_blobs, req, output_blobs);
+ // cast to original storage type, if necessary
+ CastNonDefaultStorage(post_temp_src, post_temp_dst, opctx);
+ rctx.get_stream()->Wait();
+#else
+ LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
+#endif
+ } else {
+ CastNonDefaultStorage(pre_temp_src, pre_temp_dst, opctx);
+ fn(attrs, opctx, input_blobs, req, output_blobs);
+ // cast to original storage type, if necessary
+ CastNonDefaultStorage(post_temp_src, post_temp_dst, opctx);
+ }
+ on_complete();
+ }, ctx, read_vars, write_vars, FnProperty::kNormal,
+ 0, PROFILER_MESSAGE(op->name.c_str()));
+}
+
+void PushFComputeEx(const FComputeEx& fn,
+ const nnvm::Op* op,
+ const nnvm::NodeAttrs& attrs,
+ const Context& ctx,
+ const std::vector& read_vars,
+ const std::vector& write_vars,
+ const std::vector& requested,
+ const std::vector& ndinputs,
+ const std::vector& ndoutputs) {
+ Engine::Get()->PushAsync(
+ [ctx, attrs, fn, ndinputs, ndoutputs, requested](
+ RunContext rctx,
+ engine::CallbackOnComplete on_complete) {
+ std::vector input_blobs, output_blobs;
+ OpContext opctx{false, rctx,
+ engine::CallbackOnComplete(),
+ requested};
+ std::vector req(ndoutputs.size(), kWriteTo);
+ SetWriteInplaceReq(ndinputs, ndoutputs, &req);
+ fn(attrs, opctx, ndinputs, req, ndoutputs);
if (ctx.dev_mask() == gpu::kDevMask) {
rctx.get_stream()->Wait();
}
@@ -301,7 +397,9 @@ void PushOperator(const OpStatePtr& state,
const std::vector& write_vars,
const std::vector& requested,
const std::vector& ndinputs,
- const std::vector& ndoutputs) {
+ const std::vector& ndoutputs,
+ const std::vector& mutate_idx) {
+ using namespace common;
static auto& fexec_type = nnvm::Op::GetAttr("FExecType");
bool is_train = AutogradRuntime::Get()->IsTraining();
@@ -314,15 +412,40 @@ void PushOperator(const OpStatePtr& state,
if (fcompute != nullptr) {
CHECK(exec_type == ExecType::kSync || exec_type == ExecType::kAsync);
Engine::Get()->PushAsync(
- [state, fcompute, ndinputs, ndoutputs, requested, is_train, exec_type](
+ [state, fcompute, ndinputs, ndoutputs, requested, is_train, exec_type, mutate_idx](
RunContext rctx,
engine::CallbackOnComplete on_complete) {
OpContext opctx{is_train, rctx, on_complete, requested};
+
std::vector input_blobs, output_blobs;
- for (const auto& i : ndinputs) input_blobs.push_back(i.data());
- for (const auto& i : ndoutputs) output_blobs.push_back(i.data());
+ // pre-fcompute and post-fcompute storage fallback src NDArrays and dst NDArrays
+ std::vector pre_temp_src, pre_temp_dst, post_temp_dst, post_temp_src;
+ // mapping from index in input_blobs to index in pre_temp_dst
+ std::unordered_map in_temp_idx_map;
+ // populate input blobs and output blobs
+ SetupDefaultBlobs(ndinputs, &input_blobs, &pre_temp_src, &pre_temp_dst, &in_temp_idx_map);
+ SetupDefaultBlobs(ndoutputs, &output_blobs, &post_temp_dst, &post_temp_src);
+ // add mutable inputs to post temp list
+ for (const auto idx : mutate_idx) {
+ if (in_temp_idx_map.find(idx) != in_temp_idx_map.end()) {
+ post_temp_src.push_back(pre_temp_dst[in_temp_idx_map[idx]]);
+ post_temp_dst.push_back(ndinputs[idx]);
+ }
+ }
std::vector req(output_blobs.size(), kWriteTo);
- fcompute(state, opctx, input_blobs, req, output_blobs);
+ if (rctx.get_ctx().dev_mask() == gpu::kDevMask) {
+#if MXNET_USE_CUDA
+ CastNonDefaultStorage(pre_temp_src, pre_temp_dst, opctx);
+ fcompute(state, opctx, input_blobs, req, output_blobs);
+ CastNonDefaultStorage(post_temp_src, post_temp_dst, opctx);
+#else
+ LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
+#endif
+ } else {
+ CastNonDefaultStorage(pre_temp_src, pre_temp_dst, opctx);
+ fcompute(state, opctx, input_blobs, req, output_blobs);
+ CastNonDefaultStorage(post_temp_src, post_temp_dst, opctx);
+ }
if (exec_type == ExecType::kSync) {
if (rctx.get_ctx().dev_mask() == gpu::kDevMask) {
rctx.get_stream()->Wait();
@@ -342,6 +465,7 @@ void PushOperator(const OpStatePtr& state,
engine::CallbackOnComplete on_complete) {
OpContext opctx{is_train, rctx, on_complete, requested};
std::vector req(ndoutputs.size(), kWriteTo);
+ SetWriteInplaceReq(ndinputs, ndoutputs, &req);
fcompute_ex(state, opctx, ndinputs, req, ndoutputs);
if (exec_type == ExecType::kSync) {
if (rctx.get_ctx().dev_mask() == gpu::kDevMask) {
@@ -363,8 +487,6 @@ void ImperativeInvokeImpl(const Context& default_ctx,
const nnvm::NodeAttrs& attrs,
std::vector* p_ndinputs,
std::vector* p_ndoutputs) {
- static auto& fcpu = nnvm::Op::GetAttr("FCompute");
- static auto& fgpu = nnvm::Op::GetAttr("FCompute");
static auto& ndfunc = nnvm::Op::GetAttr("FNDArrayFunction");
static auto& createop = nnvm::Op::GetAttr("FCreateOpState");
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
@@ -379,29 +501,32 @@ void ImperativeInvokeImpl(const Context& default_ctx,
} else {
// TODO(piiswrong): infer ctx
Context ctx;
+ int stype;
SetContext(&ctx, attrs, ndinputs, ndoutputs, default_ctx);
- SetShapeType(op, attrs, ctx, ndinputs, &ndoutputs);
+ SetShapeType(op, attrs, ctx, ndinputs, &ndoutputs, &stype);
std::vector read_vars, write_vars;
std::vector requested;
- std::vector auxidx;
- SetDependency(&read_vars, &write_vars, &requested, &auxidx,
+ std::vector mutate_idx;
+ SetDependency(&read_vars, &write_vars, &requested, &mutate_idx,
op, attrs, ctx, ndinputs, ndoutputs);
- FCompute fn;
- if (ctx.dev_mask() == cpu::kDevMask && fcpu.count(op)) {
- fn = fcpu[op];
- } else if (ctx.dev_mask() == gpu::kDevMask && fgpu.count(op)) {
- fn = fgpu[op];
- }
-
- if (fn) {
+ FCompute fn = common::GetFCompute(op, "FCompute", ctx);
+ FComputeEx fn_ex = common::GetFCompute(op, "FComputeEx", ctx);
+ if (fn_ex && stype != kDefaultStorage) {
if (AutogradRuntime::Get()->IsRecording()) {
AutogradRuntime::Get()->RecordImperativeFCompute(op,
attrs, &ndinputs, &ndoutputs);
}
- PushFCompute(fn, op, attrs, ctx, read_vars, write_vars,
+ PushFComputeEx(fn_ex, op, attrs, ctx, read_vars, write_vars,
requested, ndinputs, ndoutputs);
+ } else if (fn) {
+ if (AutogradRuntime::Get()->IsRecording()) {
+ AutogradRuntime::Get()->RecordImperativeFCompute(op,
+ attrs, &ndinputs, &ndoutputs);
+ }
+ PushFCompute(fn, op, attrs, ctx, read_vars, write_vars,
+ requested, ndinputs, ndoutputs, mutate_idx);
} else if (createop.count(op)) {
auto state =
createop[op](attrs, ctx, ret->arg_shapes, ret->arg_types);
@@ -411,7 +536,7 @@ void ImperativeInvokeImpl(const Context& default_ctx,
}
write_vars.push_back(state.get_var());
PushOperator(state, op, attrs, ctx, read_vars, write_vars,
- requested, ndinputs, ndoutputs);
+ requested, ndinputs, ndoutputs, mutate_idx);
} else {
LOG(FATAL)
<< "Operator " << op->name << " is not implemented for "
@@ -461,6 +586,28 @@ int MXImperativeInvoke(AtomicSymbolCreator creator,
API_END();
}
+int MXImperativeInvokeEx(AtomicSymbolCreator creator,
+ int num_inputs,
+ NDArrayHandle *inputs,
+ int *num_outputs,
+ NDArrayHandle **outputs,
+ int num_params,
+ const char **param_keys,
+ const char **param_vals,
+ const int **out_stypes) { // outputs storage types
+ API_BEGIN();
+ MXImperativeInvoke(creator, num_inputs, inputs, num_outputs, outputs,
+ num_params, param_keys, param_vals);
+ MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
+ NDArray** output_nds = reinterpret_cast(*outputs);
+ ret->out_types.resize(*num_outputs);
+ for (int i = 0; i < *num_outputs; ++i) {
+ ret->out_types[i] = output_nds[i]->storage_type();
+ }
+ *out_stypes = dmlc::BeginPtr(ret->out_types);
+ API_END();
+}
+
int MXCreateCachedOp(SymbolHandle handle,
CachedOpHandle *out) {
nnvm::Symbol* sym = static_cast(handle);
@@ -540,6 +687,24 @@ int MXInvokeCachedOp(CachedOpHandle handle,
API_END();
}
+int MXInvokeCachedOpEx(CachedOpHandle handle,
+ int num_inputs,
+ NDArrayHandle *inputs,
+ int *num_outputs,
+ NDArrayHandle **outputs,
+ const int **out_stypes) { // outputs storage types
+ API_BEGIN();
+ MXInvokeCachedOp(handle, num_inputs, inputs, num_outputs, outputs);
+ MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
+ NDArray** output_nds = reinterpret_cast(*outputs);
+ ret->out_types.resize(*num_outputs);
+ for (int i = 0; i < *num_outputs; ++i) {
+ ret->out_types[i] = output_nds[i]->storage_type();
+ }
+ *out_stypes = dmlc::BeginPtr(ret->out_types);
+ API_END();
+}
+
int MXAutogradIsTraining(bool* curr) {
API_BEGIN();
*curr = AutogradRuntime::Get()->IsTraining();
diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc
index e2c29b888ada..d526aea0d35f 100644
--- a/src/c_api/c_api_symbolic.cc
+++ b/src/c_api/c_api_symbolic.cc
@@ -29,6 +29,7 @@
#include
#include "./c_api_common.h"
#include "../operator/operator_common.h"
+#include "../executor/exec_pass.h"
namespace mxnet {
namespace op {
@@ -459,7 +460,7 @@ int MXSymbolInferShape(SymbolHandle sym,
}
try {
- g = nnvm::pass::InferShape(std::move(g), arg_shapes, "__shape__");
+ g = mxnet::exec::InferShape(std::move(g), arg_shapes, "__shape__");
} catch (const mxnet::op::InferShapeError &err) {
throw dmlc::Error(err.msg);
}
@@ -544,7 +545,7 @@ int MXSymbolInferType(SymbolHandle sym,
mxnet::MatchArguments(g.indexed_graph(), kwargs, &arg_types, "InferType");
}
- g = nnvm::pass::InferType(std::move(g), arg_types, "__dtype__");
+ g = mxnet::exec::InferType(std::move(g), arg_types, "__dtype__");
// copy back
CopyAttr(g.indexed_graph(), g.GetAttr("dtype"),
&(ret->arg_types), &(ret->out_types), &(ret->aux_types));
diff --git a/src/c_api/c_predict_api.cc b/src/c_api/c_predict_api.cc
index 5ca01492800e..dda4fda1ed8f 100644
--- a/src/c_api/c_predict_api.cc
+++ b/src/c_api/c_predict_api.cc
@@ -32,6 +32,7 @@
#include
#include "./c_api_common.h"
#include "../operator/operator_common.h"
+#include "../executor/exec_pass.h"
using namespace mxnet;
@@ -194,7 +195,7 @@ int MXPredCreatePartialOut(const char* symbol_json_str,
}
}
nnvm::Graph g; g.outputs = sym.outputs;
- g = nnvm::pass::InferShape(std::move(g), in_shapes, "__shape__");
+ g = mxnet::exec::InferShape(std::move(g), in_shapes, "__shape__");
bool infer_complete = (g.GetAttr("shape_num_unknown_nodes") == 0);
CHECK(infer_complete)
<< "The shape information of is not enough to get the shapes";
diff --git a/src/common/utils.cc b/src/common/utils.cc
new file mode 100644
index 000000000000..125e4e5dc7d7
--- /dev/null
+++ b/src/common/utils.cc
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file utils.cc
+ * \brief cpu implementation of util functions
+ */
+
+#include "./utils.h"
+#include "../operator/tensor/cast_storage-inl.h"
+
+namespace mxnet {
+namespace common {
+
+template<>
+void CastStorageDispatch(const OpContext& ctx,
+ const NDArray& input,
+ const NDArray& output) {
+ mxnet::op::CastStorageComputeImpl(ctx, input, output);
+}
+
+} // namespace common
+} // namespace mxnet
diff --git a/src/common/utils.cu b/src/common/utils.cu
new file mode 100644
index 000000000000..093480a98907
--- /dev/null
+++ b/src/common/utils.cu
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file utils.cu
+ * \brief gpu implementation of util functions
+ */
+
+#include "./utils.h"
+#include "../operator/tensor/cast_storage-inl.h"
+
+namespace mxnet {
+namespace common {
+
+template<>
+void CastStorageDispatch(const OpContext& ctx,
+ const NDArray& input,
+ const NDArray& output) {
+ mxnet::op::CastStorageComputeImpl(ctx, input, output);
+}
+
+} // namespace common
+} // namespace mxnet
diff --git a/src/common/utils.h b/src/common/utils.h
index 85e30970f1a0..92631a9b5c34 100644
--- a/src/common/utils.h
+++ b/src/common/utils.h
@@ -24,7 +24,14 @@
#ifndef MXNET_COMMON_UTILS_H_
#define MXNET_COMMON_UTILS_H_
-#if DMLC_USE_CXX11
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
#include
#include
#include
@@ -33,15 +40,100 @@
#include
#include
#include
-#endif // DMLC_USE_CXX11
-
-#include
-#include
+#include
namespace mxnet {
namespace common {
-#if DMLC_USE_CXX11
+template
+void CastStorageDispatch(const OpContext& ctx, const NDArray& input, const NDArray& output);
+
+/*
+ * \brief setup default-storage tblobs from source NDArrays. If any source NDArray has non-default
+ * storage, it creates a temp NDArray with default storage and uses the temp tblob. The
+ * function also records the indices of non-default source NDArrays and the indices of
+ * their corresponding temporary NDArrays in the temp array.
+ * \param src list of source NDArray
+ * \param blobs list of tblobs to return
+ * \param temp_src list of source NDArrays which requires temporary default storage representation
+ * \param temp_dst list of temporary destination NDArrays for default storage representation
+ * \param idx_map mapping from indices in source NDArrays to indices in temp_dst. When not set,
+ indices are not recorded
+ * \return true if any source NDArray need to cast storage
+ */
+inline bool SetupDefaultBlobs(const std::vector& src,
+ std::vector *blobs,
+ std::vector *temp_src,
+ std::vector *temp_dst,
+ std::unordered_map *idx_map = nullptr) {
+ bool require_cast = false;
+ for (size_t i = 0; i < src.size(); i++) {
+ auto& nd = src[i];
+ if (nd.storage_type() != kDefaultStorage) {
+ if (idx_map != nullptr) {
+ (*idx_map)[i] = temp_dst->size();
+ }
+ NDArray temp(nd.shape(), nd.ctx(), false, nd.dtype());
+ temp_src->emplace_back(nd);
+ temp_dst->emplace_back(temp);
+ blobs->emplace_back(temp.data());
+ require_cast = true;
+ } else {
+ blobs->push_back(nd.data());
+ }
+ }
+ return require_cast;
+}
+
+/*
+ * \brief cast the NDArrays in `src` and store the result in NDArrays in `dst`.
+ * This is only used for storage fallback in executor.
+ * When storage_fallback is false, and `MXNET_EXEC_STORAGE_FALLBACK` == 0,
+ * storage fallback is disallowed.
+ * \param src list of source NDArray to cast
+ * \param dst list of destionation NDArray which hold the result of cast_storage operation
+ * \param ctx operator context for cast_storage operation
+ * \param storage_fallback whether storage_fallback is allowed. When set to false,
+ * its value depends on `MXNET_EXEC_STORAGE_FALLBACK`.
+ */
+template
+inline void CastNonDefaultStorage(const std::vector& src,
+ const std::vector& dst,
+ const OpContext& ctx,
+ bool storage_fallback = false) {
+ CHECK_GE(dst.size(), src.size());
+ if (src.size() == 0) return;
+ if (storage_fallback == false) {
+ storage_fallback = dmlc::GetEnv("MXNET_EXEC_STORAGE_FALLBACK", true);
+ }
+ if (storage_fallback == false) {
+ LOG(FATAL) << "Storage type conversion detected during execution. "
+ << "You are probably executing an operator which "
+ << "doesn't support NDArray inputs with non-default storage.";
+ }
+ for (size_t i = 0; i < src.size(); i++) {
+ CastStorageDispatch(ctx, src[i], dst[i]);
+ }
+}
+
+// Check if any storage type is not default storage
+inline bool ContainsNonDefaultStorage(const StorageTypeVector& vstorage) {
+ for (const auto& i : vstorage) {
+ if (i != kUndefinedStorage && i != kDefaultStorage) return true;
+ }
+ return false;
+}
+
+// Check if any NDArray in the list has default storage
+inline bool ContainsDefaultStorage(const std::vector& ndarrays) {
+ for (const auto &nd : ndarrays) {
+ if (nd.storage_type() == kDefaultStorage) {
+ return true;
+ }
+ }
+ return false;
+}
+
// heuristic to dermine number of threads per GPU
inline int GetNumThreadPerGPU() {
// This is resource efficient option.
@@ -56,6 +148,67 @@ inline int GetExecNumMatchColor() {
return std::min(num_match_color, GetNumThreadPerGPU());
}
+template
+V ParallelAccumulate(const T* a, const int n, V start) {
+ V sum = start;
+#pragma omp parallel for reduction(+:sum)
+ for (int i = 0; i < n; ++i) {
+ sum += a[i];
+ }
+ return sum;
+}
+
+/*!
+ * \brief
+ * Helper function for ParallelSort.
+ * DO NOT call this function directly.
+ * Use the interface ParallelSort instead.
+ * Ref: https://github.com/dmlc/difacto/blob/master/src/common/parallel_sort.h
+ */
+template
+void ParallelSortHelper(RandomIt first, size_t len,
+ size_t grainsize, const Compare& comp) {
+ if (len < grainsize) {
+ std::sort(first, first+len, comp);
+ } else {
+ std::thread thr(ParallelSortHelper, first, len/2, grainsize, comp);
+ ParallelSortHelper(first+len/2, len - len/2, grainsize, comp);
+ thr.join();
+ std::inplace_merge(first, first+len/2, first+len, comp);
+ }
+}
+
+/*!
+ * \brief
+ * Sort the elements in the range [first, last) into the ascending order defined by
+ * the comparator comp.
+ * If the length of the range [first, last) is greater than a certain threshold,
+ * the range will be recursively divided into two and assign two threads
+ * to sort each half range.
+ * Ref: https://github.com/dmlc/difacto/blob/master/src/common/parallel_sort.h
+ */
+template
+void ParallelSort(RandomIt first, RandomIt last, size_t num_threads, Compare comp) {
+ const auto num = std::distance(first, last);
+ size_t grainsize = std::max(num / num_threads + 5, static_cast(1024*16));
+ ParallelSortHelper(first, num, grainsize, comp);
+}
+
+/*!
+ * \brief
+ * Sort the elements in the range [first, last) into ascending order.
+ * The elements are compared using the default < operator.
+ * If the length of the range [first, last) is greater than a certain threshold,
+ * the range will be recursively divided into two and assign two threads
+ * to sort each half range.
+ * Ref: https://github.com/dmlc/difacto/blob/master/src/common/parallel_sort.h
+ */
+template
+void ParallelSort(RandomIt first, RandomIt last, size_t num_threads) {
+ ParallelSort(first, last, num_threads,
+ std::less::value_type>());
+}
+
/*!
* \brief Random Engine
*/
@@ -159,8 +312,6 @@ FCompType GetFCompute(const nnvm::Op* op, const std::string& name,
}
}
-#endif // DMLC_USE_CXX11
-
} // namespace common
} // namespace mxnet
#endif // MXNET_COMMON_UTILS_H_
diff --git a/src/executor/attach_op_execs_pass.cc b/src/executor/attach_op_execs_pass.cc
index 47b74758d702..fe8cc653bbc3 100644
--- a/src/executor/attach_op_execs_pass.cc
+++ b/src/executor/attach_op_execs_pass.cc
@@ -24,6 +24,7 @@
#include
#include
#include
+#include
#include
#include "../common/utils.h"
#include "./exec_pass.h"
@@ -40,33 +41,98 @@ const OperatorProperty* OpPropGetOpProperty(const NodeAttrs& attrs);
namespace exec {
-// forward executor
-class StatefulComputeExecutor : public OpExecutor {
+// abstract OpExecutor which provides storage fallback procedure on
+// non-default inputs and outputs
+// FComputeExecutor and FStatefulComputeExecutor inherit from this class
+class StorageFallbackOpExecutor : public OpExecutor {
public:
- void Run(RunContext rctx) override {
+ explicit StorageFallbackOpExecutor(const std::vector &mutate_idx)
+ : mutate_idx_(mutate_idx) {}
+
+ void Setup() override {
+ init_ = false;
+ }
+
+ protected:
+ // initialize the data blobs
+ void InitBlobs() {
+ using namespace common;
if (!init_) {
- in_data_.clear();
- for (size_t i = 0; i < in_array.size(); ++i) {
- in_data_.push_back(in_array[i].data());
- }
- out_data_.clear();
- for (size_t i = 0; i < out_array.size(); ++i) {
- out_data_.push_back(out_array[i].data());
+ in_data_.clear(); out_data_.clear();
+ pre_temp_src_.clear(); pre_temp_dst_.clear();
+ post_temp_src_.clear(); post_temp_dst_.clear();
+ in_temp_idx_map_.clear();
+ SetupDefaultBlobs(in_array, &in_data_, &pre_temp_src_, &pre_temp_dst_, &in_temp_idx_map_);
+ SetupDefaultBlobs(out_array, &out_data_, &post_temp_dst_, &post_temp_src_);
+ for (const auto idx : mutate_idx_) {
+ auto map_iter = in_temp_idx_map_.find(idx);
+ if (map_iter != in_temp_idx_map_.end()) {
+ post_temp_src_.push_back(pre_temp_dst_[map_iter->second]);
+ post_temp_dst_.push_back(in_array[idx]);
+ }
}
init_ = true;
}
+ }
+
+ // storage fallback before fcompute is launched
+ void PreFCompute(bool is_gpu) {
+ using namespace common;
+ InitBlobs();
+ if (is_gpu) {
+#if MXNET_USE_CUDA
+ CastNonDefaultStorage(pre_temp_src_, pre_temp_dst_, op_ctx);
+#else
+ LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
+#endif
+ } else {
+ CastNonDefaultStorage(pre_temp_src_, pre_temp_dst_, op_ctx);
+ }
+ }
+
+ // storage fallback after fcompute is completed
+ void PostFCompute(bool is_gpu) {
+ using namespace common;
+ if (is_gpu) {
+#if MXNET_USE_CUDA
+ CastNonDefaultStorage(post_temp_src_, post_temp_dst_, op_ctx);
+#else
+ LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
+#endif
+ } else {
+ CastNonDefaultStorage(post_temp_src_, post_temp_dst_, op_ctx);
+ }
+ }
+
+ // default storage tensor blobs for fcompute
+ std::vector in_data_, out_data_;
+ // source NDArray for cast storage
+ std::vector pre_temp_src_, post_temp_src_;
+ // destination NDArray for cast storage
+ std::vector pre_temp_dst_, post_temp_dst_;
+ // mapping from index in input_blobs to index in pre_temp_dst
+ std::unordered_map in_temp_idx_map_;
+ // indices of mutatable inputs
+ std::vector