diff --git a/Makefile b/Makefile index 6a01e38d4d06..2f2b14bee0a7 100644 --- a/Makefile +++ b/Makefile @@ -64,14 +64,14 @@ endif #BIN = test/test_threaded_engine test/api_registry_test OBJ = narray_function_cpu.o # add threaded engine after it is done -OBJCXX11 = engine.o narray.o c_api.o operator.o symbol.o storage.o fully_connected_cpu.o static_graph.o activation_cpu.o graph_executor.o softmax_cpu.o elementwise_sum_cpu.o pooling_cpu.o convolution_cpu.o +OBJCXX11 = flatten_cpu.o engine.o narray.o c_api.o operator.o symbol.o storage.o fully_connected_cpu.o static_graph.o activation_cpu.o graph_executor.o softmax_cpu.o elementwise_sum_cpu.o pooling_cpu.o convolution_cpu.o CUOBJ = SLIB = lib/libmxnet.so ALIB = lib/libmxnet.a LIB_DEP = $(DMLC_CORE)/libdmlc.a ifeq ($(USE_CUDA), 1) - CUOBJ += narray_function_gpu.o fully_connected_gpu.o activation_gpu.o elementwise_sum_gpu.o pooling_gpu.o softmax_gpu.o convolution_gpu.o + CUOBJ += flatten_gpu.o narray_function_gpu.o fully_connected_gpu.o activation_gpu.o elementwise_sum_gpu.o pooling_gpu.o softmax_gpu.o convolution_gpu.o endif .PHONY: clean all test lint doc @@ -103,6 +103,8 @@ softmax_cpu.o: src/operator/softmax.cc softmax_gpu.o: src/operator/softmax.cu convolution_cpu.o: src/operator/convolution.cc convolution_gpu.o: src/operator/convolution.cu +flatten_cpu.o: src/operator/flatten.cc +flatten_gpu.o: src/operator/flatten.cu lib/libmxnet.a: $(OBJ) $(OBJCXX11) $(CUOBJ) lib/libmxnet.so: $(OBJ) $(OBJCXX11) $(CUOBJ) diff --git a/python/test_mnist.py b/python/test_mnist.py index 3a3ee85a8d3f..63153cbe7f19 100644 --- a/python/test_mnist.py +++ b/python/test_mnist.py @@ -3,28 +3,14 @@ import numpy as np import os, cPickle, gzip -def Softmax(x): - batch, nidden = x.shape - maxes = np.max(x, axis=1) - x -= maxes.reshape(batch, 1) - x = np.exp(x) - norm = np.sum(x, axis=1) - prob = x / norm.reshape((batch, 1)) - return prob - def CalAcc(out, label): pred = np.argmax(out, axis=1) return np.sum(pred == label) * 1.0 / out.shape[0] -def SetGradient(out_grad, label): - assert(out_grad.shape[0] == label.shape[0]) - for i in xrange(label.shape[0]): - k = label[i] - out_grad[i][k] -= 1.0 # load data class MNISTIter(object): - def __init__(self, which_set, batch_size=100): + def __init__(self, which_set, batch_size=100, flatten=True): if not os.path.exists('mnist.pkl.gz'): os.system("wget http://deeplearning.net/data/mnist/mnist.pkl.gz") f = gzip.open('mnist.pkl.gz', 'rb') @@ -39,6 +25,7 @@ def __init__(self, which_set, batch_size=100): else: self.data = test_set[0] self.data = np.asarray(test_set[1]) + self.flatten = flatten self.batch_size = batch_size self.nbatch = self.data.shape[0] / batch_size assert(self.data.shape[0] % batch_size == 0) # I am lazy @@ -57,25 +44,34 @@ def Get(self): raise Exception("Iterator is at end") start = self.now_idx * self.batch_size end = (self.now_idx + 1) * self.batch_size - return (self.data[start:end, :], self.label[start:end]) + if self.flatten: + return (self.data[start:end, :], self.label[start:end]) + else: + return (self.data[start:end, :].reshape(batch_size, 1, 28, 28), + self.label[start:end]) # symbol net batch_size = 100 data = mx.symbol.Variable('data') -fc1 = mx.symbol.FullyConnected(data=data, name='fc1', num_hidden=160) +fc1 = mx.symbol.Convolution(data = data, name='conv1', nb_filter=32, kernel=(7,7), stride=(2,2), nstep=10, no_bias=1) act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu") -fc2 = mx.symbol.FullyConnected(data = act1, name='fc2', num_hidden=10) -args_list = fc2.list_arguments() +mp = mx.symbol.Pooling(data = act1, name = 'mp', kernel=(2,2), stride=(2,2), pool_type='avg') +fl = mx.symbol.Flatten(data = mp, name="flatten") +fc2 = mx.symbol.FullyConnected(data = fl, name='fc2', num_hidden=10) +softmax = mx.symbol.Softmax(data = fc2, name = 'sm') +args_list = softmax.list_arguments() # infer shape -data_shape = (batch_size, 784) -arg_shapes, out_shapes = fc2.infer_shape(data=data_shape) +#data_shape = (batch_size, 784) + +data_shape = (batch_size, 1, 28, 28) +arg_shapes, out_shapes = softmax.infer_shape(data=data_shape) arg_narrays = [mx.narray.create(shape) for shape in arg_shapes] grad_narrays = [mx.narray.create(shape) for shape in arg_shapes] mom_narrays = [mx.narray.create(shape) for shape in arg_shapes] inputs = dict(zip(args_list, arg_narrays)) - +print zip(args_list, arg_shapes) np.random.seed(0) # set random weight for name, narray in inputs.items(): @@ -87,7 +83,7 @@ def Get(self): req = ['write_to' for i in range(len(arg_narrays))] # bind executer # TODO(bing): think of a better bind interface -executor = fc2.bind(mx.Context('cpu'), arg_narrays, grad_narrays, req) +executor = softmax.bind(mx.Context('cpu'), arg_narrays, grad_narrays, req) # update out_narray = executor.heads()[0] @@ -104,8 +100,8 @@ def Update(mom, grad, weight): block = zip(mom_narrays, grad_narrays, arg_narrays) -train = MNISTIter("train", batch_size) -valid = MNISTIter("valid", batch_size) +train = MNISTIter("train", batch_size, False) +valid = MNISTIter("valid", batch_size, False) for i in xrange(epoch): # train @@ -115,11 +111,10 @@ def Update(mom, grad, weight): while train.Next(): data, label = train.Get() inputs["data"].numpy[:] = data + inputs["sm_label"].numpy[:] = label executor.forward() - out_narray.numpy[:] = Softmax(out_narray.numpy) train_acc += CalAcc(out_narray.numpy, label) grad_narray.numpy[:] = out_narray.numpy - SetGradient(grad_narray.numpy, label) executor.backward([grad_narray]) for mom, grad, weight in block: diff --git a/src/operator/flatten-inl.h b/src/operator/flatten-inl.h new file mode 100644 index 000000000000..da4110296909 --- /dev/null +++ b/src/operator/flatten-inl.h @@ -0,0 +1,101 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file flatten-inl.h + * \brief + * \author Bing Xu +*/ +#ifndef MXNET_OPERATOR_FLATTEN_INL_H_ +#define MXNET_OPERATOR_FLATTEN_INL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "./operator_common.h" + +namespace mxnet { +namespace op { + +enum FlattenOpInputs {kData}; +enum FlattenOpOutputs {kOut}; + +template +class FlattenOp : public Operator { + public: + virtual void Forward(const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(in_data.size(), 1); + CHECK_EQ(req.size(), 1); + CHECK_EQ(out_data.size(), 1); + Stream *s = ctx.get_stream(); + Tensor data = in_data[kData].get(s); + Tensor out = out_data[kOut].get(s); + Assign(out, req[kOut], reshape(data, out.shape_)); + } + + virtual void Backward(const OpContext &ctx, + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &req, + const std::vector &in_grad) { + using namespace mshadow; + using namespace mshadow::expr; + Stream *s = ctx.get_stream(); + Tensor grad_out = out_grad[kData].get(s); + Tensor grad_in = in_grad[kOut].get(s); + Assign(grad_in, req[kData], reshape(grad_out, grad_in.shape_)); + } +}; // class FlattenOp + +template +Operator* CreateOp(); + +#if DMLC_USE_CXX11 +class FlattenProp : public OperatorProperty { + public: + FlattenProp() {} + + virtual void Init(const std::vector >& kwargs) {} + + virtual std::string TypeString() const { + return "Flatten"; + } + + virtual bool InferShape(std::vector *in_shape, + std::vector *out_shape) const { + CHECK_EQ(in_shape->size(), 1) << "Input: [data]"; + const TShape &dshape = in_shape->at(kData); + if (dshape.ndim() == 0) return false; + out_shape->clear(); + out_shape->push_back(mshadow::Shape4(dshape[0], 1, 1, dshape[1] * dshape[2] * dshape[3])); + return true; + } + + virtual OperatorProperty* Copy() const { + auto ptr = new FlattenProp(); + return ptr; + } + + virtual std::vector DeclareBackwardDependency( + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data) const { + return {out_grad[kOut]}; + } + + Operator* CreateOperator(Context ctx) const; +}; // class FlattenProp +#endif // DMLC_USE_CXX11 + +} // namespace op +} // namespace mxnet +#endif // MXNET_OPERATOR_FLATTEN_INL_H_ diff --git a/src/operator/flatten.cc b/src/operator/flatten.cc new file mode 100644 index 000000000000..db156def8ca2 --- /dev/null +++ b/src/operator/flatten.cc @@ -0,0 +1,27 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file flatten.cc + * \brief + * \author Bing Xu +*/ + +#include "./flatten-inl.h" + + +namespace mxnet { +namespace op { +template<> +Operator *CreateOp() { + return new FlattenOp(); +} + +Operator* FlattenProp::CreateOperator(Context ctx) const { + DO_BIND_DISPATCH(CreateOp); +} + +MXNET_REGISTER_OP_PROPERTY(Flatten, FlattenProp) +.add_argument("data", "Symbol", "Input data to flatten.") +.describe("Flatten 4D input to form batch-1-1-feature format"); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/flatten.cu b/src/operator/flatten.cu new file mode 100644 index 000000000000..5bf9d47c5691 --- /dev/null +++ b/src/operator/flatten.cu @@ -0,0 +1,19 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file flatten.cc + * \brief + * \author Bing Xu +*/ + +#include "./flatten-inl.h" + + +namespace mxnet { +namespace op { +template<> + Operator *CreateOp() { + return new FlattenOp(); +} + +} // namespace op +} // namespace mxnet diff --git a/src/operator/fully_connected-inl.h b/src/operator/fully_connected-inl.h index 532113e56c6d..ac5fd992cd82 100644 --- a/src/operator/fully_connected-inl.h +++ b/src/operator/fully_connected-inl.h @@ -60,6 +60,7 @@ class FullyConnectedOp : public Operator { CHECK_EQ(out_data.size(), 1); // TODO(bing): check the BLAS Handle, be careful // maybe need blas handle from context + // TODO(bing): judge shape to remove flatten op Stream *s = ctx.get_stream(); Tensor data = in_data[kData].FlatTo2D(s); Tensor wmat = in_data[kWeight].get(s);