Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
conv is able to work
Browse files Browse the repository at this point in the history
  • Loading branch information
antinucleon committed Aug 25, 2015
1 parent 4bd535f commit 1045c18
Show file tree
Hide file tree
Showing 6 changed files with 174 additions and 29 deletions.
6 changes: 4 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@ endif
#BIN = test/test_threaded_engine test/api_registry_test
OBJ = narray_function_cpu.o
# add threaded engine after it is done
OBJCXX11 = engine.o narray.o c_api.o operator.o symbol.o storage.o fully_connected_cpu.o static_graph.o activation_cpu.o graph_executor.o softmax_cpu.o elementwise_sum_cpu.o pooling_cpu.o convolution_cpu.o
OBJCXX11 = flatten_cpu.o engine.o narray.o c_api.o operator.o symbol.o storage.o fully_connected_cpu.o static_graph.o activation_cpu.o graph_executor.o softmax_cpu.o elementwise_sum_cpu.o pooling_cpu.o convolution_cpu.o
CUOBJ =
SLIB = lib/libmxnet.so
ALIB = lib/libmxnet.a
LIB_DEP = $(DMLC_CORE)/libdmlc.a

ifeq ($(USE_CUDA), 1)
CUOBJ += narray_function_gpu.o fully_connected_gpu.o activation_gpu.o elementwise_sum_gpu.o pooling_gpu.o softmax_gpu.o convolution_gpu.o
CUOBJ += flatten_gpu.o narray_function_gpu.o fully_connected_gpu.o activation_gpu.o elementwise_sum_gpu.o pooling_gpu.o softmax_gpu.o convolution_gpu.o
endif

.PHONY: clean all test lint doc
Expand Down Expand Up @@ -103,6 +103,8 @@ softmax_cpu.o: src/operator/softmax.cc
softmax_gpu.o: src/operator/softmax.cu
convolution_cpu.o: src/operator/convolution.cc
convolution_gpu.o: src/operator/convolution.cu
flatten_cpu.o: src/operator/flatten.cc
flatten_gpu.o: src/operator/flatten.cu

lib/libmxnet.a: $(OBJ) $(OBJCXX11) $(CUOBJ)
lib/libmxnet.so: $(OBJ) $(OBJCXX11) $(CUOBJ)
Expand Down
49 changes: 22 additions & 27 deletions python/test_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,14 @@
import numpy as np
import os, cPickle, gzip

def Softmax(x):
batch, nidden = x.shape
maxes = np.max(x, axis=1)
x -= maxes.reshape(batch, 1)
x = np.exp(x)
norm = np.sum(x, axis=1)
prob = x / norm.reshape((batch, 1))
return prob

def CalAcc(out, label):
pred = np.argmax(out, axis=1)
return np.sum(pred == label) * 1.0 / out.shape[0]

def SetGradient(out_grad, label):
assert(out_grad.shape[0] == label.shape[0])
for i in xrange(label.shape[0]):
k = label[i]
out_grad[i][k] -= 1.0

# load data
class MNISTIter(object):
def __init__(self, which_set, batch_size=100):
def __init__(self, which_set, batch_size=100, flatten=True):
if not os.path.exists('mnist.pkl.gz'):
os.system("wget http://deeplearning.net/data/mnist/mnist.pkl.gz")
f = gzip.open('mnist.pkl.gz', 'rb')
Expand All @@ -39,6 +25,7 @@ def __init__(self, which_set, batch_size=100):
else:
self.data = test_set[0]
self.data = np.asarray(test_set[1])
self.flatten = flatten
self.batch_size = batch_size
self.nbatch = self.data.shape[0] / batch_size
assert(self.data.shape[0] % batch_size == 0) # I am lazy
Expand All @@ -57,25 +44,34 @@ def Get(self):
raise Exception("Iterator is at end")
start = self.now_idx * self.batch_size
end = (self.now_idx + 1) * self.batch_size
return (self.data[start:end, :], self.label[start:end])
if self.flatten:
return (self.data[start:end, :], self.label[start:end])
else:
return (self.data[start:end, :].reshape(batch_size, 1, 28, 28),
self.label[start:end])



# symbol net
batch_size = 100
data = mx.symbol.Variable('data')
fc1 = mx.symbol.FullyConnected(data=data, name='fc1', num_hidden=160)
fc1 = mx.symbol.Convolution(data = data, name='conv1', nb_filter=32, kernel=(7,7), stride=(2,2), nstep=10, no_bias=1)
act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
fc2 = mx.symbol.FullyConnected(data = act1, name='fc2', num_hidden=10)
args_list = fc2.list_arguments()
mp = mx.symbol.Pooling(data = act1, name = 'mp', kernel=(2,2), stride=(2,2), pool_type='avg')
fl = mx.symbol.Flatten(data = mp, name="flatten")
fc2 = mx.symbol.FullyConnected(data = fl, name='fc2', num_hidden=10)
softmax = mx.symbol.Softmax(data = fc2, name = 'sm')
args_list = softmax.list_arguments()
# infer shape
data_shape = (batch_size, 784)
arg_shapes, out_shapes = fc2.infer_shape(data=data_shape)
#data_shape = (batch_size, 784)

data_shape = (batch_size, 1, 28, 28)
arg_shapes, out_shapes = softmax.infer_shape(data=data_shape)
arg_narrays = [mx.narray.create(shape) for shape in arg_shapes]
grad_narrays = [mx.narray.create(shape) for shape in arg_shapes]
mom_narrays = [mx.narray.create(shape) for shape in arg_shapes]
inputs = dict(zip(args_list, arg_narrays))

print zip(args_list, arg_shapes)
np.random.seed(0)
# set random weight
for name, narray in inputs.items():
Expand All @@ -87,7 +83,7 @@ def Get(self):
req = ['write_to' for i in range(len(arg_narrays))]
# bind executer
# TODO(bing): think of a better bind interface
executor = fc2.bind(mx.Context('cpu'), arg_narrays, grad_narrays, req)
executor = softmax.bind(mx.Context('cpu'), arg_narrays, grad_narrays, req)
# update

out_narray = executor.heads()[0]
Expand All @@ -104,8 +100,8 @@ def Update(mom, grad, weight):
block = zip(mom_narrays, grad_narrays, arg_narrays)


train = MNISTIter("train", batch_size)
valid = MNISTIter("valid", batch_size)
train = MNISTIter("train", batch_size, False)
valid = MNISTIter("valid", batch_size, False)

for i in xrange(epoch):
# train
Expand All @@ -115,11 +111,10 @@ def Update(mom, grad, weight):
while train.Next():
data, label = train.Get()
inputs["data"].numpy[:] = data
inputs["sm_label"].numpy[:] = label
executor.forward()
out_narray.numpy[:] = Softmax(out_narray.numpy)
train_acc += CalAcc(out_narray.numpy, label)
grad_narray.numpy[:] = out_narray.numpy
SetGradient(grad_narray.numpy, label)
executor.backward([grad_narray])

for mom, grad, weight in block:
Expand Down
101 changes: 101 additions & 0 deletions src/operator/flatten-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*!
* Copyright (c) 2015 by Contributors
* \file flatten-inl.h
* \brief
* \author Bing Xu
*/
#ifndef MXNET_OPERATOR_FLATTEN_INL_H_
#define MXNET_OPERATOR_FLATTEN_INL_H_

#include <dmlc/logging.h>
#include <dmlc/parameter.h>
#include <mxnet/operator.h>
#include <algorithm>
#include <map>
#include <vector>
#include <string>
#include <utility>
#include "./operator_common.h"

namespace mxnet {
namespace op {

enum FlattenOpInputs {kData};
enum FlattenOpOutputs {kOut};

template<typename xpu>
class FlattenOp : public Operator {
public:
virtual void Forward(const OpContext &ctx,
const std::vector<TBlob> &in_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &out_data) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(in_data.size(), 1);
CHECK_EQ(req.size(), 1);
CHECK_EQ(out_data.size(), 1);
Stream<xpu> *s = ctx.get_stream<xpu>();
Tensor<xpu, 4> data = in_data[kData].get<xpu, 4, real_t>(s);
Tensor<xpu, 4> out = out_data[kOut].get<xpu, 4, real_t>(s);
Assign(out, req[kOut], reshape(data, out.shape_));
}

virtual void Backward(const OpContext &ctx,
const std::vector<TBlob> &out_grad,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &in_grad) {
using namespace mshadow;
using namespace mshadow::expr;
Stream<xpu> *s = ctx.get_stream<xpu>();
Tensor<xpu, 4> grad_out = out_grad[kData].get<xpu, 4, real_t>(s);
Tensor<xpu, 4> grad_in = in_grad[kOut].get<xpu, 4, real_t>(s);
Assign(grad_in, req[kData], reshape(grad_out, grad_in.shape_));
}
}; // class FlattenOp

template<typename xpu>
Operator* CreateOp();

#if DMLC_USE_CXX11
class FlattenProp : public OperatorProperty {
public:
FlattenProp() {}

virtual void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) {}

virtual std::string TypeString() const {
return "Flatten";
}

virtual bool InferShape(std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape) const {
CHECK_EQ(in_shape->size(), 1) << "Input: [data]";
const TShape &dshape = in_shape->at(kData);
if (dshape.ndim() == 0) return false;
out_shape->clear();
out_shape->push_back(mshadow::Shape4(dshape[0], 1, 1, dshape[1] * dshape[2] * dshape[3]));
return true;
}

virtual OperatorProperty* Copy() const {
auto ptr = new FlattenProp();
return ptr;
}

virtual std::vector<int> DeclareBackwardDependency(
const std::vector<int> &out_grad,
const std::vector<int> &in_data,
const std::vector<int> &out_data) const {
return {out_grad[kOut]};
}

Operator* CreateOperator(Context ctx) const;
}; // class FlattenProp
#endif // DMLC_USE_CXX11

} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_FLATTEN_INL_H_
27 changes: 27 additions & 0 deletions src/operator/flatten.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*!
* Copyright (c) 2015 by Contributors
* \file flatten.cc
* \brief
* \author Bing Xu
*/

#include "./flatten-inl.h"


namespace mxnet {
namespace op {
template<>
Operator *CreateOp<cpu>() {
return new FlattenOp<cpu>();
}

Operator* FlattenProp::CreateOperator(Context ctx) const {
DO_BIND_DISPATCH(CreateOp);
}

MXNET_REGISTER_OP_PROPERTY(Flatten, FlattenProp)
.add_argument("data", "Symbol", "Input data to flatten.")
.describe("Flatten 4D input to form batch-1-1-feature format");

} // namespace op
} // namespace mxnet
19 changes: 19 additions & 0 deletions src/operator/flatten.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*!
* Copyright (c) 2015 by Contributors
* \file flatten.cc
* \brief
* \author Bing Xu
*/

#include "./flatten-inl.h"


namespace mxnet {
namespace op {
template<>
Operator *CreateOp<gpu>() {
return new FlattenOp<gpu>();
}

} // namespace op
} // namespace mxnet
1 change: 1 addition & 0 deletions src/operator/fully_connected-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class FullyConnectedOp : public Operator {
CHECK_EQ(out_data.size(), 1);
// TODO(bing): check the BLAS Handle, be careful
// maybe need blas handle from context
// TODO(bing): judge shape to remove flatten op
Stream<xpu> *s = ctx.get_stream<xpu>();
Tensor<xpu, 2> data = in_data[kData].FlatTo2D<xpu, real_t>(s);
Tensor<xpu, 2> wmat = in_data[kWeight].get<xpu, 2, real_t>(s);
Expand Down

0 comments on commit 1045c18

Please sign in to comment.