-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Revise python save load api using new load/save op #7995
Conversation
python/paddle/v2/fluid/io.py
Outdated
main_program=None, | ||
vars=None, | ||
predicate=None, | ||
save_file_name='__parameters__'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest setting the default value of save_filename
to None
. If None
, then all variables will be saved into separate files as before. I am not sure if it is suitable to change the storing format of training results. So maybe it is better to enable two formats?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. Will do.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
python/paddle/v2/fluid/io.py
Outdated
main_program=None, | ||
vars=None, | ||
predicate=None, | ||
load_file_name='__parameters__'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The same as save_vars
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
python/paddle/v2/fluid/io.py
Outdated
vars=None, | ||
predicate=_is_presistable_and_exist_) | ||
parameter_list = get_parameters(inference_program) | ||
save_vars(executor, dirname, inference_program, parameter_list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#7874 is merged. Directly change of the save_inference_model
will fail the CI of the develop branch. Please update the develop branch first.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. Thanks!
python/paddle/v2/fluid/io.py
Outdated
|
||
for var in program.list_vars(): | ||
if is_persistable(var) and var.name in input_args: | ||
parameter_list.append(var) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is actually not a parameter list but a persistable variable list. Normally, the program should not contain unreferenced variables, so if var.name in input_args
should be removed. When loading, if a persistable variable is absent, there should be some error message.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation. Then I guess we don't need to define this function. There is already a save_persistable method, so I will use that one instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But I think we still need to exclude 'feed' and 'fetch' variables right (because they have been added to the program desc)? They are also persistable and we don't want to store them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to remove var.name in input_args
, then 'feed' and 'fetch' variables are also included, which incurs an error "The type of var fetch is unsupported" since the type of feed/fetch is vector<lodTensor>
and is not supported by load / save op.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this implementation can potentially solve the problem described in PR #8020
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will try including these changes to actually verify.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But I think we still need to exclude 'feed' and 'fetch' variables right
@kexinzhao Can we change the function is_persistable()
or rewrite load/save_persistables()
to exclude feed
and fetch
variables?
I think this implementation can potentially solve the problem described in PR #8020
@sidgoyal78 I think the problem in #8020 is that, means
of bn is not parameter but persistable variable. In fact, we should save all persistable variables in save_inference_model
, not only parameters. I think about this for a long time, and there are some issues for this: #7931 #7163
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah. 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we change the function is_persistable() or rewrite load/save_persistables() to exclude feed and fetch variables?
I think of three ways to redefine load/save_persistables():
- One is like this:
def load_persitables(xxx):
parameter_list = get_parameters(program)
load_var(xxx, parameter_list)
which basically moved the usage of get_parameters
inside save_load_persistables
.
You can also get rid of the get_parameters
method, move its code directly into save/load_persistable, but we prefer not to do this because of duplicate code in two functions.
- If we don't want to use the code in
get_parameters()
method, then basically we want to modify theis_persistable(var)
predicate function so that it can exclude 'feed' and 'fetch' vars.
Note that since this is a predicate, we only have the variable as input.
Although as show in framework.py thatclass Variable
has a data memberop
, thisop
will only be set to the operator that output this variable. Meaning that if a variable is not an output of any operator (e.g.,feed
and weights parameters), this var.op == None.
So we cannot use code like below to exclude 'feed' and 'fetch'
def is_persistable(var):
if var.op.desc.type() == 'feed' or var.op.desc.type() == 'fetch':
return false
return var.persistable
- Just like we define the feed/fetch operator type to be fixed as 'feed'/'fetch', we can also fix the name of the feed/fetch variable to be 'feed'/'fetch' (or some better names). This means that we can get rid of the API in the current Inference Design that optionally allows user to provide its own 'feed_holder_name' and 'fetch_holder_name'. For this design, we can simply modify the
is_persistable
as follows:
def is_persistable(var):
if var.desc.name() == 'feed' or var.desc.name() == 'fetch':
return false
return var.persistable
If we want to go with this option, we can firstly do a quick fix in this pr using the code above. Then fix the feed/fetch var name, modify API accordingly, set some global const kFeedVarName in C++ and pybind it to python, etc in the future PR.
@Xreki @luotao1 @sidgoyal78, which option do your prefer or do you have other suggestions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like the 2nd method. But I am not sure whether it is suitable to use var.op.desc.type()
. In the C++ definition of VarDesc
, there is not a member to record the belonged op. Also, a variable may be shared among multiple op.
For the 3rd method. I think it is not suitable to use the name, but may be we can use the type, which should be FEED_MINIBATCH
.
python/paddle/v2/fluid/io.py
Outdated
load_var_map = {} | ||
for each_var in vars: | ||
assert isinstance(each_var, Variable) | ||
new_var = _clone_var_in_block_(load_block, each_var) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move the common codes line 202 - 204
out of the if
statement?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
python/paddle/v2/fluid/io.py
Outdated
|
||
for var in program.list_vars(): | ||
if is_persistable(var) and var.name in input_args: | ||
parameter_list.append(var) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But I think we still need to exclude 'feed' and 'fetch' variables right
@kexinzhao Can we change the function is_persistable()
or rewrite load/save_persistables()
to exclude feed
and fetch
variables?
I think this implementation can potentially solve the problem described in PR #8020
@sidgoyal78 I think the problem in #8020 is that, means
of bn is not parameter but persistable variable. In fact, we should save all persistable variables in save_inference_model
, not only parameters. I think about this for a long time, and there are some issues for this: #7931 #7163
python/paddle/v2/fluid/io.py
Outdated
|
||
load_vars( | ||
parameter_list = get_parameters(inference_program) | ||
save_vars( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can try to call save_persistables
here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
python/paddle/v2/fluid/io.py
Outdated
@@ -342,7 +410,13 @@ def load_inference_model(dirname, executor): | |||
program_desc_str = f.read() | |||
|
|||
program = Program.parse_from_string(program_desc_str) | |||
load_persistables_if_exist(executor, dirname, program) | |||
parameter_list = get_parameters(program) | |||
load_vars( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can also try to call load_persistables
here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -46,6 +46,9 @@ def is_parameter(var): | |||
|
|||
|
|||
def is_persistable(var): | |||
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ | |||
var.desc.type() == core.VarDesc.VarType.FETCH_LIST: | |||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Xreki I have changed the code and go with option 3 using your suggestion. For option 2, there is problem. Because in the python side of the code, the operator op
field of var will only be associated with the operator that have this variable as its output. So for feed variable, since it is not the output of any operator. Its op
data member will be None
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. But I am thinking about the what the interface load_inference_model
should be. I mean, the argument list.
fix #7959
op's info is not well synchronized on the python side. So I have to use the OpDesc info. input_arg_names() is defined in protobuf.cc as a binding to InputArguments() method in the c++ OpDesc class.