Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[new-exec] Refine standalone executor #37278

Merged
merged 5 commits into from
Nov 17, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion paddle/fluid/framework/new_executor/interpretercore.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,24 @@ paddle::framework::FetchList InterpreterCore::Run(
return *(fetch_var->GetMutable<framework::FetchList>());
}

paddle::framework::FetchList InterpreterCore::Run() {
if (!is_build_) {
paddle::framework::interpreter::build_variable_scope(block_, global_scope_);
std::vector<paddle::framework::OpFuncNode> op_func_nodes;
paddle::framework::interpreter::build_op_func_list(
place_, block_, &op_func_nodes, global_scope_);
is_build_ = true;
// convert vec func_list to graph
Convert(&op_func_nodes);
} else {
ExecuteInstructionList(vec_instruction_);
}

// return Fetch Tensors
auto* fetch_var = global_scope_->Var(interpreter::kFetchVarName);
return *(fetch_var->GetMutable<framework::FetchList>());
}

void InterpreterCore::BuildOperatorDependences() {
// analysis the dependences between ops, set the dependecy_count_ and Call
// Schedule
Expand All @@ -94,6 +112,7 @@ void InterpreterCore::BuildOperatorDependences() {
}
}


void InterpreterCore::Convert(
std::vector<paddle::framework::OpFuncNode>* op_func_nodes) {
auto& vec_meta_info = global_scope_->MutableVecMetaInfo();
Expand Down Expand Up @@ -505,6 +524,7 @@ void InterpreterCore::Prepare(
feed_names.size(), feed_tensors.size()));

auto FeedInput = [&] {
VLOG(4) << "Feed inputs";
for (size_t i = 0; i < feed_names.size(); ++i) {
auto* feed_var = global_scope_->FindVar(feed_names[i]);
PADDLE_ENFORCE_NOT_NULL(feed_var, platform::errors::NotFound(
Expand All @@ -529,7 +549,9 @@ void InterpreterCore::Prepare(
// NOTE: Because feed_tensor will be GC after
// paddle::framework::build_op_func_list, so we should
// call FeedInput again.
if (prepare_feed) FeedInput();
if (prepare_feed) {
FeedInput();
}
}

interpreter::CostInfo InterpreterCore::DryRun(
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/new_executor/interpretercore.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ class InterpreterCore {
const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors);

paddle::framework::FetchList Run();

interpreter::CostInfo DryRun(
const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors);
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/framework/new_executor/standalone_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ paddle::framework::FetchList StandaloneExecutor::Run(
return core->Run(feed_names, feed_tensors);
}

paddle::framework::FetchList StandaloneExecutor::Run(
const std::vector<std::string>& feed_names,
const std::vector<std::string>& fetch_names) {
auto core = GetInterpreterCore(feed_names, fetch_names);

return core->Run();
}

framework::interpreter::CostInfo StandaloneExecutor::DryRun(
const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors) {
Expand Down
8 changes: 7 additions & 1 deletion paddle/fluid/framework/new_executor/standalone_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,17 @@ class StandaloneExecutor : public ExecutorBase {

~StandaloneExecutor() {}

virtual paddle::framework::FetchList Run(
paddle::framework::FetchList Run(
const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors,
const std::vector<std::string>& fetch_names);

// NOTE(zhiqiu): feed_names are only used for caching interpretercore.
// fetch_names are used for caching interpretercore and inserting fetch ops,
// the latter can be moved to python side.
paddle::framework::FetchList Run(const std::vector<std::string>& feed_names,
const std::vector<std::string>& fetch_names);

framework::interpreter::CostInfo DryRun(
const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors);
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2123,6 +2123,16 @@ All parameter, weight, gradient are variables in Paddle.
}
return py::cast(std::move(ret));
})
.def("run",
[](StandaloneExecutor &self, std::vector<std::string> feed_names,
std::vector<std::string> fetch_names) {
paddle::framework::FetchList ret;
{
pybind11::gil_scoped_release release;
ret = self.Run(feed_names, fetch_names);
}
return py::cast(std::move(ret));
})
.def("dry_run",
[](StandaloneExecutor &self,
const std::unordered_map<std::string, py::array> &input_dict) {
Expand Down
58 changes: 40 additions & 18 deletions python/paddle/fluid/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,29 +493,20 @@ def __init__(self, place, main_program, scope):
self._scope = scope
self._new_exe = self._create_new_executor()

def run(self, feed, fetch_list, return_numpy=True):
def run(self, feed_names, fetch_list, return_numpy=True):
"""
Args:
feed(list|dict): This parameter represents the input Tensors of the model.
If it is single card training, the feed is dict type, and if it is multi-card
training, the parameter feed can be dict or list of Tensors. If the
parameter type is dict, the data in the feed will be split and sent to
multiple devices (CPU/GPU), that is to say, the input data will be evenly
sent to different devices, so you should make sure the number of samples of
the current mini-batch must be greater than the number of places;
if the parameter type is list, those data are copied directly to each device,
so the length of this list should be equal to the number of places.
The default is None.
feed_names(list): This parameter represents the input names of the model.
fetch_list(list): This parameter represents the Tensors that need to be returned
after the model runs. The default is None.
return_numpy(bool): This parameter indicates whether convert the fetched Tensors
(the Tensor specified in the fetch list) to numpy.ndarray. if it is False,
the type of the return value is a list of :code:`LoDTensor`. The default is True.
"""
feed = self._update_feed(feed)
# feed = self._update_feed(feed)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注释的代码后续可以考虑remove掉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

fetch_list = self._check_fetch(fetch_list)

tensors = self._new_exe.run(feed, fetch_list)._move_to_list()
tensors = self._new_exe.run(feed_names, fetch_list)._move_to_list()
if return_numpy:
return as_numpy(tensors, copy=True)
else:
Expand Down Expand Up @@ -598,9 +589,9 @@ def _get_exe_from_cache(self, program, scope):
assert isinstance(
program, Program), "Required type(Program), but received {}".format(
type(program).__name__)

if str(program) not in self._cached_executors:
new_program = program.clone()
_prune_feed_ops(new_program)
new_exe = _StandaloneExecutor(self._place, new_program, scope)
self._cached_executors[str(program)] = new_exe

Expand Down Expand Up @@ -744,8 +735,13 @@ def _add_trainer_cache(self, trainer_cache_key, ctx):
def _add_scope_cache(self, scope_cache_key, scope):
self.scope_caches[scope_cache_key] = scope

def _add_feed_fetch_ops(self, program, feed, fetch_list, feed_var_name,
fetch_var_name):
def _add_feed_fetch_ops(self,
program,
feed,
fetch_list,
feed_var_name,
fetch_var_name,
skip_fetch=False):
tmp_program = program.clone()

global_block = tmp_program.global_block()
Expand Down Expand Up @@ -780,6 +776,9 @@ def _add_feed_fetch_ops(self, program, feed, fetch_list, feed_var_name,
warnings.warn(
"The variable %s is not found in program. It is not declared or is pruned."
% name)
if skip_fetch:
return tmp_program

# append fetch_operators
if not has_fetch_operators(global_block, fetch_list, fetch_var_name):
for i, var in enumerate(fetch_list):
Expand Down Expand Up @@ -1325,8 +1324,31 @@ def _run_impl(self, program, feed, fetch_list, feed_var_name,
program, compiler.CompiledProgram) else program
assert isinstance(inner_program_, framework.Program)
if not inner_program_._is_start_up_program_:
return self._executor_cache.run(inner_program_, scope, feed,
fetch_list, return_numpy)
if feed is None:
feed = {}
program = self._add_feed_fetch_ops(
program=inner_program_,
feed=feed,
fetch_list=fetch_list,
feed_var_name=feed_var_name,
fetch_var_name=fetch_var_name,
skip_fetch=True)
self._feed_data(program, feed, feed_var_name, scope)
if hasattr(program, 'lr_sheduler'):
assert isinstance(program.lr_sheduler,
LRScheduler), "must be LRScheduler"
lr_sheduler = program.lr_sheduler
lr_value = lr_sheduler()
lr_var = program.global_block().vars[lr_sheduler._var_name]
data = np.array(
[lr_value]).astype(convert_dtype(lr_var.dtype))
tensor = core.get_variable_tensor(scope,
lr_sheduler._var_name)
tensor.set(data, self.place)

return self._executor_cache.run(program, scope,
list(feed.keys()), fetch_list,
return_numpy)

# use_prune can be overrided by putting optimize_ops in fetch_list
_origin_fetch_list = fetch_list
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ def test_with_error(self):
class TestException(unittest.TestCase):
def setUp(self):
self.place = paddle.CPUPlace()
self.fetch_vars = None

def build_program(self):
main_program = paddle.static.Program()
Expand All @@ -276,6 +277,7 @@ def _run(self, feeds):
for feed in feeds:
out = exe.run(main_program, feed=feed, fetch_list=fetch_vars)
print(main_program)
self.fetch_vars = fetch_vars
return out

def run_new_executor(self, feed):
Expand Down Expand Up @@ -317,7 +319,7 @@ def test_scope(self):
}]
self.run_new_executor(feed)
self.assertIsNotNone(paddle.static.global_scope().find_var(
'embedding.tmp_2'))
self.fetch_vars.name))


if __name__ == "__main__":
Expand Down