-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/py executor test #4922
Feature/py executor test #4922
Changes from 73 commits
f5d9005
5488ec9
153d9a8
e016726
f6570b5
f7cffb7
e017ba2
1cf33cb
cd93f12
3e613de
3ab53e4
a281c39
03fc36c
32cdc7b
d28c2c7
647e1eb
216979d
50bd700
0e5ba8c
6723273
122bd2a
686ac1e
e18e79d
d87e137
5c778e7
93643c3
9310a3b
d0d1172
8c05974
dad5769
aa00ab8
a307501
df9d100
0f5731a
ac6b11b
072d6d0
f92dc30
71ec313
e6f0924
87938e4
3e1ecf7
9f99377
0af66c5
b337e18
175bfd5
3d684bc
f73b9f2
291e7d2
6a4f6d9
4f6d3c6
60f96d1
d5d025a
06ea1b7
7842b2c
c36464f
3ccaf48
ab55cbe
e61be25
36ca498
c0c5fdd
1c184a9
6dbc038
a64031d
53c5bcc
81fb44c
20e1297
5f426cf
68cd25b
32342f9
a4868c6
1fb8e4e
dc94946
82e8f65
1ca9002
62cb355
58ccf55
9bec178
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import paddle.v2.framework.core as core | ||
from paddle.v2.framework.framework import Block, Program | ||
|
||
|
||
class Executor(object): | ||
def __init__(self, places): | ||
if not isinstance(places, list) and not isinstance(places, tuple): | ||
places = [places] | ||
|
||
act_places = [] | ||
for each in places: | ||
p = core.Place() | ||
p.set_place(each) | ||
act_places.append(p) | ||
|
||
self.executor = core.Executor(act_places) | ||
|
||
def run(self, | ||
program, | ||
feed, | ||
fetch_list, | ||
feed_var_name='feed', | ||
fetch_var_name='fetch'): | ||
if not isinstance(program, Program): | ||
raise TypeError() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Error message? |
||
|
||
program = program.clone() | ||
global_block = program.global_block() | ||
assert isinstance(global_block, Block) | ||
feed_var = global_block.create_var( | ||
name=feed_var_name, | ||
type=core.VarDesc.VarType.FEED_MINIBATCH, | ||
persistable=True) | ||
|
||
for i, name in enumerate(feed): | ||
out = global_block.var(name) | ||
global_block.prepend_op( | ||
'feed', | ||
inputs={'X': [feed_var]}, | ||
outputs={'Out': [out]}, | ||
attrs={'col': i}) | ||
# FIXME | ||
core.set_feed_variable_float(feed[name], feed_var.name, i) | ||
|
||
fetch_var = global_block.create_var( | ||
name=fetch_var_name, | ||
type=core.VarDesc.VarType.FETCH_LIST, | ||
persistable=True) | ||
for i, var in enumerate(fetch_list): | ||
global_block.append_op( | ||
type='fetch', | ||
inputs={'X': [var]}, | ||
outputs={'Out': [fetch_var]}, | ||
attrs={'col': i}) | ||
|
||
assert isinstance(global_block, Block) | ||
self.executor.run(program.desc, 0) | ||
for i, _ in enumerate(fetch_list): | ||
yield core.get_fetch_variable(fetch_var_name, i) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -256,7 +256,8 @@ def __init__(self, | |
self.desc.set_block_attr(attr_name, attrs[attr_name].desc) | ||
|
||
self.desc.check_attrs() | ||
self.desc.infer_shape(self.block.desc) | ||
if type not in {'feed', 'fetch'}: | ||
self.desc.infer_shape(self.block.desc) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why can't we infer shape a feed/fetch? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. InferVarType is also needed.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Add in following PRs |
||
|
||
def __str__(self): | ||
protostr = self.desc.serialize_to_string() | ||
|
@@ -323,9 +324,12 @@ def idx(self): | |
return self.desc.id | ||
|
||
def var(self, name): | ||
if name not in self.vars: | ||
if not isinstance(name, basestring): | ||
raise TypeError() | ||
v = self.vars.get(name, None) | ||
if v is None: | ||
raise ValueError("var %s not in this block" % name) | ||
return self.vars[name] | ||
return v | ||
|
||
def all_parameters(self): | ||
return {v for k, v in self.vars.iteritems() if isinstance(v, Parameter)} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import unittest | ||
from paddle.v2.framework.layers import mul, data_layer | ||
import paddle.v2.framework.core as core | ||
from paddle.v2.framework.executor import Executor | ||
from paddle.v2.framework.framework import g_program | ||
import numpy | ||
|
||
|
||
class TestExecutor(unittest.TestCase): | ||
def test_mul(self): | ||
a = data_layer(name='a', shape=[784], data_type='float32') | ||
b = data_layer( | ||
name='b', | ||
shape=[784, 100], | ||
data_type='float32', | ||
append_batch_size=False) | ||
out = mul(x=a, y=b) | ||
place = core.CPUPlace() | ||
a_np = numpy.random.random((100, 784)).astype('float32') | ||
tensor_a = core.LoDTensor() | ||
tensor_a.set(a_np, place) | ||
b_np = numpy.random.random((784, 100)).astype('float32') | ||
tensor_b = core.LoDTensor() | ||
tensor_b.set(b_np, place) | ||
# del input_tensor | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is this comment for? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Typo |
||
exe = Executor(place) | ||
outs = list( | ||
exe.run(g_program, | ||
feed={'a': tensor_a, | ||
'b': tensor_b}, | ||
fetch_list=[out])) | ||
out = numpy.array(outs[0]) | ||
self.assertEqual((100, 100), out.shape) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better to check the result
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
column of feed --> column of fetch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.