Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #27 from mavenlin/master
Browse files Browse the repository at this point in the history
symbolic autodiff
  • Loading branch information
antinucleon committed Aug 23, 2015
2 parents 8d95f52 + 1e0459c commit 6812764
Show file tree
Hide file tree
Showing 10 changed files with 169 additions and 53 deletions.
9 changes: 8 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,14 @@ include $(DMLC_CORE)/make/dmlc.mk
# all tge possible warning tread
WARNFLAGS= -Wall
CFLAGS = -DMSHADOW_FORCE_STREAM $(WARNFLAGS)
CFLAGS += -g -O3 -I./mshadow/ -I./dmlc-core/include -fPIC -Iinclude $(MSHADOW_CFLAGS)

# CFLAGS for debug
ifeq ($(DEBUG),0)
CFLAGS += -O3
else
CFLAGS += -g -O0
endif
CFLAGS += -I./mshadow/ -I./dmlc-core/include -fPIC -Iinclude $(MSHADOW_CFLAGS)
LDFLAGS = -pthread $(MSHADOW_LDFLAGS) $(DMLC_LDFLAGS)
NVCCFLAGS = --use_fast_math -g -O3 -ccbin $(CXX) $(MSHADOW_NVCCFLAGS)
ROOTDIR = $(CURDIR)
Expand Down
13 changes: 13 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,19 @@ MXNET_DLL int MXSymbolCompose(SymbolHandle sym,
mx_uint num_args,
const char** keys,
SymbolHandle* args);
/*!
* \brief Get the gradient graph of the symbol
*
* \param sym the symbol to get gradient
* \param num_wrt number of arguments to get gradient
* \param wrt the name of the arguments to get gradient
* \param out the returned symbol that has gradient
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolGrad(SymbolHandle sym,
mx_uint num_wrt,
const char** wrt,
SymbolHandle* out);
/*!
* \brief infer shape of unknown input shapes given the known one.
* The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data
Expand Down
11 changes: 9 additions & 2 deletions include/mxnet/symbolic.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* \file symbolic.h
* \brief Symbolic interface of mxnet.
* \author Min Lin, Bing Xu
*/
*/
#ifndef MXNET_SYMBOLIC_H_
#define MXNET_SYMBOLIC_H_

Expand Down Expand Up @@ -161,7 +161,7 @@ class StaticGraph {
* \param arg_grads used to store gradients to args, can be multiple one if an argument is used by operator
*/
void MakeBackwardPass(std::vector<uint32_t> *head_grad_nodes,
std::vector<std::vector<DataEntry> > *arg_grads);
std::vector<DataEntry> *arg_grads);

/*!
* \brief create a sum node that aggregates gradient together
Expand Down Expand Up @@ -254,6 +254,13 @@ class Symbol {
*/
Symbol operator () (const std::unordered_map<std::string, Symbol>& kwargs,
const std::string& name) const;
/*!
* \brief get the gradient graph
* \param wrt with respect to the input
* \return the new symbol with gradient graph
*/
Symbol Grad(const std::vector<std::string>& wrt) const;

/*!
* \brief infer the shapes of outputs and unknown input arguments
* \param arg_shapes the shape of input arguments of the operator
Expand Down
3 changes: 3 additions & 0 deletions make/config.mk
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ export CC = gcc
export CXX = g++
export NVCC = nvcc

# whether compile with debug
DEBUG = 0

# whether use CUDA during compile
USE_CUDA = 0

Expand Down
15 changes: 15 additions & 0 deletions python/mxnet/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,21 @@ def bind(self, ctx, args, args_grad, reqs):
ctypes.byref(handle)))
return Executor(handle)

def grad(self, wrt):
"""get the autodiff of current symbol.
Parameters
----------
wrt: Array of String
keyword arguments of the symbol that the gradients are taken.
"""
handle = SymbolHandle()
c_wrt = c_array(ctypes.c_char_p, [c_str(key) for key in wrt])
check_call(_LIB.MXSymbolGrad(self.handle,
mx_uint(len(wrt)),
c_wrt,
ctypes.byref(handle)))
return Symbol(handle)

def Variable(name):
"""Create a symbolic variable with specified name.
Expand Down
14 changes: 13 additions & 1 deletion src/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,19 @@ int MXSymbolCompose(SymbolHandle sym,
API_END();
}

int MXSymbolGrad(SymbolHandle sym, mx_uint num_wrt, const char** wrt, SymbolHandle* out) {
API_BEGIN();
Symbol* s = static_cast<Symbol*>(sym);
std::vector<std::string> wrts(num_wrt);
for (mx_uint i = 0; i < num_wrt; ++i) {
wrts[i] = wrt[i];
}
Symbol* ret = new Symbol;
*ret = s->Grad(wrts);
*out = ret;
API_END();
}

int MXSymbolInferShape(SymbolHandle sym,
mx_uint num_args,
const char** keys,
Expand Down Expand Up @@ -596,4 +609,3 @@ int MXExecutorBind(SymbolHandle symbol_handle,
*out = Executor::Bind(*symb, ctx, in_args_vec, arg_grad_vec, grad_req_vec);
API_END();
}

26 changes: 9 additions & 17 deletions src/symbol/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
* Copyright (c) 2015 by Contributors
* \file graph_executor.cc
* \brief Executor to execute the Graph.
*/
*/
#include <dmlc/logging.h>
#include <mxnet/symbolic.h>
#include <memory>
Expand Down Expand Up @@ -200,7 +200,7 @@ GraphExecutor::GetOpExecEntry(uint32_t nid) {
}

void GraphExecutor::InitGraph(Symbol symbol, Context ctx, bool need_backward) {
// initialize all internal daa structures
// initialize all internal data structures
symbol.ToStaticGraph(&graph_);
num_forward_nodes_ = graph_.nodes.size();
if (need_backward) {
Expand Down Expand Up @@ -252,21 +252,13 @@ void GraphExecutor::InitDataEntryInfo(const std::vector<NArray> &in_args,
if (grad_req_type[i] == kNullOp) continue;
CHECK_NE(grad_req_type[i], kWriteInplace)
<< "Gradient request can only be nullop, add, write";
std::vector<StaticGraph::DataEntry> &grad_source = arg_grads_[i];
CHECK_GE(grad_source.size(), 1);
// TODO(bing) add a aggregation node here
if (grad_source.size() > 1) {
CHECK_EQ(grad_req_type[i], kAddTo)
<< "The gradient contains multiple variables,";
}
for (StaticGraph::DataEntry e : grad_source) {
DataEntryInfo &info = op_nodes_[e.source_id].outputs[e.index];
info.type = kBindByExternal;
info.op_req = grad_req_type[i];
info.data = arg_grad_store[i];
++info.ref_count;
op_nodes_[e.source_id].activated = true;
}
StaticGraph::DataEntry &grad_source = arg_grads_[i];
DataEntryInfo &info = op_nodes_[grad_source.source_id].outputs[grad_source.index];
info.type = kBindByExternal;
info.op_req = grad_req_type[i];
info.data = arg_grad_store[i];
++info.ref_count;
op_nodes_[grad_source.source_id].activated = true;
}
// setup head gradient
for (uint32_t nid : head_grad_nodes_) {
Expand Down
2 changes: 1 addition & 1 deletion src/symbol/graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ class GraphExecutor : public Executor {
// head gradient node in the graph, if there is backward pass
std::vector<uint32_t> head_grad_nodes_;
// argument node in the graph, if there is backward pass
std::vector<std::vector<StaticGraph::DataEntry> > arg_grads_;
std::vector<StaticGraph::DataEntry> arg_grads_;
// operational nodes
std::vector<OpNode> op_nodes_;
// head NArrays
Expand Down
14 changes: 12 additions & 2 deletions src/symbol/static_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ StaticGraph::Node StaticGraph::CreateSumNode(
}

void StaticGraph::MakeBackwardPass(std::vector<uint32_t> *head_grad_nodes,
std::vector<std::vector<DataEntry> > *arg_grads) {
std::vector<DataEntry> *arg_grads) {
arg_grads->clear();
head_grad_nodes->clear();
// get topo order of nodes, before new nodes are added
Expand Down Expand Up @@ -254,7 +254,17 @@ void StaticGraph::MakeBackwardPass(std::vector<uint32_t> *head_grad_nodes,
DataEntry odata(arg_nodes[i], 0);
auto it = grad_map.find(odata);
CHECK(it != grad_map.end()) << "bad graph";
arg_grads->at(i) = it->second;
if (it->second.size() == 1) {
arg_grads->at(i) = it->second[0];
} else {
std::ostringstream os_name;
Node agg_node = StaticGraph::CreateSumNode(it->second);
os_name << nodes[arg_nodes[i]].name << "_grad_agg";
agg_node.name = os_name.str();
uint32_t agg_node_id = static_cast<uint32_t>(nodes.size());
nodes.push_back(std::move(agg_node));
arg_grads->at(i) = DataEntry(agg_node_id, 0);
}
}
}
} // namespace mxnet
Loading

0 comments on commit 6812764

Please sign in to comment.