-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simplify fluid api for fit a line #10301
Merged
daming-lu
merged 39 commits into
PaddlePaddle:develop
from
daming-lu:simplify_fluid_api
May 15, 2018
Merged
Changes from 7 commits
Commits
Show all changes
39 commits
Select commit
Hold shift + click to select a range
fd36e61
use the style here to refactor code:
daming-lu 5484022
Merge branch 'develop' into simplify_fluid_api
daming-lu a065d11
add notest_fit_a_line following the new pattern here:
daming-lu 527c1c7
newline
daming-lu 80eada2
Merge develop
daming-lu ea754d1
Merge remote-tracking branch 'upstream/develop' into simplify_fluid_api
daming-lu d6c8f4c
use new trainer.train() and old infer, but not working
daming-lu 00f68c7
fix code style
daming-lu 67584c6
some fixes
JiayiFeng 2d1341e
some fixes
JiayiFeng 3c4a2e9
Merge pull request #1 from JiayiFeng/for_daming
daming-lu b11cd59
fix
daming-lu 11b58e0
Merge branch 'simplify_fluid_api' of https://github.com/daming-lu/Pad…
daming-lu 26860b4
remove notest_ prefix
daming-lu 8bdea24
add prefix as removing it will make TeamCity fail
daming-lu 8b07b01
Merge remote-tracking branch 'upstream/develop' into simplify_fluid_api
daming-lu 791765a
remove notest_ prefix. In this case, the test should pass.
daming-lu 78bbf31
using both train() and test()
daming-lu 7f47ecd
Merge branch 'develop' into simplify_fluid_api
daming-lu 89ed947
Merge branch 'develop' into simplify_fluid_api
daming-lu 3e6510a
follow new pattern
daming-lu e229d81
fix style
daming-lu d6b32c5
add inference_program
daming-lu 0beac7c
Merge branch 'develop' into simplify_fluid_api
daming-lu 835b5d9
rm comment
daming-lu 0c89ca7
style
daming-lu 19bd2b8
fluid data
daming-lu 305f482
paddle -> fluid
daming-lu 841e256
layers
daming-lu 2c073b4
act=None
daming-lu 54e1274
add inferencer.infer()
daming-lu 8ac261e
style
daming-lu 2e22982
Merge branch 'develop' into simplify_fluid_api
daming-lu 4b8537b
save inference model in the new way
daming-lu 757952b
style
daming-lu 4c221c9
add PR link in comment for ref
daming-lu fab651c
remove unused is_local
daming-lu 410ec9a
added tony-yang's fix
daming-lu 35c78af
no need to change old book test
daming-lu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
135 changes: 135 additions & 0 deletions
135
python/paddle/fluid/tests/book/fit_a_line/notest_fit_a_line.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import paddle | ||
import paddle.fluid as fluid | ||
import contextlib | ||
import numpy | ||
import unittest | ||
|
||
|
||
# train reader | ||
BATCH_SIZE = 20 | ||
|
||
train_reader = paddle.batch( | ||
paddle.reader.shuffle( | ||
paddle.dataset.uci_housing.train(), buf_size=500), | ||
batch_size=BATCH_SIZE) | ||
|
||
|
||
# train | ||
def linear(): | ||
x = fluid.layers.data(name='x', shape=[13], dtype='float32') | ||
y = fluid.layers.data(name='y', shape=[1], dtype='float32') | ||
y_predict = fluid.layers.fc(input=x, size=1, act=None) | ||
|
||
loss = fluid.layers.cross_entropy(y_predict, y) | ||
avg_loss = fluid.layers.mean(loss) | ||
accuracy = fluid.layers.accuracy(input=y_predict, label=y) | ||
|
||
return avg_loss | ||
|
||
|
||
def train(use_cuda, save_dirname, is_local): | ||
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() | ||
|
||
trainer = fluid.Trainer( | ||
linear, | ||
optimizer=fluid.optimizer.SGD(learning_rate=0.001), | ||
place=place | ||
) | ||
|
||
def event_handler(event): | ||
if isinstance(event, fluid.EndStepEvent): | ||
print event.metrics | ||
|
||
elif isinstance(event, fluid.EndEpochEvent): | ||
test_metrics = trainer.test(reader=paddle.dataset.uci_housing.test()) | ||
print test_metrics | ||
|
||
if test_metrics[0] < 10.0: | ||
if save_dirname is not None: | ||
# fluid.io.save_inference_model(save_dirname, ['x'], [y_predict]) | ||
trainer.save_params(save_dirname) | ||
return | ||
|
||
trainer.train( | ||
reader=train_reader, | ||
num_epochs=100, | ||
event_handler=event_handler, | ||
feed_order={'x': 0, 'y': 1}) | ||
|
||
|
||
# infer | ||
def infer(use_cuda, save_dirname=None): | ||
if save_dirname is None: | ||
return | ||
|
||
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() | ||
exe = fluid.Executor(place) | ||
|
||
inference_scope = fluid.core.Scope() | ||
with fluid.scope_guard(inference_scope): | ||
# Use fluid.io.load_inference_model to obtain the inference program desc, | ||
# the feed_target_names (the names of variables that will be feeded | ||
# data using feed operators), and the fetch_targets (variables that | ||
# we want to obtain data from using fetch operators). | ||
[inference_program, feed_target_names, | ||
fetch_targets] = fluid.io.load_inference_model(save_dirname, exe) | ||
|
||
# The input's dimension should be 2-D and the second dim is 13 | ||
# The input data should be >= 0 | ||
batch_size = 10 | ||
tensor_x = numpy.random.uniform(0, 10, | ||
[batch_size, 13]).astype("float32") | ||
assert feed_target_names[0] == 'x' | ||
results = exe.run(inference_program, | ||
feed={feed_target_names[0]: tensor_x}, | ||
fetch_list=fetch_targets) | ||
print("infer shape: ", results[0].shape) | ||
print("infer results: ", results[0]) | ||
|
||
|
||
def main(use_cuda, is_local=True): | ||
if use_cuda and not fluid.core.is_compiled_with_cuda(): | ||
return | ||
|
||
# Directory for saving the trained model | ||
save_dirname = "fit_a_line.inference.model" | ||
|
||
train(use_cuda, save_dirname, is_local) | ||
infer(use_cuda, save_dirname) | ||
|
||
|
||
class TestFitALine(unittest.TestCase): | ||
def test_cpu(self): | ||
with self.program_scope_guard(): | ||
main(use_cuda=False) | ||
|
||
def test_cuda(self): | ||
with self.program_scope_guard(): | ||
main(use_cuda=True) | ||
|
||
@contextlib.contextmanager | ||
def program_scope_guard(self): | ||
prog = fluid.Program() | ||
startup_prog = fluid.Program() | ||
scope = fluid.core.Scope() | ||
with fluid.scope_guard(scope): | ||
with fluid.program_guard(prog, startup_prog): | ||
yield | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the purpose of this API Simplification effort is to hide details like Executor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. The
infer
is in old format for a reason. #10426I will change it to new API once it is implemented. The current PR only changed
trainer.train()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@daming-lu : You could take a look at Jeff's PR: https://github.com/PaddlePaddle/Paddle/pull/10308/files#diff-07faf3656ed1409403fcb2bb7e00f455R93
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sidgoyal78 Thanks Sid. I kept the old infer() on purpose. I want the test to be OK with new trainer.train and old infer.