From cdcc7c87ca66962438761777adce6f03ec17675b Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Tue, 10 Apr 2018 16:45:58 -0700 Subject: [PATCH 1/7] add float16 support to save op --- paddle/fluid/operators/save_load_op_test.cc | 32 +++++++++++++++++++++ paddle/fluid/operators/save_op.cc | 22 +++++++++++++- 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/save_load_op_test.cc b/paddle/fluid/operators/save_load_op_test.cc index a7ba1e0ae1d22..aff4dd00322ec 100644 --- a/paddle/fluid/operators/save_load_op_test.cc +++ b/paddle/fluid/operators/save_load_op_test.cc @@ -61,3 +61,35 @@ TEST(SaveLoadOp, CPU) { } } } + +TEST(SaveLoadFP16Op, CPU) { + paddle::framework::Scope scope; + paddle::platform::CPUPlace place; + + auto var = scope.Var("test_var"); + auto tensor = var->GetMutable(); + tensor->Resize({3, 10}); + + float* expect = tensor->mutable_data(place); + for (int64_t i = 0; i < tensor->numel(); ++i) { + expect[i] = static_cast(i); + } + + paddle::framework::AttributeMap attrs; + attrs.insert({"file_path", std::string("tensor.save")}); + attrs.insert({"save_as_fp16_dtype", true}); + + auto save_op = paddle::framework::OpRegistry::CreateOp( + "save", {{"X", {"test_var"}}}, {}, attrs); + save_op->Run(scope, place); + + auto load_var = scope.Var("out_var"); + auto target = load_var->GetMutable(); + auto load_op = paddle::framework::OpRegistry::CreateOp( + "load", {}, {{"Out", {"out_var"}}}, attrs); + load_op->Run(scope, place); + paddle::platform::float16* actual = target->data(); + for (int64_t i = 0; i < tensor->numel(); ++i) { + EXPECT_EQ(expect[i], static_cast(actual[i])); + } +} diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index 4a715c4baab2d..a2506fc17978c 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/data_type_transform.h" #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" @@ -68,6 +69,7 @@ class SaveOp : public framework::OperatorBase { const platform::Place &place) const override { auto filename = Attr("file_path"); auto overwrite = Attr("overwrite"); + auto save_as_fp16 = Attr("save_as_fp16_dtype"); if (FileExists(filename) && !overwrite) { PADDLE_THROW("%s is existed, cannot save to it when overwrite=false", @@ -96,7 +98,19 @@ class SaveOp : public framework::OperatorBase { platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(place); - framework::SerializeToStream(fout, tensor, dev_ctx); + auto in_dtype = framework::ToDataType(tensor.type()); + auto out_dtype = save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype; + + if (in_dtype != out_dtype) { + auto in_kernel_type = framework::OpKernelType(in_dtype, place); + auto out_kernel_type = framework::OpKernelType(out_dtype, place); + framework::LoDTensor out; + framework::TransDataType(in_kernel_type, out_kernel_type, tensor, &out); + std::cout << "var " << iname << " is converted to float16" << std::endl; + framework::SerializeToStream(fout, out, dev_ctx); + } else { + framework::SerializeToStream(fout, tensor, dev_ctx); + } } }; @@ -114,6 +128,12 @@ This operator will serialize and write a tensor variable to file on disk. "(boolean, default true)" "Overwrite the output file if exist") .SetDefault(true); + AddAttr("save_as_fp16_dtype", + "(boolean, default false)" + "If true, the tensor will be converted to float16 data " + "type and then saved. Otherwise, the tensor will be " + "directly saved without data type conversion.") + .SetDefault(false); AddAttr("file_path", "(string)" "The \"file_path\" where the variable will be saved.") From 033a61ca15e0da806462d10c4c6451e6c0cc78d0 Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Wed, 11 Apr 2018 16:33:42 -0700 Subject: [PATCH 2/7] add float16 image class example --- paddle/fluid/operators/save_load_op_test.cc | 2 +- paddle/fluid/operators/save_op.cc | 4 +- python/paddle/fluid/framework.py | 35 ++- python/paddle/fluid/io.py | 72 ++++-- python/paddle/fluid/tests/CMakeLists.txt | 1 + .../fluid/tests/book_float16/.gitignore | 1 + .../fluid/tests/book_float16/CMakeLists.txt | 7 + .../test_float16_image_classification.py | 238 ++++++++++++++++++ 8 files changed, 332 insertions(+), 28 deletions(-) create mode 100644 python/paddle/fluid/tests/book_float16/.gitignore create mode 100644 python/paddle/fluid/tests/book_float16/CMakeLists.txt create mode 100644 python/paddle/fluid/tests/book_float16/test_float16_image_classification.py diff --git a/paddle/fluid/operators/save_load_op_test.cc b/paddle/fluid/operators/save_load_op_test.cc index aff4dd00322ec..0cfb7fb730587 100644 --- a/paddle/fluid/operators/save_load_op_test.cc +++ b/paddle/fluid/operators/save_load_op_test.cc @@ -77,7 +77,7 @@ TEST(SaveLoadFP16Op, CPU) { paddle::framework::AttributeMap attrs; attrs.insert({"file_path", std::string("tensor.save")}); - attrs.insert({"save_as_fp16_dtype", true}); + attrs.insert({"save_as_fp16", true}); auto save_op = paddle::framework::OpRegistry::CreateOp( "save", {{"X", {"test_var"}}}, {}, attrs); diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index a2506fc17978c..45cbf26bdd77e 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -69,7 +69,7 @@ class SaveOp : public framework::OperatorBase { const platform::Place &place) const override { auto filename = Attr("file_path"); auto overwrite = Attr("overwrite"); - auto save_as_fp16 = Attr("save_as_fp16_dtype"); + auto save_as_fp16 = Attr("save_as_fp16"); if (FileExists(filename) && !overwrite) { PADDLE_THROW("%s is existed, cannot save to it when overwrite=false", @@ -128,7 +128,7 @@ This operator will serialize and write a tensor variable to file on disk. "(boolean, default true)" "Overwrite the output file if exist") .SetDefault(true); - AddAttr("save_as_fp16_dtype", + AddAttr("save_as_fp16", "(boolean, default false)" "If true, the tensor will be converted to float16 data " "type and then saved. Otherwise, the tensor will be " diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 33cf6918178ff..7622c4947c6d2 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -23,16 +23,9 @@ import unique_name __all__ = [ - 'Block', - 'Variable', - 'Program', - 'Operator', - 'default_startup_program', - 'default_main_program', - 'program_guard', - 'switch_startup_program', - 'switch_main_program', - 'get_var', + 'Block', 'Variable', 'Program', 'Operator', 'default_startup_program', + 'default_main_program', 'program_guard', 'switch_startup_program', + 'switch_main_program', 'get_var', 'np_dtype_to_fluid_dtype' ] EMPTY_VAR_NAME = core.kEmptyVarName() @@ -41,6 +34,28 @@ ZERO_VAR_SUFFIX = core.kZeroVarSuffix() +def np_dtype_to_fluid_dtype(input): + """Change the dtype of float16 numpy array + + numpy float16 is binded to paddle::platform::float16 + in tensor_py.h via the help of uint16 data type since + the internal memory representation of float16 is + uint16_t in paddle and np.uint16 in numpy, which are + themselves binded together by pybind. + + Args: + input: input numpy array + + Returns: + input: The dtype of input will be changed to np.uint16 if + it is originally np.float16, such that the internal memory + of input will be reinterpreted as of dtype np.uint16. + """ + if input.dtype == np.float16: + input.dtype = np.uint16 + return input + + def grad_var_name(var_name): """ return gradient name for a certain var name diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 1c0f1f6eb415b..4dcbf4bf7f02e 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -63,12 +63,25 @@ def _clone_var_in_block_(block, var): persistable=True) +def _get_no_coversion_var_names_(program): + op_names = {'batch_norm'} + var_names = set() + for block in program.blocks: + for op in block.ops: + if op.type in op_names: + input_names = op.input_arg_names + for in_name in input_names: + var_names.add(in_name) + return var_names + + def save_vars(executor, dirname, main_program=None, vars=None, predicate=None, - filename=None): + filename=None, + use_float16=False): """ Save variables to directory by executor. @@ -85,33 +98,46 @@ def save_vars(executor, :return: None """ - if vars is None: - if main_program is None: - main_program = default_main_program() - if not isinstance(main_program, Program): - raise TypeError("program should be as Program type or None") + if main_program is None: + main_program = default_main_program() + if not isinstance(main_program, Program): + raise TypeError("program should be as Program type or None") + if vars is None: save_vars( executor, dirname=dirname, vars=filter(predicate, main_program.list_vars()), - filename=filename) + filename=filename, + use_float16=use_float16) else: save_program = Program() save_block = save_program.global_block() - save_var_map = {} + + # Get the names of variables that shouldn't be converted to float16 in + # float16 saving mode, right now it is limited to batch norm input weights. + no_conversion_var_names = _get_no_coversion_var_names_(main_program) + print no_conversion_var_names for each_var in vars: # NOTE: don't save the variable which type is RAW if each_var.type == core.VarDesc.VarType.RAW: continue + new_var = _clone_var_in_block_(save_block, each_var) + # Determine if a variable needed to be converted to float16 before saving + save_as_fp16 = use_float16 and new_var.name not in no_conversion_var_names + print new_var.name, use_float16, save_as_fp16 + if filename is None: save_block.append_op( type='save', inputs={'X': [new_var]}, outputs={}, - attrs={'file_path': os.path.join(dirname, new_var.name)}) + attrs={ + 'file_path': os.path.join(dirname, new_var.name), + 'save_as_fp16': save_as_fp16 + }) else: save_var_map[new_var.name] = new_var @@ -129,7 +155,11 @@ def save_vars(executor, executor.run(save_program) -def save_params(executor, dirname, main_program=None, filename=None): +def save_params(executor, + dirname, + main_program=None, + filename=None, + use_float16=False): """ Save all parameters to directory with executor. """ @@ -139,10 +169,15 @@ def save_params(executor, dirname, main_program=None, filename=None): main_program=main_program, vars=None, predicate=is_parameter, - filename=filename) + filename=filename, + use_float16=use_float16) -def save_persistables(executor, dirname, main_program=None, filename=None): +def save_persistables(executor, + dirname, + main_program=None, + filename=None, + use_float16=False): """ Save all persistables to directory with executor. """ @@ -152,7 +187,8 @@ def save_persistables(executor, dirname, main_program=None, filename=None): main_program=main_program, vars=None, predicate=is_persistable, - filename=filename) + filename=filename, + use_float16=use_float16) def load_vars(executor, @@ -301,7 +337,8 @@ def save_inference_model(dirname, executor, main_program=None, model_filename=None, - params_filename=None): + params_filename=None, + use_float16=False): """ Build a model especially for inference, and save it to directory by the executor. @@ -359,7 +396,12 @@ def save_inference_model(dirname, with open(model_filename, "wb") as f: f.write(inference_program.desc.serialize_to_string()) - save_persistables(executor, dirname, inference_program, params_filename) + save_persistables( + executor, + dirname, + inference_program, + params_filename, + use_float16=use_float16) def get_feed_targets_names(program): diff --git a/python/paddle/fluid/tests/CMakeLists.txt b/python/paddle/fluid/tests/CMakeLists.txt index d24417bbacb50..6a80da41b24ff 100644 --- a/python/paddle/fluid/tests/CMakeLists.txt +++ b/python/paddle/fluid/tests/CMakeLists.txt @@ -8,3 +8,4 @@ endforeach() add_subdirectory(unittests) add_subdirectory(book) add_subdirectory(book_memory_optimization) +add_subdirectory(book_float16) diff --git a/python/paddle/fluid/tests/book_float16/.gitignore b/python/paddle/fluid/tests/book_float16/.gitignore new file mode 100644 index 0000000000000..dd28d354f4160 --- /dev/null +++ b/python/paddle/fluid/tests/book_float16/.gitignore @@ -0,0 +1 @@ +*.inference.model diff --git a/python/paddle/fluid/tests/book_float16/CMakeLists.txt b/python/paddle/fluid/tests/book_float16/CMakeLists.txt new file mode 100644 index 0000000000000..673c965b662a0 --- /dev/null +++ b/python/paddle/fluid/tests/book_float16/CMakeLists.txt @@ -0,0 +1,7 @@ +file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") +string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") + +# default test +foreach(src ${TEST_OPS}) + py_test(${src} SRCS ${src}.py) +endforeach() diff --git a/python/paddle/fluid/tests/book_float16/test_float16_image_classification.py b/python/paddle/fluid/tests/book_float16/test_float16_image_classification.py new file mode 100644 index 0000000000000..2d4344144068d --- /dev/null +++ b/python/paddle/fluid/tests/book_float16/test_float16_image_classification.py @@ -0,0 +1,238 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import paddle +import paddle.fluid as fluid +import contextlib +import math +import sys +import numpy as np +import unittest +import os + + +def resnet_cifar10(input, depth=32): + def conv_bn_layer(input, ch_out, filter_size, stride, padding, act='relu'): + tmp = fluid.layers.conv2d( + input=input, + filter_size=filter_size, + num_filters=ch_out, + stride=stride, + padding=padding, + act=None, + bias_attr=False) + return fluid.layers.batch_norm(input=tmp, act=act) + + def shortcut(input, ch_in, ch_out, stride): + if ch_in != ch_out: + return conv_bn_layer(input, ch_out, 1, stride, 0, None) + else: + return input + + def basicblock(input, ch_in, ch_out, stride): + tmp = conv_bn_layer(input, ch_out, 3, stride, 1) + tmp = conv_bn_layer(tmp, ch_out, 3, 1, 1, act=None) + short = shortcut(input, ch_in, ch_out, stride) + return fluid.layers.elementwise_add(x=tmp, y=short, act='relu') + + def layer_warp(block_func, input, ch_in, ch_out, count, stride): + tmp = block_func(input, ch_in, ch_out, stride) + for i in range(1, count): + tmp = block_func(tmp, ch_out, ch_out, 1) + return tmp + + assert (depth - 2) % 6 == 0 + n = (depth - 2) / 6 + conv1 = conv_bn_layer( + input=input, ch_out=16, filter_size=3, stride=1, padding=1) + res1 = layer_warp(basicblock, conv1, 16, 16, n, 1) + res2 = layer_warp(basicblock, res1, 16, 32, n, 2) + res3 = layer_warp(basicblock, res2, 32, 64, n, 2) + pool = fluid.layers.pool2d( + input=res3, pool_size=8, pool_type='avg', pool_stride=1) + return pool + + +def vgg16_bn_drop(input): + def conv_block(input, num_filter, groups, dropouts): + return fluid.nets.img_conv_group( + input=input, + pool_size=2, + pool_stride=2, + conv_num_filter=[num_filter] * groups, + conv_filter_size=3, + conv_act='relu', + conv_with_batchnorm=True, + conv_batchnorm_drop_rate=dropouts, + pool_type='max') + + conv1 = conv_block(input, 64, 2, [0.3, 0]) + conv2 = conv_block(conv1, 128, 2, [0.4, 0]) + conv3 = conv_block(conv2, 256, 3, [0.4, 0.4, 0]) + conv4 = conv_block(conv3, 512, 3, [0.4, 0.4, 0]) + conv5 = conv_block(conv4, 512, 3, [0.4, 0.4, 0]) + + drop = fluid.layers.dropout(x=conv5, dropout_prob=0.5) + fc1 = fluid.layers.fc(input=drop, size=4096, act=None) + bn = fluid.layers.batch_norm(input=fc1, act='relu') + drop2 = fluid.layers.dropout(x=bn, dropout_prob=0.5) + fc2 = fluid.layers.fc(input=drop2, size=4096, act=None) + return fc2 + + +def train(net_type, save_dirname): + classdim = 10 + data_shape = [3, 32, 32] + + images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + + if net_type == "vgg": + print("train vgg net") + net = vgg16_bn_drop(images) + elif net_type == "resnet": + print("train resnet") + net = resnet_cifar10(images, 32) + else: + raise ValueError("%s network is not supported" % net_type) + + predict = fluid.layers.fc(input=net, size=classdim, act='softmax') + cost = fluid.layers.cross_entropy(input=predict, label=label) + avg_cost = fluid.layers.mean(cost) + acc = fluid.layers.accuracy(input=predict, label=label) + + # Test program + test_program = fluid.default_main_program().clone(for_test=True) + + optimizer = fluid.optimizer.Adam(learning_rate=0.001) + optimize_ops, params_grads = optimizer.minimize(avg_cost) + + BATCH_SIZE = 128 + PASS_NUM = 1 + + train_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.cifar.train10(), buf_size=128 * 10), + batch_size=BATCH_SIZE) + + test_reader = paddle.batch( + paddle.dataset.cifar.test10(), batch_size=BATCH_SIZE) + + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + feeder = fluid.DataFeeder(place=place, feed_list=[images, label]) + + main_program = fluid.default_main_program() + exe.run(fluid.default_startup_program()) + loss = 0.0 + for pass_id in range(PASS_NUM): + for batch_id, data in enumerate(train_reader()): + exe.run(main_program, feed=feeder.feed(data)) + + if (batch_id % 10) == 0: + acc_list = [] + avg_loss_list = [] + for tid, test_data in enumerate(test_reader()): + loss_t, acc_t = exe.run(program=test_program, + feed=feeder.feed(test_data), + fetch_list=[avg_cost, acc]) + if math.isnan(float(loss_t)): + sys.exit("got NaN loss, training failed.") + acc_list.append(float(acc_t)) + avg_loss_list.append(float(loss_t)) + break # Use 1 segment for speeding up CI + + acc_value = np.array(acc_list).mean() + avg_loss_value = np.array(avg_loss_list).mean() + + print( + 'PassID {0:1}, BatchID {1:04}, Test Loss {2:2.2}, Acc {3:2.2}'. + format(pass_id, batch_id + 1, + float(avg_loss_value), float(acc_value))) + + if acc_value > 0.01: # Low threshold for speeding up CI + fluid.io.save_inference_model( + save_dirname, ["pixel"], [predict], + exe, + use_float16=True) + return + + +def infer(save_dirname): + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + + inference_scope = fluid.core.Scope() + with fluid.scope_guard(inference_scope): + # Use fluid.io.load_inference_model to obtain the inference program desc, + # the feed_target_names (the names of variables that will be feeded + # data using feed operators), and the fetch_targets (variables that + # we want to obtain data from using fetch operators). + [inference_program, feed_target_names, + fetch_targets] = fluid.io.load_inference_model(save_dirname, exe) + + # The input's dimension of conv should be 4-D or 5-D. + # Use normilized image pixels as input data, which should be in the range [0, 1.0]. + batch_size = 1 + # The input is of numpy float16 data type + tensor_img = np.random.rand(batch_size, 3, 32, 32).astype(np.float16) + # Construct feed as a dictionary of {feed_target_name: feed_target_data} + # and results will contain a list of data corresponding to fetch_targets. + # Use np_dtype_to_fluid_dtype to change the data type of tensor_img so that + # it can be successfully binded to fluid float16 data type, which can invoke + # the float16 inference mode to run. + results = exe.run(inference_program, + feed={ + feed_target_names[0]: + fluid.np_dtype_to_fluid_dtype(tensor_img) + }, + fetch_list=fetch_targets) + print("infer results: ", results[0]) + + +def main(net_type): + # float16 mode is currently supported only on cuda GPU + if not fluid.core.is_compiled_with_cuda(): + return + + # Directory for saving the trained model + save_dirname = "float16_image_classification_" + net_type + ".inference.model" + + train(net_type, save_dirname) + infer(save_dirname) + + +class TestFP16ImageClassification(unittest.TestCase): + def test_vgg(self): + with self.scope_prog_guard(): + main('vgg') + + def test_resnet(self): + with self.scope_prog_guard(): + main('resnet') + + @contextlib.contextmanager + def scope_prog_guard(self): + prog = fluid.Program() + startup_prog = fluid.Program() + scope = fluid.core.Scope() + with fluid.scope_guard(scope): + with fluid.program_guard(prog, startup_prog): + yield + + +if __name__ == '__main__': + unittest.main() From 049ba79eb057ed0a2c9778bb4ed76b8a8d2e4302 Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Wed, 11 Apr 2018 18:27:51 -0700 Subject: [PATCH 3/7] add float16 inference test --- paddle/fluid/inference/CMakeLists.txt | 1 + .../tests/book_float16/CMakeLists.txt | 27 +++++++++ ..._inference_float16_image_classification.cc | 58 +++++++++++++++++++ paddle/fluid/inference/tests/test_helper.h | 4 +- .../test_float16_image_classification.py | 18 +++--- 5 files changed, 99 insertions(+), 9 deletions(-) create mode 100644 paddle/fluid/inference/tests/book_float16/CMakeLists.txt create mode 100644 paddle/fluid/inference/tests/book_float16/test_inference_float16_image_classification.cc diff --git a/paddle/fluid/inference/CMakeLists.txt b/paddle/fluid/inference/CMakeLists.txt index f417f62f3f753..050d444bbd0db 100644 --- a/paddle/fluid/inference/CMakeLists.txt +++ b/paddle/fluid/inference/CMakeLists.txt @@ -21,4 +21,5 @@ endif() if(WITH_TESTING) add_subdirectory(tests/book) + add_subdirectory(tests/book_float16) endif() diff --git a/paddle/fluid/inference/tests/book_float16/CMakeLists.txt b/paddle/fluid/inference/tests/book_float16/CMakeLists.txt new file mode 100644 index 0000000000000..e9d22d8b4c958 --- /dev/null +++ b/paddle/fluid/inference/tests/book_float16/CMakeLists.txt @@ -0,0 +1,27 @@ +function(inference_test TARGET_NAME) + set(options "") + set(oneValueArgs "") + set(multiValueArgs ARGS) + cmake_parse_arguments(inference_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests) + set(arg_list "") + if(inference_test_ARGS) + foreach(arg ${inference_test_ARGS}) + list(APPEND arg_list "_${arg}") + endforeach() + else() + list(APPEND arg_list "_") + endif() + foreach(arg ${arg_list}) + string(REGEX REPLACE "^_$" "" arg "${arg}") + cc_test(test_inference_${TARGET_NAME}${arg} + SRCS test_inference_${TARGET_NAME}.cc + DEPS ARCHIVE_START paddle_fluid ARCHIVE_END + ARGS --dirname=${PYTHON_TESTS_DIR}/book_float16/${TARGET_NAME}${arg}.inference.model) + set_tests_properties(test_inference_${TARGET_NAME}${arg} + PROPERTIES DEPENDS test_${TARGET_NAME}) + endforeach() +endfunction(inference_test) + +inference_test(float16_image_classification ARGS vgg resnet) diff --git a/paddle/fluid/inference/tests/book_float16/test_inference_float16_image_classification.cc b/paddle/fluid/inference/tests/book_float16/test_inference_float16_image_classification.cc new file mode 100644 index 0000000000000..1a9e4f9a3da70 --- /dev/null +++ b/paddle/fluid/inference/tests/book_float16/test_inference_float16_image_classification.cc @@ -0,0 +1,58 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "gflags/gflags.h" +#include "gtest/gtest.h" +#include "paddle/fluid/inference/tests/test_helper.h" +#include "paddle/fluid/platform/float16.h" + +DEFINE_string(dirname, "", "Directory of the inference model."); +DEFINE_int32(batch_size, 1, "Batch size of input data"); +DEFINE_int32(repeat, 1, "Running the inference program repeat times"); + +TEST(inference, float16_image_classification) { +// float16 inference is currently only supported on CUDA GPU +#ifdef PADDLE_WITH_CUDA + using float16 = paddle::platform::float16; + + if (FLAGS_dirname.empty() || FLAGS_batch_size < 1 || FLAGS_repeat < 1) { + LOG(FATAL) << "Usage: ./example --dirname=path/to/your/model " + "--batch_size=1 --repeat=1"; + } + + LOG(INFO) << "FLAGS_dirname: " << FLAGS_dirname << std::endl; + std::string dirname = FLAGS_dirname; + + // 0. Call `paddle::framework::InitDevices()` initialize all the devices + // In unittests, this is done in paddle/testing/paddle_gtest_main.cc + + paddle::framework::LoDTensor input; + // Use normilized image pixels as input data, + // which should be in the range [0.0, 1.0]. + SetupTensor(&input, {FLAGS_batch_size, 3, 32, 32}, + static_cast(0), static_cast(1)); + std::vector cpu_feeds; + cpu_feeds.push_back(&input); + + paddle::framework::LoDTensor output; + std::vector cpu_fetchs; + cpu_fetchs.push_back(&output); + + // Run inference on CUDA GPU + LOG(INFO) << "--- GPU Runs: ---"; + TestInference(dirname, cpu_feeds, + cpu_fetchs, FLAGS_repeat); + LOG(INFO) << output.dims(); +#endif +} diff --git a/paddle/fluid/inference/tests/test_helper.h b/paddle/fluid/inference/tests/test_helper.h index 064e400f0c750..ffb1282fb660a 100644 --- a/paddle/fluid/inference/tests/test_helper.h +++ b/paddle/fluid/inference/tests/test_helper.h @@ -31,7 +31,9 @@ void SetupTensor(paddle::framework::LoDTensor* input, T* input_ptr = input->mutable_data(dims, paddle::platform::CPUPlace()); for (int i = 0; i < input->numel(); ++i) { - input_ptr[i] = static_cast(uniform_dist(rng) * (upper - lower) + lower); + input_ptr[i] = + static_cast(uniform_dist(rng) * static_cast(upper - lower) + + static_cast(lower)); } } diff --git a/python/paddle/fluid/tests/book_float16/test_float16_image_classification.py b/python/paddle/fluid/tests/book_float16/test_float16_image_classification.py index 2d4344144068d..b8f250a7f5b27 100644 --- a/python/paddle/fluid/tests/book_float16/test_float16_image_classification.py +++ b/python/paddle/fluid/tests/book_float16/test_float16_image_classification.py @@ -131,7 +131,9 @@ def train(net_type, save_dirname): test_reader = paddle.batch( paddle.dataset.cifar.test10(), batch_size=BATCH_SIZE) - place = fluid.CUDAPlace(0) + place = fluid.CPUPlace() + if fluid.core.is_compiled_with_cuda(): + place = fluid.CUDAPlace(0) exe = fluid.Executor(place) feeder = fluid.DataFeeder(place=place, feed_list=[images, label]) @@ -172,6 +174,10 @@ def train(net_type, save_dirname): def infer(save_dirname): + # float16 inference is currently only supported on CUDA GPU + if not fluid.core.is_compiled_with_cuda(): + return + place = fluid.CUDAPlace(0) exe = fluid.Executor(place) @@ -191,9 +197,9 @@ def infer(save_dirname): tensor_img = np.random.rand(batch_size, 3, 32, 32).astype(np.float16) # Construct feed as a dictionary of {feed_target_name: feed_target_data} # and results will contain a list of data corresponding to fetch_targets. - # Use np_dtype_to_fluid_dtype to change the data type of tensor_img so that - # it can be successfully binded to fluid float16 data type, which can invoke - # the float16 inference mode to run. + # Use np_dtype_to_fluid_dtype to bind tensor_img of numpy float16 data type + # with fluid float16 data type so that it will invoke the inference to run + # in float16 mode. results = exe.run(inference_program, feed={ feed_target_names[0]: @@ -204,10 +210,6 @@ def infer(save_dirname): def main(net_type): - # float16 mode is currently supported only on cuda GPU - if not fluid.core.is_compiled_with_cuda(): - return - # Directory for saving the trained model save_dirname = "float16_image_classification_" + net_type + ".inference.model" From bbc60360d651a1947bc98d12dbfe49c9266aa8ca Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Wed, 11 Apr 2018 19:32:30 -0700 Subject: [PATCH 4/7] code clean up --- paddle/fluid/operators/save_op.cc | 1 - python/paddle/fluid/io.py | 23 +++++++++++++++---- .../test_float16_image_classification.py | 4 ++-- 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index 45cbf26bdd77e..f45d07ed90d52 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -106,7 +106,6 @@ class SaveOp : public framework::OperatorBase { auto out_kernel_type = framework::OpKernelType(out_dtype, place); framework::LoDTensor out; framework::TransDataType(in_kernel_type, out_kernel_type, tensor, &out); - std::cout << "var " << iname << " is converted to float16" << std::endl; framework::SerializeToStream(fout, out, dev_ctx); } else { framework::SerializeToStream(fout, tensor, dev_ctx); diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 4dcbf4bf7f02e..87898aa190c5b 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -63,7 +63,22 @@ def _clone_var_in_block_(block, var): persistable=True) -def _get_no_coversion_var_names_(program): +def _get_no_fp16_coversion_var_names_(program): + """ + Get the set of input variable names that shouldn't be converted to float16. + + When we want to save the trained parameters for float16 inference, most + parameters need to be firstly converted to float16 and then saved by the + save op. However, there are some parameters that shouldn't be converted to + float16 because the corresponding operator requires float32 parameters even + in float16 mode (when the input data is of float16 data type). Currently, + the only operator that has this exclusion is the batch norm op. + + :param program: program to get the variable names + :type program: Program + :return: set of input variable names + :type var_names: set + """ op_names = {'batch_norm'} var_names = set() for block in program.blocks: @@ -117,8 +132,9 @@ def save_vars(executor, # Get the names of variables that shouldn't be converted to float16 in # float16 saving mode, right now it is limited to batch norm input weights. - no_conversion_var_names = _get_no_coversion_var_names_(main_program) - print no_conversion_var_names + no_conversion_var_names = _get_no_fp16_coversion_var_names_( + main_program) + for each_var in vars: # NOTE: don't save the variable which type is RAW if each_var.type == core.VarDesc.VarType.RAW: @@ -127,7 +143,6 @@ def save_vars(executor, new_var = _clone_var_in_block_(save_block, each_var) # Determine if a variable needed to be converted to float16 before saving save_as_fp16 = use_float16 and new_var.name not in no_conversion_var_names - print new_var.name, use_float16, save_as_fp16 if filename is None: save_block.append_op( diff --git a/python/paddle/fluid/tests/book_float16/test_float16_image_classification.py b/python/paddle/fluid/tests/book_float16/test_float16_image_classification.py index b8f250a7f5b27..f41a5a281ea2c 100644 --- a/python/paddle/fluid/tests/book_float16/test_float16_image_classification.py +++ b/python/paddle/fluid/tests/book_float16/test_float16_image_classification.py @@ -198,8 +198,8 @@ def infer(save_dirname): # Construct feed as a dictionary of {feed_target_name: feed_target_data} # and results will contain a list of data corresponding to fetch_targets. # Use np_dtype_to_fluid_dtype to bind tensor_img of numpy float16 data type - # with fluid float16 data type so that it will invoke the inference to run - # in float16 mode. + # with fluid float16 data type so that it will invoke the inference engine + # to run in float16 mode. results = exe.run(inference_program, feed={ feed_target_names[0]: From ac10b025f1615b6882cb6c99c3e8f39e26e613a3 Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Thu, 12 Apr 2018 15:10:23 -0700 Subject: [PATCH 5/7] simplify float16 inference code --- paddle/fluid/inference/CMakeLists.txt | 1 - .../fluid/inference/tests/book/CMakeLists.txt | 29 +++++++++- .../test_inference_image_classification.cc | 33 +++++++---- .../tests/book_float16/CMakeLists.txt | 27 --------- ..._inference_float16_image_classification.cc | 58 ------------------- 5 files changed, 51 insertions(+), 97 deletions(-) delete mode 100644 paddle/fluid/inference/tests/book_float16/CMakeLists.txt delete mode 100644 paddle/fluid/inference/tests/book_float16/test_inference_float16_image_classification.cc diff --git a/paddle/fluid/inference/CMakeLists.txt b/paddle/fluid/inference/CMakeLists.txt index 050d444bbd0db..f417f62f3f753 100644 --- a/paddle/fluid/inference/CMakeLists.txt +++ b/paddle/fluid/inference/CMakeLists.txt @@ -21,5 +21,4 @@ endif() if(WITH_TESTING) add_subdirectory(tests/book) - add_subdirectory(tests/book_float16) endif() diff --git a/paddle/fluid/inference/tests/book/CMakeLists.txt b/paddle/fluid/inference/tests/book/CMakeLists.txt index 6ed77adb9d891..cbb78a3509e13 100644 --- a/paddle/fluid/inference/tests/book/CMakeLists.txt +++ b/paddle/fluid/inference/tests/book/CMakeLists.txt @@ -18,14 +18,41 @@ function(inference_test TARGET_NAME) cc_test(test_inference_${TARGET_NAME}${arg} SRCS test_inference_${TARGET_NAME}.cc DEPS ARCHIVE_START paddle_fluid ARCHIVE_END - ARGS --dirname=${PYTHON_TESTS_DIR}/book/${TARGET_NAME}${arg}.inference.model) + ARGS --dirname=${PYTHON_TESTS_DIR}/book/${TARGET_NAME}${arg}.inference.model --use_float16=false) set_tests_properties(test_inference_${TARGET_NAME}${arg} PROPERTIES DEPENDS test_${TARGET_NAME}) endforeach() endfunction(inference_test) +function(inference_float16_test TARGET_NAME) + set(options "") + set(oneValueArgs "") + set(multiValueArgs ARGS) + cmake_parse_arguments(inference_float16_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests) + set(arg_list "") + if(inference_float16_test_ARGS) + foreach(arg ${inference_float16_test_ARGS}) + list(APPEND arg_list "_${arg}") + endforeach() + else() + list(APPEND arg_list "_") + endif() + foreach(arg ${arg_list}) + string(REGEX REPLACE "^_$" "" arg "${arg}") + cc_test(test_inference_float16_${TARGET_NAME}${arg} + SRCS test_inference_${TARGET_NAME}.cc + DEPS ARCHIVE_START paddle_fluid ARCHIVE_END + ARGS --dirname=${PYTHON_TESTS_DIR}/book_float16/float16_${TARGET_NAME}${arg}.inference.model --use_float16=true) + set_tests_properties(test_inference_float16_${TARGET_NAME}${arg} + PROPERTIES DEPENDS test_float16_${TARGET_NAME}) + endforeach() +endfunction(inference_float16_test) + inference_test(fit_a_line) inference_test(image_classification ARGS vgg resnet) +inference_float16_test(image_classification ARGS vgg resnet) inference_test(label_semantic_roles) inference_test(recognize_digits ARGS mlp conv) inference_test(recommender_system) diff --git a/paddle/fluid/inference/tests/book/test_inference_image_classification.cc b/paddle/fluid/inference/tests/book/test_inference_image_classification.cc index ca2077d07411d..164096960a6a2 100644 --- a/paddle/fluid/inference/tests/book/test_inference_image_classification.cc +++ b/paddle/fluid/inference/tests/book/test_inference_image_classification.cc @@ -19,8 +19,11 @@ limitations under the License. */ DEFINE_string(dirname, "", "Directory of the inference model."); DEFINE_int32(batch_size, 1, "Batch size of input data"); DEFINE_int32(repeat, 1, "Running the inference program repeat times"); +DEFINE_bool(use_float16, false, "Running inference in float16 mode or not"); TEST(inference, image_classification) { + using float16 = paddle::platform::float16; + if (FLAGS_dirname.empty() || FLAGS_batch_size < 1 || FLAGS_repeat < 1) { LOG(FATAL) << "Usage: ./example --dirname=path/to/your/model " "--batch_size=1 --repeat=1"; @@ -35,20 +38,28 @@ TEST(inference, image_classification) { paddle::framework::LoDTensor input; // Use normilized image pixels as input data, // which should be in the range [0.0, 1.0]. - SetupTensor(&input, {FLAGS_batch_size, 3, 32, 32}, - static_cast(0), static_cast(1)); + if (!FLAGS_use_float16) { + SetupTensor(&input, {FLAGS_batch_size, 3, 32, 32}, + static_cast(0), static_cast(1)); + } else { + SetupTensor(&input, {FLAGS_batch_size, 3, 32, 32}, + static_cast(0), static_cast(1)); + } std::vector cpu_feeds; cpu_feeds.push_back(&input); + // float16 inference is currently not supported on CPU paddle::framework::LoDTensor output1; - std::vector cpu_fetchs1; - cpu_fetchs1.push_back(&output1); + if (!FLAGS_use_float16) { + std::vector cpu_fetchs1; + cpu_fetchs1.push_back(&output1); - // Run inference on CPU - LOG(INFO) << "--- CPU Runs: ---"; - TestInference(dirname, cpu_feeds, - cpu_fetchs1, FLAGS_repeat); - LOG(INFO) << output1.dims(); + // Run inference on CPU + LOG(INFO) << "--- CPU Runs: ---"; + TestInference(dirname, cpu_feeds, + cpu_fetchs1, FLAGS_repeat); + LOG(INFO) << output1.dims(); + } #ifdef PADDLE_WITH_CUDA paddle::framework::LoDTensor output2; @@ -61,6 +72,8 @@ TEST(inference, image_classification) { cpu_fetchs2, FLAGS_repeat); LOG(INFO) << output2.dims(); - CheckError(output1, output2); + if (!FLAGS_use_float16) { + CheckError(output1, output2); + } #endif } diff --git a/paddle/fluid/inference/tests/book_float16/CMakeLists.txt b/paddle/fluid/inference/tests/book_float16/CMakeLists.txt deleted file mode 100644 index e9d22d8b4c958..0000000000000 --- a/paddle/fluid/inference/tests/book_float16/CMakeLists.txt +++ /dev/null @@ -1,27 +0,0 @@ -function(inference_test TARGET_NAME) - set(options "") - set(oneValueArgs "") - set(multiValueArgs ARGS) - cmake_parse_arguments(inference_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - - set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests) - set(arg_list "") - if(inference_test_ARGS) - foreach(arg ${inference_test_ARGS}) - list(APPEND arg_list "_${arg}") - endforeach() - else() - list(APPEND arg_list "_") - endif() - foreach(arg ${arg_list}) - string(REGEX REPLACE "^_$" "" arg "${arg}") - cc_test(test_inference_${TARGET_NAME}${arg} - SRCS test_inference_${TARGET_NAME}.cc - DEPS ARCHIVE_START paddle_fluid ARCHIVE_END - ARGS --dirname=${PYTHON_TESTS_DIR}/book_float16/${TARGET_NAME}${arg}.inference.model) - set_tests_properties(test_inference_${TARGET_NAME}${arg} - PROPERTIES DEPENDS test_${TARGET_NAME}) - endforeach() -endfunction(inference_test) - -inference_test(float16_image_classification ARGS vgg resnet) diff --git a/paddle/fluid/inference/tests/book_float16/test_inference_float16_image_classification.cc b/paddle/fluid/inference/tests/book_float16/test_inference_float16_image_classification.cc deleted file mode 100644 index 1a9e4f9a3da70..0000000000000 --- a/paddle/fluid/inference/tests/book_float16/test_inference_float16_image_classification.cc +++ /dev/null @@ -1,58 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "gflags/gflags.h" -#include "gtest/gtest.h" -#include "paddle/fluid/inference/tests/test_helper.h" -#include "paddle/fluid/platform/float16.h" - -DEFINE_string(dirname, "", "Directory of the inference model."); -DEFINE_int32(batch_size, 1, "Batch size of input data"); -DEFINE_int32(repeat, 1, "Running the inference program repeat times"); - -TEST(inference, float16_image_classification) { -// float16 inference is currently only supported on CUDA GPU -#ifdef PADDLE_WITH_CUDA - using float16 = paddle::platform::float16; - - if (FLAGS_dirname.empty() || FLAGS_batch_size < 1 || FLAGS_repeat < 1) { - LOG(FATAL) << "Usage: ./example --dirname=path/to/your/model " - "--batch_size=1 --repeat=1"; - } - - LOG(INFO) << "FLAGS_dirname: " << FLAGS_dirname << std::endl; - std::string dirname = FLAGS_dirname; - - // 0. Call `paddle::framework::InitDevices()` initialize all the devices - // In unittests, this is done in paddle/testing/paddle_gtest_main.cc - - paddle::framework::LoDTensor input; - // Use normilized image pixels as input data, - // which should be in the range [0.0, 1.0]. - SetupTensor(&input, {FLAGS_batch_size, 3, 32, 32}, - static_cast(0), static_cast(1)); - std::vector cpu_feeds; - cpu_feeds.push_back(&input); - - paddle::framework::LoDTensor output; - std::vector cpu_fetchs; - cpu_fetchs.push_back(&output); - - // Run inference on CUDA GPU - LOG(INFO) << "--- GPU Runs: ---"; - TestInference(dirname, cpu_feeds, - cpu_fetchs, FLAGS_repeat); - LOG(INFO) << output.dims(); -#endif -} From 7b540af516df4515cddb2b01c293152689aee882 Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Thu, 12 Apr 2018 17:19:52 -0700 Subject: [PATCH 6/7] modify cmake inference_test function --- .../fluid/inference/tests/book/CMakeLists.txt | 45 +++++++------------ 1 file changed, 16 insertions(+), 29 deletions(-) diff --git a/paddle/fluid/inference/tests/book/CMakeLists.txt b/paddle/fluid/inference/tests/book/CMakeLists.txt index cbb78a3509e13..c684971b6f223 100644 --- a/paddle/fluid/inference/tests/book/CMakeLists.txt +++ b/paddle/fluid/inference/tests/book/CMakeLists.txt @@ -13,46 +13,33 @@ function(inference_test TARGET_NAME) else() list(APPEND arg_list "_") endif() + + set(use_float16 "") + if(${TARGET_NAME} MATCHES "^float16") + if(${TARGET_NAME} MATCHES "image_classification") + set(use_float16 "--use_float16=true") + endif() + set(book_dir "book_float16") + else() + set(book_dir "book") + endif() + set(SOURCE_NAME "") + string(REGEX REPLACE "^float16_" "" SOURCE_NAME "${TARGET_NAME}") + foreach(arg ${arg_list}) string(REGEX REPLACE "^_$" "" arg "${arg}") cc_test(test_inference_${TARGET_NAME}${arg} - SRCS test_inference_${TARGET_NAME}.cc + SRCS test_inference_${SOURCE_NAME}.cc DEPS ARCHIVE_START paddle_fluid ARCHIVE_END - ARGS --dirname=${PYTHON_TESTS_DIR}/book/${TARGET_NAME}${arg}.inference.model --use_float16=false) + ARGS --dirname=${PYTHON_TESTS_DIR}/${book_dir}/${TARGET_NAME}${arg}.inference.model ${use_float16}) set_tests_properties(test_inference_${TARGET_NAME}${arg} PROPERTIES DEPENDS test_${TARGET_NAME}) endforeach() endfunction(inference_test) -function(inference_float16_test TARGET_NAME) - set(options "") - set(oneValueArgs "") - set(multiValueArgs ARGS) - cmake_parse_arguments(inference_float16_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - - set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests) - set(arg_list "") - if(inference_float16_test_ARGS) - foreach(arg ${inference_float16_test_ARGS}) - list(APPEND arg_list "_${arg}") - endforeach() - else() - list(APPEND arg_list "_") - endif() - foreach(arg ${arg_list}) - string(REGEX REPLACE "^_$" "" arg "${arg}") - cc_test(test_inference_float16_${TARGET_NAME}${arg} - SRCS test_inference_${TARGET_NAME}.cc - DEPS ARCHIVE_START paddle_fluid ARCHIVE_END - ARGS --dirname=${PYTHON_TESTS_DIR}/book_float16/float16_${TARGET_NAME}${arg}.inference.model --use_float16=true) - set_tests_properties(test_inference_float16_${TARGET_NAME}${arg} - PROPERTIES DEPENDS test_float16_${TARGET_NAME}) - endforeach() -endfunction(inference_float16_test) - inference_test(fit_a_line) inference_test(image_classification ARGS vgg resnet) -inference_float16_test(image_classification ARGS vgg resnet) +inference_test(float16_image_classification ARGS vgg resnet) inference_test(label_semantic_roles) inference_test(recognize_digits ARGS mlp conv) inference_test(recommender_system) From db5d0a63ce3586649deb413d1542c32752719738 Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Sat, 14 Apr 2018 14:49:38 -0700 Subject: [PATCH 7/7] fix const issue --- paddle/fluid/framework/executor.cc | 18 ++++++++++-------- paddle/fluid/framework/executor.h | 11 ++++++----- paddle/fluid/pybind/pybind.cc | 6 +++--- 3 files changed, 19 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index a5af25368bb60..6b1fb36010999 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -94,7 +94,7 @@ static void CheckTensorNANOrInf(const std::string& name, } void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope, - int block_id) { + int block_id) const { auto& global_block = pdesc.Block(block_id); const Scope* ancestor_scope = scope; @@ -131,7 +131,7 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope, } void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, - bool create_local_scope, bool create_vars) { + bool create_local_scope, bool create_vars) const { platform::RecordBlock b(block_id); auto ctx = Prepare(pdesc, block_id); RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars); @@ -229,7 +229,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, const std::map& feed_targets, const std::map& fetch_targets, bool create_vars, const std::string& feed_holder_name, - const std::string& fetch_holder_name) { + const std::string& fetch_holder_name) const { platform::RecordBlock b(kProgramId); bool has_feed_ops = has_feed_operators(program.Block(0), feed_targets, feed_holder_name); @@ -321,7 +321,8 @@ std::vector> Executor::Prepare( } void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, - bool create_local_scope, bool create_vars) { + bool create_local_scope, + bool create_vars) const { Scope* local_scope = scope; if (create_vars) { if (create_local_scope) { @@ -363,7 +364,8 @@ void Executor::RunPreparedContext( ExecutorPrepareContext* ctx, Scope* scope, const std::map& feed_targets, const std::map& fetch_targets, bool create_vars, - const std::string& feed_holder_name, const std::string& fetch_holder_name) { + const std::string& feed_holder_name, + const std::string& fetch_holder_name) const { auto& global_block = ctx->prog_.Block(ctx->block_id_); PADDLE_ENFORCE( @@ -378,8 +380,8 @@ void Executor::RunPreparedContext( if (op->Type() == kFeedOpType) { std::string feed_target_name = op->Output("Out")[0]; int idx = boost::get(op->GetAttr("col")); - SetFeedVariable(scope, *feed_targets[feed_target_name], feed_holder_name, - idx); + SetFeedVariable(scope, *feed_targets.at(feed_target_name), + feed_holder_name, idx); } } @@ -390,7 +392,7 @@ void Executor::RunPreparedContext( if (op->Type() == kFetchOpType) { std::string fetch_target_name = op->Input("X")[0]; int idx = boost::get(op->GetAttr("col")); - *fetch_targets[fetch_target_name] = + *fetch_targets.at(fetch_target_name) = GetFetchVariable(*scope, fetch_holder_name, idx); } } diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index 921131c196b29..d70306b0cc0b1 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -52,14 +52,14 @@ class Executor { * Scope */ void Run(const ProgramDesc& prog, Scope* scope, int block_id, - bool create_local_scope = true, bool create_vars = true); + bool create_local_scope = true, bool create_vars = true) const; void Run(const ProgramDesc& program, Scope* scope, const std::map& feed_targets, const std::map& fetch_targets, bool create_vars = true, const std::string& feed_holder_name = "feed", - const std::string& fetch_holder_name = "fetch"); + const std::string& fetch_holder_name = "fetch") const; static std::unique_ptr Prepare( const ProgramDesc& program, int block_id); @@ -67,18 +67,19 @@ class Executor { static std::vector> Prepare( const ProgramDesc& program, const std::vector& block_ids); - void CreateVariables(const ProgramDesc& pdesc, Scope* scope, int block_id); + void CreateVariables(const ProgramDesc& pdesc, Scope* scope, + int block_id) const; void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, bool create_local_scope = true, - bool create_vars = true); + bool create_vars = true) const; void RunPreparedContext( ExecutorPrepareContext* ctx, Scope* scope, const std::map& feed_targets, const std::map& fetch_targets, bool create_vars = true, const std::string& feed_holder_name = "feed", - const std::string& fetch_holder_name = "fetch"); + const std::string& fetch_holder_name = "fetch") const; private: const platform::Place place_; diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index a1e8ff6399f08..869b577bfde9b 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -409,9 +409,9 @@ All parameter, weight, gradient are variables in Paddle. py::class_(m, "Executor") .def(py::init()) - .def("run", - (void (Executor::*)(const ProgramDesc &, Scope *, int, bool, bool)) & - Executor::Run); + .def("run", (void (Executor::*)(const ProgramDesc &, Scope *, int, bool, + bool) const) & + Executor::Run); m.def("init_gflags", framework::InitGflags); m.def("init_glog", framework::InitGLOG);