Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #37 from antinucleon/narray
Browse files Browse the repository at this point in the history
add narray scalar op
  • Loading branch information
antinucleon committed Sep 3, 2015
2 parents 5644490 + 37b986e commit 36b7e8f
Show file tree
Hide file tree
Showing 7 changed files with 266 additions and 20 deletions.
2 changes: 1 addition & 1 deletion dmlc-core
79 changes: 78 additions & 1 deletion include/mxnet/narray.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -106,27 +106,55 @@ class NArray {
* \return reference of self
*/
NArray &operator+=(const NArray &src);
/*!
* \brief elementwise add to current space
* this mutate the current NArray
* \param src the data to add
* \return reference of self
*/
NArray &operator+=(const real_t &src);
/*!
* \brief elementwise subtract from current narray
* this mutate the current NArray
* \param src the data to substract
* \return reference of self
*/
NArray &operator-=(const NArray &src);
/*!
* \brief elementwise subtract from current narray
* this mutate the current NArray
* \param src the data to substract
* \return reference of self
*/
NArray &operator-=(const real_t &src);
/*!
* \brief elementwise multiplication to current narray
* this mutate the current NArray
* \param src the data to substract
* \return reference of self
*/
NArray &operator*=(const NArray &src);
/*!
* \brief elementwise multiplication to current narray
* this mutate the current NArray
* \param src the data to substract
* \return reference of self
*/
NArray &operator*=(const real_t &src);
/*!
* \brief elementwise division from current narray
* this mutate the current NArray
* \param src the data to substract
* \return reference of self
*/
NArray &operator/=(const NArray &src);
/*!
* \brief elementwise division from current narray
* this mutate the current NArray
* \param src the data to substract
* \return reference of self
*/
NArray &operator/=(const real_t &src);
/*!
* \brief return transpose of current NArray
* \return a new transposed NArray
Expand Down Expand Up @@ -241,6 +269,8 @@ class NArray {
friend void BinaryOp(const NArray &lhs, const NArray &rhs, NArray *out);
template<typename OP>
friend void UnaryOp(const NArray &lhs, const NArray &rhs, NArray *out);
template<typename OP>
friend void ScalarOp(const NArray &lhs, const real_t &rhs, NArray *out);
};

/*!
Expand All @@ -262,27 +292,55 @@ void CopyFromTo(const NArray &from, NArray *to);
* \return a new result narray
*/
NArray operator+(const NArray &lhs, const NArray &rhs);
/*!
* \brief elementwise add
* \param lhs left operand
* \param rhs right operand
* \return a new result narray
*/
NArray operator+(const NArray &lhs, const real_t &rhs);
/*!
* \brief elementwise substraction
* \param lhs left operand
* \param rhs right operand
* \return a new result narray
*/
NArray operator-(const NArray &lhs, const NArray &rhs);
/*!
* \brief elementwise substraction
* \param lhs left operand
* \param rhs right operand
* \return a new result narray
*/
NArray operator-(const NArray &lhs, const real_t &rhs);
/*!
* \brief elementwise multiplication
* \param lhs left operand
* \param rhs right operand
* \return a new result narray
*/
NArray operator*(const NArray &lhs, const NArray &rhs);
NArray operator*(const NArray &lhs, const NArray &rhs);\
/*!
* \brief elementwise multiplication
* \param lhs left operand
* \param rhs right operand
* \return a new result narray
*/
NArray operator*(const NArray &lhs, const real_t &rhs);
/*!
* \brief elementwise division
* \param lhs left operand
* \param rhs right operand
* \return a new result narray
*/
NArray operator/(const NArray &lhs, const NArray &rhs);
/*!
* \brief elementwise division
* \param lhs left operand
* \param rhs right operand
* \return a new result narray
*/
NArray operator/(const NArray &lhs, const real_t &rhs);

//--------------------------------------------------------------
// The following part are API Registration of NArray functions.
Expand Down Expand Up @@ -346,6 +404,25 @@ struct NArrayFunctionReg
this->add_argument("rhs", "NArray", "Right operand to the function.");
return *this;
}
/*!
* \brief set the function body to a binary NArray function
* this will also auto set the parameters correctly
* \param fscalar function body to set
* \return ref to the registered entry, used to set properties
*/
inline NArrayFunctionReg &set_function(void fscalar(const NArray &lhs,
const real_t &rhs,
NArray *out)) {
body = [fscalar] (NArray **used_vars,
real_t *s, NArray **mutate_vars) {
fscalar(*used_vars[0], s[0], mutate_vars[0]);
};
num_use_vars = 1; num_mutate_vars = 1; num_scalars = 1;
type_mask = kNArrayArgBeforeScalar | kAcceptEmptyMutateTarget;
this->add_argument("lhs", "NArray", "Left operand to the function.");
this->add_argument("rhs", "real_t", "Right operand to the function.");
return *this;
}
/*!
* \brief set the function body to a unary NArray function
* this will also auto set the parameters correctly
Expand Down
14 changes: 14 additions & 0 deletions python/mxnet/narray.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def __del__(self):
def __add__(self, other):
if isinstance(other, NArray):
return NArray._plus(self, other)
elif isinstance(other, float) or isinstance(other, int):
return NArray._plus_scalar(self, float(other))
else:
raise TypeError('type %s not supported' % str(type(other)))

Expand All @@ -75,12 +77,16 @@ def __radd__(self, other):
def __sub__(self, other):
if isinstance(other, NArray):
return NArray._minus(self, other)
elif isinstance(other, float) or isinstance(other, int):
return NArray._minus_scalar(self, float(other))
else:
raise TypeError('type %s not supported' % str(type(other)))

def __mul__(self, other):
if isinstance(other, NArray):
return NArray._mul(self, other)
elif isinstance(other, float) or isinstance(other, int):
return NArray._mul_scalar(self, float(other))
else:
raise TypeError('type %s not supported' % str(type(other)))

Expand All @@ -90,9 +96,17 @@ def __rmul__(self, other):
def __div__(self, other):
if isinstance(other, NArray):
return NArray._div(self, other)
elif isinstance(other, float) or isinstance(other, int):
return NArray._div_scalar(self, float(other))
else:
raise TypeError('type %s not supported' % str(type(other)))

def __idiv__(self, other):
return self.__div__(other)

def __truediv__(self, other):
return self.__div__(other)

def __getstate__(self):
this = self.__dict__.copy()
handle = this['handle']
Expand Down
146 changes: 133 additions & 13 deletions src/narray/narray.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,95 @@ inline void BinaryOp(const NArray &lhs,
NArray ret = *out;
// redirect everything to mshadow operations
switch (lhs.ctx().dev_mask) {
case cpu::kDevMask:
DAGEngine::Get()->Push([lhs, rhs, ret](RunContext ctx) {
ret.ptr_->CheckAndAlloc();
TBlob tmp = ret.data();
narray::Eval<cpu, OP>(lhs.data(), rhs.data(), &tmp, ctx);
}, lhs.ctx(), {lhs.ptr_->var, rhs.ptr_->var}, {ret.ptr_->var});
case cpu::kDevMask: {
auto func = [lhs, rhs, ret](RunContext ctx) {
ret.ptr_->CheckAndAlloc();
TBlob tmp = ret.data();
narray::Eval<cpu, OP>(lhs.data(), rhs.data(), &tmp, ctx);
};
if (lhs.ptr_->var == ret.ptr_->var && rhs.ptr_->var == ret.ptr_->var) {
DAGEngine::Get()->Push(func, lhs.ctx(), {}, {ret.ptr_->var});
} else if (lhs.ptr_->var == ret.ptr_->var) {
DAGEngine::Get()->Push(func, lhs.ctx(), {rhs.ptr_->var}, {ret.ptr_->var});
} else if (rhs.ptr_->var == ret.ptr_->var) {
DAGEngine::Get()->Push(func, lhs.ctx(), {lhs.ptr_->var}, {ret.ptr_->var});
} else {
DAGEngine::Get()->Push(func, lhs.ctx(), {lhs.ptr_->var, rhs.ptr_->var}, {ret.ptr_->var});
}
break;
}
#if MXNET_USE_CUDA
case gpu::kDevMask:
DAGEngine::Get()->Push([lhs, rhs, ret](RunContext ctx) {
ret.ptr_->CheckAndAlloc();
TBlob tmp = ret.data();
narray::Eval<gpu, OP>(lhs.data(), rhs.data(), &tmp, ctx);
}, lhs.ctx(), {lhs.ptr_->var, rhs.ptr_->var}, {ret.ptr_->var});
case gpu::kDevMask: {
auto func = [lhs, rhs, ret](RunContext ctx) {
ret.ptr_->CheckAndAlloc();
TBlob tmp = ret.data();
narray::Eval<gpu, OP>(lhs.data(), rhs.data(), &tmp, ctx);
};
if (lhs.ptr_->var == ret.ptr_->var && rhs.ptr_->var == ret.ptr_->var) {
DAGEngine::Get()->Push(func, lhs.ctx(), {}, {ret.ptr_->var});
} else if (lhs.ptr_->var == ret.ptr_->var) {
DAGEngine::Get()->Push(func, lhs.ctx(), {rhs.ptr_->var}, {ret.ptr_->var});
} else if (rhs.ptr_->var == ret.ptr_->var) {
DAGEngine::Get()->Push(func, lhs.ctx(), {lhs.ptr_->var}, {ret.ptr_->var});
} else {
DAGEngine::Get()->Push(func, lhs.ctx(), {lhs.ptr_->var, rhs.ptr_->var}, {ret.ptr_->var});
}
break;
}
#endif
default: LOG(FATAL) << "GPU is not enabled";
}
}

/*!
* \brief run a binary operation
* \param lhs left operand
* \param rhs right operand
* \param out the output narray
* \param binary_op the real
*/
template<typename OP>
inline void ScalarOp(const NArray &lhs,
const real_t &rhs,
NArray *out) {
if (out->is_none()) {
*out = NArray(OP::GetShape(lhs.shape(), lhs.shape()), lhs.ctx(), true);
} else {
CHECK(out->ctx() == lhs.ctx()) << "target context mismatch";
CHECK(out->shape() == OP::GetShape(lhs.shape(), lhs.shape()))
<< "target shape mismatch";
}
// important: callback must always capture by value
NArray ret = *out;
// redirect everything to mshadow operations
switch (lhs.ctx().dev_mask) {
case cpu::kDevMask: {
auto func = [lhs, rhs, ret](RunContext ctx) {
ret.ptr_->CheckAndAlloc();
TBlob tmp = ret.data();
narray::Eval<cpu, OP>(lhs.data(), rhs, &tmp, ctx);
};
if (lhs.ptr_->var == ret.ptr_->var) {
DAGEngine::Get()->Push(func, lhs.ctx(), {}, {ret.ptr_->var});
} else {
DAGEngine::Get()->Push(func, lhs.ctx(), {lhs.ptr_->var}, {ret.ptr_->var});
}
break;
}
#if MXNET_USE_CUDA
case gpu::kDevMask: {
auto func = [lhs, rhs, ret](RunContext ctx) {
ret.ptr_->CheckAndAlloc();
TBlob tmp = ret.data();
narray::Eval<gpu, OP>(lhs.data(), rhs, &tmp, ctx);
};
if (lhs.ptr_->var == ret.ptr_->var) {
DAGEngine::Get()->Push(func, lhs.ctx(), {}, {ret.ptr_->var});
} else {
DAGEngine::Get()->Push(func, lhs.ctx(), {lhs.ptr_->var}, {ret.ptr_->var});
}
break;
}
#endif
default: LOG(FATAL) << "GPU is not enabled";
}
Expand Down Expand Up @@ -120,13 +194,28 @@ inline NArray BinaryOpRet(const NArray &lhs,
return ret;
}

template<typename OP>
inline NArray ScalarOpRet(const NArray &lhs,
const real_t &rhs) {
NArray ret;
ScalarOp<OP>(lhs, rhs, &ret);
return ret;
}

template<typename OP>
inline NArray &BinaryOpApply(NArray *dst,
const NArray &src) {
BinaryOp<OP>(*dst, src, dst);
return *dst;
}

template<typename OP>
inline NArray &ScalarOpApply(NArray *dst,
const real_t &src) {
ScalarOp<OP>(*dst, src, dst);
return *dst;
}
// Binary
NArray operator+(const NArray &lhs, const NArray &rhs) {
return BinaryOpRet<narray::Plus>(lhs, rhs);
}
Expand All @@ -139,7 +228,20 @@ NArray operator*(const NArray &lhs, const NArray &rhs) {
NArray operator/(const NArray &lhs, const NArray &rhs) {
return BinaryOpRet<narray::Div>(lhs, rhs);
}

// Scalar
NArray operator+(const NArray &lhs, const real_t &rhs) {
return ScalarOpRet<narray::Plus>(lhs, rhs);
}
NArray operator-(const NArray &lhs, const real_t &rhs) {
return ScalarOpRet<narray::Minus>(lhs, rhs);
}
NArray operator*(const NArray &lhs, const real_t &rhs) {
return ScalarOpRet<narray::Mul>(lhs, rhs);
}
NArray operator/(const NArray &lhs, const real_t &rhs) {
return ScalarOpRet<narray::Div>(lhs, rhs);
}
// Binary
NArray &NArray::operator+=(const NArray &src) {
return BinaryOpApply<narray::Plus>(this, src);
}
Expand All @@ -152,6 +254,19 @@ NArray &NArray::operator*=(const NArray &src) {
NArray &NArray::operator/=(const NArray &src) {
return BinaryOpApply<narray::Div>(this, src);
}
// Scalar
NArray &NArray::operator+=(const real_t &src) {
return ScalarOpApply<narray::Plus>(this, src);
}
NArray &NArray::operator-=(const real_t &src) {
return ScalarOpApply<narray::Minus>(this, src);
}
NArray &NArray::operator*=(const real_t &src) {
return ScalarOpApply<narray::Mul>(this, src);
}
NArray &NArray::operator/=(const real_t &src) {
return ScalarOpApply<narray::Div>(this, src);
}

void NArray::Save(dmlc::Stream *strm) const {
// save shape
Expand Down Expand Up @@ -223,6 +338,11 @@ MXNET_REGISTER_NARRAY_FUN(_minus).set_function(BinaryOp<narray::Minus>);
MXNET_REGISTER_NARRAY_FUN(_mul).set_function(BinaryOp<narray::Mul>);
MXNET_REGISTER_NARRAY_FUN(_div).set_function(BinaryOp<narray::Div>);

///////
MXNET_REGISTER_NARRAY_FUN(_plus_scalar).set_function(ScalarOp<narray::Plus>);
MXNET_REGISTER_NARRAY_FUN(_minus_scalar).set_function(ScalarOp<narray::Minus>);
MXNET_REGISTER_NARRAY_FUN(_mul_scalar).set_function(ScalarOp<narray::Mul>);
MXNET_REGISTER_NARRAY_FUN(_div_scalar).set_function(ScalarOp<narray::Div>);
// copy function is special
// that we need to remove kAcceptEmptyMutateTarget from it
MXNET_REGISTER_NARRAY_FUN(_copyto)
Expand Down
Loading

0 comments on commit 36b7e8f

Please sign in to comment.