Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

a bit more minor change #16

Merged
merged 1 commit into from
Aug 15, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 0 additions & 37 deletions include/mxnet/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,43 +41,6 @@ typedef mshadow::index_t index_t;
/*! \brief data type that will be used to store ndarray */
typedef mshadow::default_real_t real_t;

/*! \brief context information about the execution enviroment */
struct Context {
/*! \brief the device type we run the op can be cpu::kDevMask or gpu::kDevMask */
int dev_mask;
/*! \brief device id we are going to run it on */
int dev_id;
/*! \brief constructor */
Context() : dev_mask(cpu::kDevMask), dev_id(0) {}
/*!
* \brief constructor of context
* \param dev_mask the device mask
* \param dev_id the device id
*/
Context(int dev_mask, int dev_id)
: dev_mask(dev_mask), dev_id(dev_id) {}
/*!
* \brief check if current context equals another one
* \param b another context to compare
* \return whether dev mask and id are same
*/
inline bool operator==(const Context &b) const {
return dev_mask == b.dev_mask && dev_id == b.dev_id;
}
};


/*!
* \brief execution context provides the information needed
* in runtime to actually execute the operation
*/
struct RunContext {
/*!
* \brief the stream of the device, can be NULL or Stream<gpu>* in GPU mode
*/
void *stream;
};

/*! \brief dynamic shape type */
typedef mshadow::TShape TShape;
/*! \brief storage container type */
Expand Down
80 changes: 80 additions & 0 deletions include/mxnet/context.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*!
* Copyright (c) 2015 by Contributors
* \file context.h
* \brief Context information and resources in mxnet.
*/
#ifndef MXNET_CONTEXT_H_
#define MXNET_CONTEXT_H_

namespace mxnet {

/*! \brief Context information about the execution enviroment */
struct Context {
/*! \brief the device type we run the op can be cpu::kDevMask or gpu::kDevMask */
int dev_mask;
/*! \brief device id we are going to run it on */
int dev_id;
/*! \brief constructor */
Context() : dev_mask(cpu::kDevMask), dev_id(0) {}
/*!
* \brief constructor of context
* \param dev_mask the device mask
* \param dev_id the device id
*/
Context(int dev_mask, int dev_id)
: dev_mask(dev_mask), dev_id(dev_id) {}
/*!
* \brief check if current context equals another one
* \param b another context to compare
* \return whether dev mask and id are same
*/
inline bool operator==(const Context &b) const {
return dev_mask == b.dev_mask && dev_id == b.dev_id;
}
};

/*!
* \brief execution time context.
* The information needed in runtime for actual execution.
*/
struct RunContext {
/*!
* \brief the stream of the device, can be NULL or Stream<gpu>* in GPU mode
*/
void *stream;
};

/*!
* \brief Additional resources
*/
struct Resource {
/*! \brief Resource type, indicating what the pointer type is */
enum Type {
/*! \brief mshadow::Random<xpu> object */
kRandom,
/*! \brief Temporal space */
kTempSpace
};
/*! \brief pointer to the resource */
void *ptr;
};

/*!
* \brief The resources that can be requested by Operator
*/
struct ResourceRequest {
/*! \brief type of resources */
Resource::Type type;
/*! \brief size requirment if it is an temp space request */
size_t space_size;
/*! \brief default constructor */
ResourceRequest() {}
/*!
* \brief default constructor, allow implicit conversion
* \param type type of resources
*/
ResourceRequest(Resource::Type type) : type(type) {} // NOLINT(*)
};

} // namespace mxnet
#endif // MXNET_CONTEXT_H_
1 change: 1 addition & 0 deletions include/mxnet/dag_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <functional>
#include <vector>
#include "./base.h"
#include "./context.h"

namespace mxnet {
/*!
Expand Down
2 changes: 2 additions & 0 deletions include/mxnet/narray.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
#include <dmlc/logging.h>
#include <memory>
#include "./base.h"
#include "./context.h"
#include "./storage.h"
#include "./context.h"
#include "./dag_engine.h"
// check c++11
#if DMLC_USE_CXX11 == 0
Expand Down
83 changes: 67 additions & 16 deletions include/mxnet/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,9 @@
#include <string>
#include <utility>
#include "./base.h"
#include "./context.h"

namespace mxnet {
/*! \brief option to pass into the forward function */
struct Option {
/*! \brief whether it is training phase*/
int is_train;
};

/*! \brief operation request type to Forward and Backward */
enum OpReqType {
/*! \brief no operation, do not write anything */
Expand All @@ -36,6 +31,28 @@ enum OpReqType {
kAddTo
};

/*!
* \brief All the possible information needed by Operator.Forward and Backward
* This is the superset of RunContext.
* We use this data structure to bookkeep everything needed by Forward and Backward.
* \sa Resource
*/
struct OpContext {
/*! \brief whether it is training phase */
int is_train;
/*! \brief Stream we are running on */
void *stream;
/*! \brief Resources requested by the operator */
std::vector<Resource> requested;
/*!
* \brief set the RunContext related parts
* \param ctx the context
*/
inline void SetRunContext(const RunContext &ctx) {
stream = ctx.stream;
}
};

/*!
* \brief Operator interface.
* Operator defins basic operation unit of optimized computation graph in mxnet.
Expand All @@ -54,30 +71,28 @@ class Operator {
virtual ~Operator() {}
/*!
* \brief perform a forward operation of Operator, save the output to TBlob.
* \param opt option on Forward such as whether this is training phase.
* \param ctx runtime context
* \param ctx runtime context available to this call
* \param in_data array of input data, it is const
* \param req the request types of saving operation, can only be kWriteTo or kWriteInplace.
* \param out_data array of output data, pointer is used to indicate that this is holder
* the space of TBlob in out_data must be pre-allocated with InferShape
* \sa OpReqType
* \sa OpReqType, OpContext
*/
virtual void Forward(Option opt,
RunContext ctx,
virtual void Forward(const OpContext &ctx,
const std::vector<TBlob> &in_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &out_data) = 0;
/*!
* \brief Perform a backward Operation, write gradient to the in_grad.
* \param ctx runtime context
* \brief Perform a Backward Operation, write gradient to the in_grad.
* \param ctx runtime context available to this call
* \param out_grad the gradient value we get from output of the Operator
* \param in_data the array of input data.
* \param out_data the array of output data.
* \param req request types of the saving operation, can be all types.
* \param in_grad the array of gradient we need to write to.
* \sa OpReqType
* \sa OpReqType, OpContext
*/
virtual void Backward(RunContext ctx,
virtual void Backward(const OpContext &ctx,
const std::vector<TBlob> &out_grad,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
Expand Down Expand Up @@ -114,10 +129,25 @@ class OperatorProperty {
virtual std::vector<std::string> ListReturns() const {
return {"output"};
}
/*! \return number of outputs of the Operator */
/*! \return number of real return values of the Operator */
virtual int NumReturns() const {
return 1;
}
/*!
* \brief get number of visible return values during Symbol creation.
* If NumVisibleReturns() = k, and NumReturns() = n.
* The first k returns will be presented in the resulting symbol.
*
* The rest of the returns can be used for auxiliary states for Backward.
* For example, Dropout will return [data, mask], with NumVisibleReturns() == 1.
* So when user call sym = Dropout(input), only data is presented in sym.
* But all the returns will be presented in out_data parameter of Backward if requested.
*
* \return number of default return values
*/
virtual int NumVisibleReturns() const {
return NumReturns();
}
/*!
* \brief Set the parameters of the Operator.
* \param name parameter name
Expand Down Expand Up @@ -154,6 +184,27 @@ class OperatorProperty {
* subclasses override this function.
*/
virtual std::string TypeString() const = 0;
//--------------------------------------------------------
// All the below functions are optional to override.
//--------------------------------------------------------
/*!
* \brief Declare additional resource required in forward pass.
* These additional resources will be presented in OpContext.requested
* in the same order of the returned Resource.
* \return Additional resource request
*/
virtual std::vector<ResourceRequest> ForwardResource() const {
return std::vector<ResourceRequest>();
}
/*!
* \brief Decalre additional resource required in backward pass.
* These additional resources will be presented in OpContext.requested
* in the same order of the returned Resource.
* \return Additional resource request
*/
virtual std::vector<ResourceRequest> BackwardResource() const {
return std::vector<ResourceRequest>();
}
/*!
* \brief Declare the input requirement of Backward pass.
*
Expand Down
1 change: 1 addition & 0 deletions include/mxnet/storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#ifndef MXNET_STORAGE_H_
#define MXNET_STORAGE_H_
#include "./base.h"
#include "./context.h"

namespace mxnet {
/*! \brief memory allocator of storage */
Expand Down
1 change: 1 addition & 0 deletions src/narray/narray_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <dmlc/logging.h>
#include <mshadow/tensor.h>
#include <mxnet/base.h>
#include <mxnet/context.h>

namespace mxnet {
/*! \brief namespace to support all possible NArray operator */
Expand Down
5 changes: 2 additions & 3 deletions src/operator/fully_connected-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ class FullyConnectedOp : public Operator {
this->param_ = p;
}

virtual void Forward(Option opt,
RunContext ctx,
virtual void Forward(const OpContext &ctx,
const std::vector<TBlob> &in_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &out_data) {
Expand All @@ -57,7 +56,7 @@ class FullyConnectedOp : public Operator {
}
}

virtual void Backward(RunContext ctx,
virtual void Backward(const OpContext &ctx,
const std::vector<TBlob> &out_grad,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
Expand Down
2 changes: 1 addition & 1 deletion src/symbol/symbol.cc
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ bool Symbol::InferShape(std::vector<TShape> *in_shape,
Symbol Symbol::Create(OperatorProperty *op) {
// use special representation for atomic symbol
auto node = std::make_shared<Node>(op, "");
size_t nret = op->NumReturns();
size_t nret = op->NumVisibleReturns();
Symbol s;
for (uint32_t i = 0; i < nret; ++i) {
s.heads_.push_back(DataEntry(node, i));
Expand Down