Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conv2d Per EG Gradient Weights #2

Open
wants to merge 26 commits into
base: myelin
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions nnvm/include/nnvm/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ class Graph {
public:
/*! \brief outputs of the computation graph. */
std::vector<NodeEntry> outputs;
/*! \brief extra outputs of the computation graph. */
std::vector<NodeEntry> extra_outputs;
/*! \brief number of visible outputs of the computation graph. */
uint32_t num_vis_outputs;
/*!
* \brief attributes of a graph
* Note that attribute is shared pointer and can be shared across graphs.
Expand Down
15 changes: 15 additions & 0 deletions nnvm/include/nnvm/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,21 @@ using FInplaceIdentity = std::function<std::vector<bool> (const NodeAttrs& attrs
using FIgnoreInputs = std::function<
std::vector<uint32_t> (const NodeAttrs& attrs)>;

/*!
* \brief Get the output node of the op node
* This function generates the graph of the node
* \param nodeptr The current node
* \param inputs The current node's inputs
* \param input_shapes The current node's input shapes
* \return the Node's outputs
*
* \note Register under "FExpandCompute"
*/
using FExpandCompute = std::function<std::vector<NodeEntry>(
const NodePtr& nodeptr,
const std::vector<NodeEntry>& inputs,
const std::vector<TShape>& input_shapes)>;

/*!
* \brief Get the gradient node of the op node
* This function generates the backward graph of the node
Expand Down
2 changes: 2 additions & 0 deletions nnvm/include/nnvm/symbolic.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class NNVM_DLL Symbol {

/*! \brief output entries contained in the symbol */
std::vector<NodeEntry> outputs;
/*! \brief additional output entries contained in the symbol */
std::vector<NodeEntry> extra_outputs;

/*!
* \brief Copy the symbol.
Expand Down
34 changes: 34 additions & 0 deletions nnvm/include/nnvm/top/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,19 @@ struct DropoutParam : public dmlc::Parameter<DropoutParam> {
};

struct BatchNormParam : public dmlc::Parameter<BatchNormParam> {
bool training;
int axis;
double epsilon;
double momentum;
bool center;
bool scale;

DMLC_DECLARE_PARAMETER(BatchNormParam) {
DMLC_DECLARE_FIELD(training).set_default(0)
.describe("If True, output will be centered and the moving statistics "
"will be updated.");
DMLC_DECLARE_FIELD(momentum).set_default(0.1)
.describe("Dampening parameter for moving_(mean|var) when training.");
DMLC_DECLARE_FIELD(axis).set_default(1)
.describe("Specify which shape axis the channel is specified.");
DMLC_DECLARE_FIELD(epsilon).set_default(1e-5)
Expand All @@ -71,6 +77,34 @@ struct BatchNormParam : public dmlc::Parameter<BatchNormParam> {
static const constexpr int kMovingVariance = 4;
};

struct InstanceNormParam : public dmlc::Parameter<InstanceNormParam> {
int axis;
double epsilon;
double momentum;
bool center;
bool scale;

DMLC_DECLARE_PARAMETER(InstanceNormParam) {
DMLC_DECLARE_FIELD(axis).set_default(1)
.describe("Specify which shape axis the channel is specified.");
DMLC_DECLARE_FIELD(momentum).set_default(0.1)
.describe("Dampening parameter for moving_(mean|var) when training.");
DMLC_DECLARE_FIELD(epsilon).set_default(1e-5)
.describe("Small float added to variance to avoid dividing by zero.");
DMLC_DECLARE_FIELD(center).set_default(true)
.describe("If True, add offset of `beta` to normalized tensor."
"If False, `beta` is ignored.");
DMLC_DECLARE_FIELD(scale).set_default(true)
.describe("If True, multiply by `gamma`. If False, `gamma` is not used."
"When the next layer is piecewise linear (also e.g. `nn.relu`),"
"this can be disabled since the scaling"
"will be done by the next layer.");
}
// constants
static const constexpr int kData = 0;
static const constexpr int kGamma = 1;
static const constexpr int kBeta = 2;
};

// Shared by softmax and log_softmax
struct SoftmaxParam : public dmlc::Parameter<SoftmaxParam> {
Expand Down
7 changes: 7 additions & 0 deletions nnvm/python/nnvm/compiler/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,11 @@ def optimize(graph, shape, dtype="float32", layout=None):
# pylint: disable=unused-argument
cfg = BuildConfig.current

if any(map(cfg.pass_enabled,
["AlterOpLayout", "SimplifyInference", "FoldScaleAxis"])):
graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph.apply(["InferShape", "ExpandCompute"])

if cfg.pass_enabled("AlterOpLayout"):
layout = layout if layout else {}
graph = graph_attr.set_layout_inputs(graph, layout)
Expand Down Expand Up @@ -276,6 +281,8 @@ def build(graph, target=None, shape=None, dtype="float32",
init_var = {}
if _all_var_init:
init_var = initialize_variables(shape, dtype)
# Expand FComputes
graph = graph.apply("InferShape").apply("ExpandCompute")
# Apply optimization
with target:
graph = optimize(graph, shape, dtype, layout)
Expand Down
17 changes: 17 additions & 0 deletions nnvm/python/nnvm/compiler/graph_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,26 @@ def get_gradient_graph(ys, xs, grad_ys=None):
ret : Graph
Generated gradient graph.
"""
if isinstance(xs, list):
xs = Group(xs)
if isinstance(ys, list):
ys = Group(ys)
g = create(ys)

# expand the graph then translate replace each requested grad symbols with
# its counterpart in the expanded graph
g = g.apply('InferShape').apply('ExpandCompute')
def _get_outs(s):
return [s[i] for i in range(len(s.list_output_names()))]
g_name2sym = {s.attr('name'): s for s in _get_outs(g.symbol.get_internals())}
def _tx_symb(symb):
tx = [g_name2sym[s.attr('name')] for s in _get_outs(symb)]
if len(tx) > 1:
return Group(tx)
return tx[0]
ys = _tx_symb(ys)
xs = _tx_symb(xs)

g._set_symbol_list_attr('grad_ys', ys)
g._set_symbol_list_attr('grad_xs', xs)
ny = len(ys.list_output_names())
Expand Down
5 changes: 4 additions & 1 deletion nnvm/python/nnvm/compiler/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,10 @@ def minimize(self, obj, var=None):
g = self.rescale_grad * g
if self.clip_gradient is not None:
g = sym.clip(g, a_min=-1 * self.clip_gradient, a_max=self.clip_gradient)
updates.append(sym._assign(v, v - lr_t * (g + self.wd * v)))
step = g
if self.wd:
step += self.wd * v
updates.append(sym._assign(v, v - lr_t * step))
return sym.Group(updates)


Expand Down
7 changes: 7 additions & 0 deletions nnvm/python/nnvm/top/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,15 @@ def schedule_conv2d(attrs, outs, target):
@reg.register_alter_op_layout("conv2d")
def alter_conv2d_layout(attrs, inputs, tinfos):
return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos)
@reg.register_schedule("_conv2d_grad_weight")
def schedule_conv2d_grad_weight(attrs, outs, target):
"""Schedule definition of conv2d_grad_weight"""
assert attrs["layout"] == "NCHW" and attrs.get_int("groups") == 1
with tvm.target.create(target):
return topi.generic.schedule_conv2d_grad_weight_nchw(outs)

reg.register_pattern("conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
reg.register_pattern("_conv2d_grad_weight", OpPattern.OUT_ELEMWISE_FUSABLE)

# convolution NCHWc
@reg.register_compute("_contrib_conv2d_NCHWc")
Expand Down
17 changes: 16 additions & 1 deletion nnvm/src/c_api/c_api_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,22 @@ using namespace nnvm;
int NNGraphCreate(SymbolHandle symbol, GraphHandle *graph) {
Graph* g = new Graph();
API_BEGIN();
g->outputs = static_cast<Symbol*>(symbol)->outputs;
Symbol* s = static_cast<Symbol*>(symbol); // may actually be a Graph
std::unordered_set<std::string> g_outputs_names;
std::unordered_set<std::string> g_extra_outputs_names;
for (const auto& out : s->outputs) {
if (g_outputs_names.find(out.node->attrs.name) == g_outputs_names.end()) {
g_outputs_names.insert(out.node->attrs.name);
}
g->outputs.push_back(out);
}
for (const auto& out : s->extra_outputs) {
if (g_outputs_names.find(out.node->attrs.name) == g_outputs_names.end()) {
g_outputs_names.insert(out.node->attrs.name);
}
g->outputs.push_back(out);
}
g->num_vis_outputs = s->outputs.size();
*graph = g;
API_END_HANDLE_ERROR(delete g);
}
Expand Down
72 changes: 72 additions & 0 deletions nnvm/src/compiler/expand_fcompute.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*!
* Copyright (c) 2018 by Contributors
* \file expand_fcompute.cc
* \author Nick Hynes
*/
#include <nnvm/graph.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/graph_attr_types.h>
#include <nnvm/pass.h>
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/top/nn.h>
#include "./graph_transform.h"
#include "./pattern_util.h"

namespace nnvm {
namespace compiler {

Graph ExpandCompute(nnvm::Graph src) {
const IndexedGraph& idx = src.indexed_graph();
std::map<std::string, TShape> name2shape;
const ShapeVector& shape_vec = src.GetAttr<ShapeVector>("shape");
for (uint32_t i = 0, j = 0; i < idx.num_nodes(); ++i) {
const Node* src = idx[i].source;
uint32_t num_outputs = src->num_outputs();
std::string name = src->attrs.name;
// CHECK(name2shape.count(name) == 0 || name2shape[name] == shape_vec[j])
// << "Reassigning shape of " << name << ". prev: "
// << shape_vec[j] << ", new: " << name2shape[name];
name2shape[name] = shape_vec[j];
j += num_outputs;
}
bool needs_expand = false;
auto transform = [&](uint32_t nid, const NodePtr& n, std::vector<NodeEntry>* ret) {
static auto& fcompute = Op::GetAttr<FExpandCompute>("FExpandCompute");
if (!fcompute.count(n->op())) return false;
std::vector<TShape> input_shapes;
for (const NodeEntry& inp : n->inputs) {
CHECK_GT(name2shape.count(inp.node->attrs.name), 0)
<< "Input " << inp.node->attrs.name << " as input to "
<< n->attrs.name << " does not exist.";
input_shapes.push_back(name2shape[inp.node->attrs.name]);
}
std::vector<NodeEntry> exp = fcompute[n->op()](n, n->inputs, input_shapes);
needs_expand = true;
*ret = exp;
return true;
};

// preserve input shapes
Graph egraph = GraphTransform(src, transform);
const IndexedGraph& eidx = egraph.indexed_graph();
ShapeVector ishapes;
for (const auto& nid : eidx.input_nodes()) {
std::string name = eidx[nid].source->attrs.name;
if (name2shape.count(name)) {
ishapes.push_back(name2shape[name]);
} else {
ishapes.emplace_back();
}
}
egraph.attrs["shape_inputs"] = std::make_shared<any>(std::move(ishapes));

if (needs_expand)
return ApplyPasses(egraph, {"InferShape", "ExpandCompute"});
return egraph;
}

NNVM_REGISTER_PASS(ExpandCompute)
.set_body(ExpandCompute);

} // namespace compiler
} // namespace nnvm
1 change: 1 addition & 0 deletions nnvm/src/compiler/graph_transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ Graph GraphTransform(Graph graph, FTransform ftransform) {
ret.outputs.push_back(graph.outputs[i]);
}
}
ret.num_vis_outputs = graph.num_vis_outputs;
return ret;
}

Expand Down
2 changes: 2 additions & 0 deletions nnvm/src/compiler/simplify_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ Graph SimplifyInference(nnvm::Graph src) {
static const Op* bn_op = Op::Get("batch_norm");
static const Op* dropout_op = Op::Get("dropout");
if (n->op() == bn_op) {
const auto& param = nnvm::get<top::BatchNormParam>(n->attrs.parsed);
if (param.training) return false;
*ret = BatchNormToInferUnpack(
n->attrs,
n->inputs[0],
Expand Down
24 changes: 20 additions & 4 deletions nnvm/src/core/symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,12 +292,16 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
if (garg_idx.empty() || std::find(garg_idx.begin(), garg_idx.end(), i) == garg_idx.end())
CHECK_EQ(args[i]->outputs.size(), 1U)
<< "Argument " << i << " is a tuple, single value is required";
for (const auto& extra_out : args[i]->extra_outputs)
extra_outputs.push_back(extra_out);
}
for (const auto& kv : kwargs) {
if (garg_names.empty()
|| std::find(garg_names.begin(), garg_names.end(), kv.first) == garg_names.end())
CHECK_EQ(kv.second->outputs.size(), 1U)
<< "Keyword Argument " << kv.first << " is a tuple, single value is required";
for (const auto& extra_out : kv.second->extra_outputs)
extra_outputs.push_back(extra_out);
}
// assign new name
if (!name.empty()) outputs[0].node->attrs.name = name;
Expand Down Expand Up @@ -598,12 +602,17 @@ Symbol Symbol::CreateFunctor(const Op* op,
}

uint32_t nout = n->num_outputs();
uint32_t nvis = nout;
if (fnum_vis_output.count(n->op())) {
nout = fnum_vis_output[n->op()](n->attrs);
nvis = fnum_vis_output[n->op()](n->attrs);
}
for (uint32_t i = 0; i < nout; ++i) {
for (uint32_t i = 0; i < nvis; ++i) {
s.outputs.emplace_back(NodeEntry{n, i, 0});
}
s.extra_outputs = std::vector<NodeEntry>();
for (uint32_t i = nvis; i < nout; ++i) {
s.extra_outputs.emplace_back(NodeEntry{n, i, 0});
}
return s;
}

Expand All @@ -614,19 +623,26 @@ Symbol Symbol::CreateFunctor(const NodeAttrs& attrs) {
n->attrs = attrs;

uint32_t nout = n->num_outputs();
uint32_t nvis = nout;
if (fnum_vis_output.count(n->op())) {
nout = fnum_vis_output[n->op()](n->attrs);
nvis = fnum_vis_output[n->op()](n->attrs);
}
for (uint32_t i = 0; i < nout; ++i) {
for (uint32_t i = 0; i < nvis; ++i) {
s.outputs.emplace_back(NodeEntry{n, i, 0});
}
s.extra_outputs = std::vector<NodeEntry>();
for (uint32_t i = nvis; i < nout; ++i) {
s.extra_outputs.emplace_back(NodeEntry{n, i, 0});
}
return s;
}

Symbol Symbol::CreateGroup(const std::vector<Symbol> &symbols) {
Symbol ret;
std::unordered_set<std::string> ret_extra_outputs_names;
for (const auto &s : symbols) {
ret.outputs.insert(ret.outputs.end(), s.outputs.begin(), s.outputs.end());
ret.extra_outputs.insert(ret.extra_outputs.end(), s.extra_outputs.begin(), s.extra_outputs.end());
}
return ret;
}
Expand Down
Loading