Skip to content

Commit

Permalink
fix failed tests. add back 64bit support for dot
Browse files Browse the repository at this point in the history
fix lint
  • Loading branch information
eric-haibin-lin committed Jul 19, 2017
1 parent ea2d74f commit 4b62c8b
Show file tree
Hide file tree
Showing 14 changed files with 255 additions and 363 deletions.
2 changes: 1 addition & 1 deletion dmlc-core
1 change: 1 addition & 0 deletions include/mxnet/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "./base.h"
#include "./ndarray.h"
#include "./engine.h"
#include "./resource.h"

namespace mxnet {

Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pickle
from .ndarray import NDArray
from .base import _LIB
from .base import check_call, c_array, c_str, string_types, mx_uint, py_str
from .base import check_call, c_array, c_str, string_types, mx_uint, py_str, integer_types
from .base import NDArrayHandle, KVStoreHandle
from . import optimizer as opt

Expand All @@ -16,7 +16,7 @@ def _ctype_key_value(keys, vals):
c_keys = []
c_vals = []
for key, val in zip(keys, vals):
c_key_i, c_val_i = _ctype_str_key_value(key, val)
c_key_i, c_val_i = _ctype_key_value(key, val)
c_keys += c_key_i
c_vals += c_val_i
return (c_array(ctypes.c_char_p, c_keys), c_array(NDArrayHandle, c_vals))
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/ndarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from . import _internal
from . import op
from .op import CachedOp, invoke
from .op import CachedOp
from .ndarray import NDArray, array, concatenate, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP
from .ndarray import empty, ones, add, arange, divide, equal, full, greater, greater_equal, imdecode
from .ndarray import lesser, lesser_equal, maximum, minimum, moveaxis, multiply, negative, not_equal
Expand Down
7 changes: 5 additions & 2 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@
import warnings
import operator
import numpy as np
from ..base import _LIB, numeric_types
from ..base import _LIB, numeric_types, integer_types
from ..base import c_array, mx_real_t
from ..base import mx_uint, NDArrayHandle, check_call
from ..base import ctypes2buffer
from ..context import Context
from . import _internal
from .op import NDArrayBase, _STORAGE_TYPE_ID_TO_STR
from . import *
from . import broadcast_add, broadcast_mul, transpose, broadcast_not_equal, broadcast_power
from . import broadcast_sub, broadcast_div, broadcast_to, broadcast_equal, cast_storage
from . import broadcast_greater, broadcast_greater_equal, broadcast_lesser, broadcast_lesser_equal
from . import zeros_like, slice

# pylint: disable= no-member
_DTYPE_NP_TO_MX = {
Expand Down
8 changes: 4 additions & 4 deletions python/mxnet/ndarray/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,21 @@
try:
if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0:
from .._ctypes.ndarray import NDArrayBase, _set_ndarray_class, _STORAGE_TYPE_ID_TO_STR
from .._ctypes.ndarray import invoke, CachedOp, _imperative_invoke
from .._ctypes.ndarray import CachedOp, _imperative_invoke
elif _sys.version_info >= (3, 0):
from .._cy3.ndarray import NDArrayBase, _set_ndarray_class,\
_imperative_invoke, _STORAGE_TYPE_ID_TO_STR
from .._cy3.ndarray import invoke, CachedOp, _imperative_invoke
from .._cy3.ndarray import CachedOp, _imperative_invoke
else:
from .._cy2.ndarray import NDArrayBase, _set_ndarray_class,\
_imperative_invoke, _STORAGE_TYPE_ID_TO_STR
from .._cy2.ndarray import invoke, CachedOp, _imperative_invoke
from .._cy2.ndarray import CachedOp, _imperative_invoke
except ImportError:
if int(_os.environ.get("MXNET_ENFORCE_CYTHON", False)) != 0:
raise ImportError("Cython Module cannot be loaded but MXNET_ENFORCE_CYTHON=1")
from .._ctypes.ndarray import NDArrayBase, _set_ndarray_class,\
_imperative_invoke, _STORAGE_TYPE_ID_TO_STR
from .._ctypes.ndarray import invoke, CachedOp, _imperative_invoke
from .._ctypes.ndarray import CachedOp, _imperative_invoke

from ..base import mx_uint, check_call, _LIB, py_str, OpHandle, c_str, _Null
# pylint: enable=unused-import
Expand Down
27 changes: 15 additions & 12 deletions src/c_api/c_api_ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ void SetShapeType(const nnvm::Op* op,
int kNonDefaultStorage = -2;
*dispatch_stype = contains_non_default ? kNonDefaultStorage : kDefaultStorage;
for (size_t i = 0; i < ndoutputs.size(); ++i) {
NDArrayStorageType storage_type = static_cast<NDArrayStorageType>(out_storage_types[i]);
if (ndoutputs[i].is_none()) {
// if failed to infer the storage type, assume the output storage is dense
if (storage_type == kDefaultStorage || out_storage_types[i] == kUndefinedStorage) {
Expand Down Expand Up @@ -349,6 +350,7 @@ void PushOperator(const OpStatePtr& state,
const std::vector<Resource>& requested,
const std::vector<NDArray>& ndinputs,
const std::vector<NDArray>& ndoutputs) {
using namespace common;
static auto& fexec_type = nnvm::Op::GetAttr<FExecType>("FExecType");

bool is_train = AutogradRuntime::Get()->IsTraining();
Expand All @@ -367,7 +369,7 @@ void PushOperator(const OpStatePtr& state,
OpContext opctx{is_train, rctx, on_complete, requested};
std::vector<TBlob> input_blobs, output_blobs;
std::vector<NDArray> temp_in, temp_out;
if (ctx.dev_mask() == gpu::kDevMask) {
if (rctx.get_ctx().dev_mask() == gpu::kDevMask) {
#if MXNET_USE_CUDA
GetDefaultBlobs<gpu>(ndinputs, &input_blobs, &temp_in, opctx);
GetDefaultBlobs<gpu>(ndoutputs, &output_blobs, &temp_out, opctx);
Expand Down Expand Up @@ -425,8 +427,6 @@ void ImperativeInvokeImpl(const Context& default_ctx,
const nnvm::NodeAttrs& attrs,
std::vector<NDArray>* p_ndinputs,
std::vector<NDArray>* p_ndoutputs) {
static auto& fcpu = nnvm::Op::GetAttr<FCompute>("FCompute<cpu>");
static auto& fgpu = nnvm::Op::GetAttr<FCompute>("FCompute<gpu>");
static auto& ndfunc = nnvm::Op::GetAttr<FNDArrayFunction>("FNDArrayFunction");
static auto& createop = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
Expand All @@ -441,23 +441,26 @@ void ImperativeInvokeImpl(const Context& default_ctx,
} else {
// TODO(piiswrong): infer ctx
Context ctx;
int stype;
SetContext(&ctx, attrs, ndinputs, ndoutputs, default_ctx);
SetShapeType(op, attrs, ctx, ndinputs, &ndoutputs);
SetShapeType(op, attrs, ctx, ndinputs, &ndoutputs, &stype);

std::vector<engine::VarHandle> read_vars, write_vars;
std::vector<Resource> requested;
std::vector<uint32_t> auxidx;
SetDependency(&read_vars, &write_vars, &requested, &auxidx,
op, attrs, ctx, ndinputs, ndoutputs);

FCompute fn;
if (ctx.dev_mask() == cpu::kDevMask && fcpu.count(op)) {
fn = fcpu[op];
} else if (ctx.dev_mask() == gpu::kDevMask && fgpu.count(op)) {
fn = fgpu[op];
}

if (fn) {
FCompute fn = common::GetFCompute<FCompute>(op, "FCompute", ctx);
FComputeEx fn_ex = common::GetFCompute<FComputeEx>(op, "FComputeEx", ctx);
if (fn_ex && stype != kDefaultStorage) {
if (AutogradRuntime::Get()->IsTraining()) {
AutogradRuntime::Get()->RecordImperativeFCompute(op,
attrs, &ndinputs, &ndoutputs);
}
PushFComputeEx(fn_ex, op, attrs, ctx, read_vars, write_vars,
requested, ndinputs, ndoutputs);
} else if (fn) {
if (AutogradRuntime::Get()->IsTraining()) {
AutogradRuntime::Get()->RecordImperativeFCompute(op,
attrs, &ndinputs, &ndoutputs);
Expand Down
9 changes: 4 additions & 5 deletions src/executor/attach_op_execs_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ class StatefulComputeExecutor : public OpExecutor {
CastNonDefaultStorage<cpu>(out_array, temp_out_, op_ctx);
}
#if MKL_EXPERIMENTAL == 1
//TODO(haibin) handle MKL mem with non-default NDArray
mkl_tblobs_prv_to_cpu(in_data_);
mkl_tblobs_prv_to_cpu(out_data_);
#endif
Expand Down Expand Up @@ -84,7 +83,7 @@ class StatefulComputeExecutor : public OpExecutor {
// stateful compute_ex executor
class StatefulComputeExExecutor : public OpExecutor {
public:
void Run(RunContext rctx) override {
void Run(RunContext rctx, bool is_gpu) override {
op_ctx.run_ctx = rctx;
fcompute_(state_, op_ctx, in_array, req, out_array);
}
Expand Down Expand Up @@ -115,7 +114,7 @@ class StatefulComputeExExecutor : public OpExecutor {
// fcompute executor
class FComputeExecutor : public OpExecutor {
public:
void Run(RunContext rctx) override {
void Run(RunContext rctx, bool is_gpu) override {
using namespace common;
// TODO(haibin) avoid repeating this if all inputs are already in default-storage
op_ctx.run_ctx = rctx;
Expand All @@ -139,7 +138,7 @@ class FComputeExecutor : public OpExecutor {
CastNonDefaultStorage<cpu>(out_array, temp_out_, op_ctx);
}
#if MKL_EXPERIMENTAL == 1
//TODO(haibin) handle MKL mem with non-default NDArray
// TODO(haibin) handle MKL mem with non-default NDArray
mkl_tblobs_prv_to_cpu(in_data_);
mkl_tblobs_prv_to_cpu(out_data_);
#endif
Expand Down Expand Up @@ -167,7 +166,7 @@ class FComputeExecutor : public OpExecutor {
// fcompute_ex executor
class FComputeExExecutor : public OpExecutor {
public:
void Run(RunContext rctx) override {
void Run(RunContext rctx, bool is_gpu) override {
op_ctx.run_ctx = rctx;
fcompute_(attrs_, op_ctx, in_array, req, out_array);
}
Expand Down
3 changes: 2 additions & 1 deletion src/executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1433,7 +1433,8 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) {
CHECK_EQ(opnode.exec->out_array.size(), 1U);
CopyFromTo(opnode.exec->in_array[0], &(opnode.exec->out_array[0]));
} else if (opnode.exec->exec_type() == ExecType::kLocal) {
opnode.exec->Run(RunContext{opnode.ctx, nullptr});
bool is_gpu = opnode.ctx.dev_mask() == gpu::kDevMask;
opnode.exec->Run(RunContext{opnode.ctx, nullptr}, is_gpu);
} else if (opnode.cached_opr != nullptr) {
#if MXNET_USE_PROFILER
bool profiling = engine::Profiler::Get()->GetState() == engine::Profiler::kRunning;
Expand Down
Loading

0 comments on commit 4b62c8b

Please sign in to comment.