Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/py executor test #4922

Merged
merged 77 commits into from
Oct 19, 2017
Merged
Show file tree
Hide file tree
Changes from 73 commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
f5d9005
Implement FC layer with helper
reyoung Oct 11, 2017
5488ec9
Merge branch 'develop' of github.com:baidu/Paddle into feature/fc_lay…
reyoung Oct 12, 2017
153d9a8
Update LayerHelper
reyoung Oct 12, 2017
e016726
Merge branch 'develop' of github.com:baidu/Paddle into feature/fc_lay…
reyoung Oct 13, 2017
f6570b5
Add debug string for Python ProtoBuf
JiayiFeng Oct 13, 2017
f7cffb7
Add check of ProtoBuf initialization
JiayiFeng Oct 13, 2017
e017ba2
Layer wrapper for FC
reyoung Oct 13, 2017
1cf33cb
Merge remote-tracking branch 'pr/4800' into feature/fc_layer_with_helper
reyoung Oct 13, 2017
cd93f12
Fix unittest
reyoung Oct 13, 2017
3e613de
Merge branch 'develop' of github.com:baidu/Paddle into feature/fc_lay…
reyoung Oct 13, 2017
3ab53e4
Merge branch 'develop' of github.com:baidu/Paddle into feature/fc_lay…
reyoung Oct 15, 2017
a281c39
Fix CI
reyoung Oct 15, 2017
03fc36c
Add code generator
reyoung Oct 15, 2017
32cdc7b
AttributeChecker Better error log and speicalize bool
reyoung Oct 16, 2017
d28c2c7
Complete mlp, fit_a_line
reyoung Oct 16, 2017
647e1eb
Merge branch 'develop' of github.com:baidu/Paddle into feature/fc_lay…
reyoung Oct 16, 2017
216979d
Merge branch 'feature/fc_layer_with_helper' of https://github.com/rey…
JiayiFeng Oct 16, 2017
50bd700
Expose get global scope
reyoung Oct 16, 2017
0e5ba8c
Make global scope not thread-safe
reyoung Oct 16, 2017
6723273
Merge branch 'feature/fix_global_scope' into feature/py_executor_test
reyoung Oct 16, 2017
122bd2a
Merge branch 'develop' of github.com:baidu/Paddle into feature/py_exe…
reyoung Oct 16, 2017
686ac1e
Fix
reyoung Oct 17, 2017
e18e79d
Implementation of simple conv_2d layer
JiayiFeng Oct 17, 2017
d87e137
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
JiayiFeng Oct 17, 2017
5c778e7
Merge branch 'develop' of github.com:baidu/Paddle into feature/py_exe…
reyoung Oct 17, 2017
93643c3
Merge branch 'develop' of github.com:baidu/Paddle into feature/py_exe…
reyoung Oct 17, 2017
9310a3b
Stash
reyoung Oct 17, 2017
d0d1172
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
JiayiFeng Oct 17, 2017
8c05974
Remove private data members in OpRegister
reyoung Oct 17, 2017
dad5769
Merge branch 'feature/fix_segsig' into feature/py_executor_test
reyoung Oct 17, 2017
aa00ab8
Fix bugs
JiayiFeng Oct 17, 2017
a307501
Stash
reyoung Oct 17, 2017
df9d100
Expose FeedFetchList as VarType
reyoung Oct 17, 2017
0f5731a
Merge branch 'feature/expose_feed_fetch_var_list' into feature/py_exe…
reyoung Oct 17, 2017
ac6b11b
Change ProgramDesc not a global variable
reyoung Oct 18, 2017
072d6d0
Merge branch 'develop' of github.com:baidu/Paddle into feature/remove…
reyoung Oct 18, 2017
f92dc30
Merge branch 'develop' of github.com:baidu/Paddle into feature/remove…
reyoung Oct 18, 2017
71ec313
Polish code style
reyoung Oct 18, 2017
e6f0924
Stash
reyoung Oct 18, 2017
87938e4
Merge branch 'develop' of github.com:baidu/Paddle into feature/remove…
reyoung Oct 18, 2017
3e1ecf7
Correct implement BlockDesc destructor
reyoung Oct 18, 2017
9f99377
Correct implement BlockDesc destructor
reyoung Oct 18, 2017
0af66c5
Merge branch 'feature/correct_block_desc_dtor' of https://github.com/…
JiayiFeng Oct 18, 2017
b337e18
Merge branch 'develop' of github.com:baidu/Paddle into feature/remove…
reyoung Oct 18, 2017
175bfd5
Merge branch 'develop' of github.com:baidu/Paddle into feature/remove…
reyoung Oct 18, 2017
3d684bc
Unify program as parameter name
reyoung Oct 18, 2017
f73b9f2
Merge branch 'develop' of github.com:baidu/Paddle into feature/remove…
reyoung Oct 18, 2017
291e7d2
Merge branch 'develop' of github.com:baidu/Paddle into feature/remove…
reyoung Oct 18, 2017
6a4f6d9
Merge branch 'feature/remove_global_instance_of_program' into feature…
reyoung Oct 18, 2017
4f6d3c6
Merge branch 'develop' of github.com:baidu/Paddle into feature/remove…
reyoung Oct 18, 2017
60f96d1
Merge branch 'feature/remove_global_instance_of_program' into feature…
reyoung Oct 18, 2017
d5d025a
Fix bugs
JiayiFeng Oct 18, 2017
06ea1b7
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
JiayiFeng Oct 18, 2017
7842b2c
Add unittest
reyoung Oct 18, 2017
c36464f
Fix unit test error
JiayiFeng Oct 18, 2017
3ccaf48
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
JiayiFeng Oct 18, 2017
ab55cbe
Merge branch 'develop' of github.com:baidu/Paddle into feature/copy_c…
reyoung Oct 18, 2017
e61be25
Merge remote-tracking branch 'PR/4894' into feature/copy_ctor_program…
reyoung Oct 18, 2017
36ca498
Remove unused functions
reyoung Oct 18, 2017
c0c5fdd
Add clone for Python Program
reyoung Oct 18, 2017
1c184a9
Merge branch 'feature/copy_ctor_program_des' into feature/py_executor…
reyoung Oct 18, 2017
6dbc038
Working on executor
reyoung Oct 18, 2017
a64031d
Merge branch 'develop' of github.com:baidu/Paddle into feature/copy_c…
reyoung Oct 18, 2017
53c5bcc
Merge branch 'feature/copy_ctor_program_des' into feature/py_executor…
reyoung Oct 18, 2017
81fb44c
Stash
reyoung Oct 18, 2017
20e1297
Add glog as dependencies of ops
reyoung Oct 18, 2017
5f426cf
Expose VarDesc::persistable to Python
reyoung Oct 18, 2017
68cd25b
Merge branch 'feature/add_glog_in_operator' into feature/py_executor_…
reyoung Oct 18, 2017
32342f9
Merge branch 'feature/expose_persistable' into feature/py_executor_test
reyoung Oct 18, 2017
a4868c6
Test executor
reyoung Oct 19, 2017
1fb8e4e
Complete unittest
reyoung Oct 19, 2017
dc94946
Merge branch 'develop' of github.com:baidu/Paddle into feature/py_exe…
reyoung Oct 19, 2017
82e8f65
Polish code
reyoung Oct 19, 2017
1ca9002
Merge branch 'develop' of github.com:baidu/Paddle into feature/py_exe…
reyoung Oct 19, 2017
62cb355
Fix merge error
reyoung Oct 19, 2017
58ccf55
Follow comment
reyoung Oct 19, 2017
9bec178
Polish Python Code
reyoung Oct 19, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ add_custom_command(TARGET framework_py_proto POST_BUILD
cc_library(backward SRCS backward.cc DEPS net_op)
cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context)

cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto backward)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto backward glog)

cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
Expand Down
8 changes: 6 additions & 2 deletions paddle/framework/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,13 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id) {

for (auto& var : block.vars()) {
if (var.persistable()) {
scope->Var(var.name());
auto* ptr = scope->Var(var.name());
VLOG(3) << "Create Variable " << var.name()
<< " global, which pointer is " << ptr;
} else {
local_scope.Var(var.name());
auto* ptr = local_scope.Var(var.name());
VLOG(3) << "Create Variable " << var.name()
<< " locally, which pointer is " << ptr;
}
}

Expand Down
14 changes: 11 additions & 3 deletions paddle/framework/feed_fetch_method.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include "glog/logging.h"
#include "paddle/framework/feed_fetch_type.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/variable.h"

Expand All @@ -24,6 +26,7 @@ void SetFeedVariable(const LoDTensor& input, const std::string& var_name,
size_t index) {
// If var_name Variable is not found in GlobalScope, a new variable will
// be created.
VLOG(3) << "SetFeedVariable name=" << var_name << " index=" << index;
Variable* g_feed_value = GetGlobalScope().Var(var_name);
auto& feed_inputs =
*(g_feed_value->GetMutable<std::vector<paddle::framework::LoDTensor>>());
Expand All @@ -40,10 +43,15 @@ LoDTensor& GetFetchVariable(const std::string& var_name, size_t index) {
// Since we want to fetch LodTensor from a variable, the variable must
// be created alreadly.
Variable* g_fetch_value = GetGlobalScope().FindVar(var_name);
auto& fetch_outputs =
*(g_fetch_value->GetMutable<std::vector<paddle::framework::LoDTensor>>());
PADDLE_ENFORCE(g_fetch_value->IsType<FeedFetchList>(),
"Only %s can be invoked by GetFetchVariable",
typeid(FeedFetchList).name());
auto& fetch_outputs = *g_fetch_value->GetMutable<FeedFetchList>();
auto& tensor = fetch_outputs[index];
VLOG(3) << "Fetch " << var_name << " with index " << index
<< " shape= " << tensor.dims();
PADDLE_ENFORCE_LT(index, fetch_outputs.size());
return fetch_outputs[index];
return tensor;
}

} // namespace framework
Expand Down
2 changes: 2 additions & 0 deletions paddle/framework/framework.proto
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ message VarDesc {
enum VarType {
LOD_TENSOR = 1;
SELECTED_ROWS = 2;
FEED_MINIBATCH = 3;
FETCH_LIST = 4;
}
required string name = 1;
required VarType type = 2;
Expand Down
2 changes: 1 addition & 1 deletion paddle/framework/program_desc_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,4 @@ TEST(ProgramDesc, copy_ctor) {
// different and it is correct.
}
} // namespace framework
} // namespace paddle
} // namespace paddle
5 changes: 4 additions & 1 deletion paddle/framework/variable.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ class Variable {
public:
template <typename T>
const T& Get() const {
PADDLE_ENFORCE(IsType<T>(), "Variable must be type %s", typeid(T).name());
PADDLE_ENFORCE(holder_ != nullptr, "Variable must hold some thing");
PADDLE_ENFORCE(IsType<T>(),
"Variable must be type %s, the holding type is %s",
typeid(T).name(), holder_->Type().name());
return *static_cast<const T*>(holder_->Ptr());
}

Expand Down
23 changes: 19 additions & 4 deletions paddle/operators/feed_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ class FeedOp : public framework::OperatorBase {
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
auto feed_var_name = Input("Input");
auto feed_var_name = Input("X");
auto *feed_var = scope.FindVar(feed_var_name);

PADDLE_ENFORCE(feed_var != nullptr,
"Cannot find feed_var in scope, feed_var_name is %s",
feed_var_name);
Expand All @@ -40,6 +41,9 @@ class FeedOp : public framework::OperatorBase {

auto col = Attr<int>("col");

VLOG(3) << "Feed Var " << feed_var_name << "'s " << col << " column to var"
<< out_name;

auto &feed_list = feed_var->Get<framework::FeedFetchList>();
auto &feed_item = feed_list.at(static_cast<size_t>(col));
auto *out_item = out_var->GetMutable<framework::FeedFetchType>();
Expand All @@ -48,10 +52,21 @@ class FeedOp : public framework::OperatorBase {
}
};

class FeedOpInfoMaker : public framework::OpProtoAndCheckerMaker {
public:
FeedOpInfoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of feed op");
AddOutput("Out", "The output of feed op");
AddComment("feed op, it should not be configured by users directly");
AddAttr<int>("col", "column of feed");
}
};

} // namespace operators
} // namespace paddle

// We do not need to register OpInfoMaker,
// since feed operator will not be used by end users directly
REGISTER_OPERATOR(feed, paddle::operators::FeedOp,
paddle::framework::EmptyGradOpMaker);
paddle::framework::EmptyGradOpMaker,
paddle::operators::FeedOpInfoMaker);
20 changes: 16 additions & 4 deletions paddle/operators/fetch_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class FetchOp : public framework::OperatorBase {

void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
auto fetch_var_name = Input("Input");
auto fetch_var_name = Input("X");
auto *fetch_var = scope.FindVar(fetch_var_name);
PADDLE_ENFORCE(fetch_var != nullptr,
"Cannot find fetch variable in scope, fetch_var_name is %s",
Expand All @@ -52,13 +52,25 @@ class FetchOp : public framework::OperatorBase {
// FIXME(yuyang18): Should we assume the fetch operator always generate
// CPU outputs?
dst_item.CopyFromTensor(src_item, platform::CPUPlace(), dev_ctx);

VLOG(3) << "Fetch variable " << fetch_var_name << " to " << out_name;
}
};

class FetchOpInfoMaker : public framework::OpProtoAndCheckerMaker {
public:
FetchOpInfoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of fetch op");
AddOutput("Out", "The output of fetch op");
AddComment("fetch op, it should not be configured by users directly");
AddAttr<int>("col", "column of feed");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

column of feed --> column of fetch

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

}
};
} // namespace operators
} // namespace paddle

// We do not need to register OpInfoMaker,
// since fetch operator will not be used by end users directly
REGISTER_OPERATOR(fetch, paddle::operators::FetchOp,
paddle::framework::EmptyGradOpMaker);
paddle::framework::EmptyGradOpMaker,
paddle::operators::FetchOpInfoMaker);
4 changes: 3 additions & 1 deletion paddle/pybind/protobuf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,9 @@ void BindVarDsec(py::module &m) {

py::enum_<VarDesc::VarType>(var_desc, "VarType", "")
.value("LOD_TENSOR", VarDesc::LOD_TENSOR)
.value("SELECTED_ROWS", VarDesc::SELECTED_ROWS);
.value("SELECTED_ROWS", VarDesc::SELECTED_ROWS)
.value("FEED_MINIBATCH", VarDesc::FEED_MINIBATCH)
.value("FETCH_LIST", VarDesc::FETCH_LIST);
}

void BindOpDesc(py::module &m) {
Expand Down
20 changes: 17 additions & 3 deletions paddle/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ PYBIND11_PLUGIN(core) {
new (&instance) LoDTensor(new_lod);
#endif
})
.def("__init__", [](LoDTensor &instance) { new (&instance) LoDTensor(); })
.def("set_lod",
[](LoDTensor &self, const std::vector<std::vector<size_t>> &lod) {
#ifndef PADDLE_WITH_CUDA
Expand Down Expand Up @@ -216,7 +217,8 @@ All parameter, weight, gradient are variables in Paddle.
.def(py::init<>())
.def("new_scope", [](Scope &self) -> Scope * { return &self.NewScope(); },
py::return_value_policy::reference)
.def("drop_kids", &Scope::DropKids);
.def("drop_kids", &Scope::DropKids)
.def_static("global_scope", &GetGlobalScope);

//! @note: Be careful! PyBind will return std::string as an unicode, not
//! Python str. If you want a str object, you should cast them in Python.
Expand Down Expand Up @@ -264,6 +266,17 @@ All parameter, weight, gradient are variables in Paddle.
.def(py::init<>())
.def("__str__", string::to_string<const platform::CPUPlace &>);

py::class_<platform::Place>(m, "Place")
.def(py::init<>())
.def("set_place",
[](platform::Place &self, const platform::CPUPlace &cpu_place) {
self = cpu_place;
})
.def("set_place",
[](platform::Place &self, const platform::GPUPlace &gpu_place) {
self = gpu_place;
});

py::class_<OperatorBase>(m, "Operator")
.def_static("create",
[](py::bytes protobin) {
Expand Down Expand Up @@ -437,14 +450,15 @@ All parameter, weight, gradient are variables in Paddle.
py::class_<framework::Executor>(m, "Executor")
.def(py::init<std::vector<platform::Place> &>())
.def("run",
[](Executor &self, const ProgramDesc &program_desc, int block_id) {
[](Executor &self, ProgramDescBind *program_bind, int block_id) {
framework::Scope &global_scope = GetGlobalScope();
self.Run(program_desc, &global_scope, block_id);
self.Run(*program_bind->Proto(), &global_scope, block_id);
});

m.def("unique_integer", UniqueIntegerGenerator);

m.def("is_compile_gpu", IsCompileGPU);
//! FIXME: it is no need to `set_xxx_float/double/int`
m.def("set_feed_variable_float", framework::SetFeedVariable<float>);
m.def("set_feed_variable_double", framework::SetFeedVariable<double>);
m.def("set_feed_variable_int", framework::SetFeedVariable<int>);
Expand Down
59 changes: 59 additions & 0 deletions python/paddle/v2/framework/executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import paddle.v2.framework.core as core
from paddle.v2.framework.framework import Block, Program


class Executor(object):
def __init__(self, places):
if not isinstance(places, list) and not isinstance(places, tuple):
places = [places]

act_places = []
for each in places:
p = core.Place()
p.set_place(each)
act_places.append(p)

self.executor = core.Executor(act_places)

def run(self,
program,
feed,
fetch_list,
feed_var_name='feed',
fetch_var_name='fetch'):
if not isinstance(program, Program):
raise TypeError()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error message?


program = program.clone()
global_block = program.global_block()
assert isinstance(global_block, Block)
feed_var = global_block.create_var(
name=feed_var_name,
type=core.VarDesc.VarType.FEED_MINIBATCH,
persistable=True)

for i, name in enumerate(feed):
out = global_block.var(name)
global_block.prepend_op(
'feed',
inputs={'X': [feed_var]},
outputs={'Out': [out]},
attrs={'col': i})
# FIXME
core.set_feed_variable_float(feed[name], feed_var.name, i)

fetch_var = global_block.create_var(
name=fetch_var_name,
type=core.VarDesc.VarType.FETCH_LIST,
persistable=True)
for i, var in enumerate(fetch_list):
global_block.append_op(
type='fetch',
inputs={'X': [var]},
outputs={'Out': [fetch_var]},
attrs={'col': i})

assert isinstance(global_block, Block)
self.executor.run(program.desc, 0)
for i, _ in enumerate(fetch_list):
yield core.get_fetch_variable(fetch_var_name, i)
10 changes: 7 additions & 3 deletions python/paddle/v2/framework/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,8 @@ def __init__(self,
self.desc.set_block_attr(attr_name, attrs[attr_name].desc)

self.desc.check_attrs()
self.desc.infer_shape(self.block.desc)
if type not in {'feed', 'fetch'}:
self.desc.infer_shape(self.block.desc)
Copy link

@tonyyang-svail tonyyang-svail Oct 19, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why can't we infer shape a feed/fetch?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InferVarType is also needed.

self.desc.infer_var_type(self.block.desc)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here infer_shape is a compile-time method. Feed/Fetch op is a run-time operator, do not need to infer shape.

Copy link
Collaborator Author

@reyoung reyoung Oct 19, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InferVarType is also needed.

Add in following PRs


def __str__(self):
protostr = self.desc.serialize_to_string()
Expand Down Expand Up @@ -323,9 +324,12 @@ def idx(self):
return self.desc.id

def var(self, name):
if name not in self.vars:
if not isinstance(name, basestring):
raise TypeError()
v = self.vars.get(name, None)
if v is None:
raise ValueError("var %s not in this block" % name)
return self.vars[name]
return v

def all_parameters(self):
return {v for k, v in self.vars.iteritems() if isinstance(v, Parameter)}
Expand Down
15 changes: 9 additions & 6 deletions python/paddle/v2/framework/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,15 @@ def fc(input,
return helper.append_activation(pre_activation)


def data(name,
shape,
data_type='float32',
type=core.VarDesc.VarType.LOD_TENSOR,
program=None):
def data_layer(name,
shape,
data_type='float32',
type=core.VarDesc.VarType.LOD_TENSOR,
append_batch_size=True,
program=None):
helper = LayerHelper('data', **locals())
shape = [-1] + shape # append batch size as -1
if append_batch_size:
shape = [-1] + shape # append batch size as -1
return helper.create_global_variable(
name=name, shape=shape, dtype=data_type, type=type)

Expand Down Expand Up @@ -112,6 +114,7 @@ def func(**kwargs):


_create_op_func_('mean')
_create_op_func_('mul')
_create_op_func_('pool2d')


Expand Down
37 changes: 37 additions & 0 deletions python/paddle/v2/framework/tests/test_executor_and_mul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import unittest
from paddle.v2.framework.layers import mul, data_layer
import paddle.v2.framework.core as core
from paddle.v2.framework.executor import Executor
from paddle.v2.framework.framework import g_program
import numpy


class TestExecutor(unittest.TestCase):
def test_mul(self):
a = data_layer(name='a', shape=[784], data_type='float32')
b = data_layer(
name='b',
shape=[784, 100],
data_type='float32',
append_batch_size=False)
out = mul(x=a, y=b)
place = core.CPUPlace()
a_np = numpy.random.random((100, 784)).astype('float32')
tensor_a = core.LoDTensor()
tensor_a.set(a_np, place)
b_np = numpy.random.random((784, 100)).astype('float32')
tensor_b = core.LoDTensor()
tensor_b.set(b_np, place)
# del input_tensor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this comment for?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo

exe = Executor(place)
outs = list(
exe.run(g_program,
feed={'a': tensor_a,
'b': tensor_b},
fetch_list=[out]))
out = numpy.array(outs[0])
self.assertEqual((100, 100), out.shape)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to check the result

self.assertTrue(numpy.allclose(out, numpy.dot(a_np, b_np)))

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.



if __name__ == '__main__':
unittest.main()