forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Simplify mxnet.gluon Block APIs (apache#18413)
## Motivations Currently the implementation of mxnet.gluon.block is not so pythonic and there are many redundancies ### 1. overlaps between Block._params and Block._reg_params when we want to self-define a model, we currently need to use the code as follows: ``` class Net(nn.HybridBlock): def __init__(self, **kwargs): super(HybridNet, self).__init__(**kwargs) with self.name_scope(): self.hidden1 = nn.Dense(256, activation='relu') self.a=self.params.get('a', shape=(1, )) ``` There are several shortcomings when using this form of registration: a. adding parameter ‘a’ will lead to double recordings in both self._params and self._reg_params, which is a redundancy. And there is also a discrepancy in Block: i. In the method “collect_params”, we use “_params” to get all parameters ii. while in the method “_collect_params_with_prefix” (and methods “load_parameters” accordingly), we use “_reg_params” to get all parameters. b. Currently if we do not use “with self.name_scope():” for children blocks, it will lead to wrong name scopes. For the following example, we actually can not get the parameters of self.hidden1 from the result of collect_params ``` class HybridNet(nn.HybridBlock): def __init__(self, **kwargs): super(HybridNet, self).__init__(**kwargs) self.hidden1 = nn.Dense(256, activation='relu') with self.name_scope(): self.hidden2 = nn.Dense(10, activation='relu') def hybrid_forward(self, F, x): x = self.hidden2(self.hidden1(x)) return x >>> net = HybridNet() >>> net.initialize() >>> print(net.collect_params()) hybridnet0_ ( Parameter dense0_weight (shape=(256, -1), dtype=float32) Parameter dense0_bias (shape=(256,), dtype=float32) Parameter hybridnet0_dense0_weight (shape=(10, -1), dtype=float32) Parameter hybridnet0_dense0_bias (shape=(10,), dtype=float32) ) ``` From the above example we can also find that the parameter names are not related to the attributes’ names, which is not straightforward. In all, we find that using name_scope and ParameterDict is not user-friendly. Thus we plan to remove such redundancies and simplify the definitions of children blocks and parameters, like: ``` class Net(nn.HybridBlock): def __init__(self, **kwargs): super(HybridNet, self).__init__(**kwargs) self.hidden1 = nn.Dense(256, activation='relu') self.a=gluon.parameter.Parameter(name="a", shape=(1, )) ``` ### 2. parameter sharing Currently, we use parameter “params” in the definition of Block for parameter sharing. It means before the __init__ of Block, shared parameters already recorded in self._params.shared. And currently Block forbids overriding parameters. We think that this is not convenient. A most common way to share parameter is like what Pytorch does, like ``` self.hidden1.weight=self.hidden2.weight ``` But note that in the case where we have a HybridBlock and the block has been hybridized, then we shouldn't allow overriding the parameter but ask the user to unhybridize the Block first. To further allow sharing parameters recursively, we plan to add an API: ``` def share_parameters(self, params : Dict): ``` We plan to use the structured based form (like what is used in “_collect_params_with_prefix()”) to represent each parameter recursively. For example, we denote “self.hidden1.weight” as “hidden_weight” In all, we plan to make the following improvements: 1. remove parameters “prefix” and “params” in the “\_\_init\_\_" function. 2. remove the use of self._params(ParameterDict) in Block 3. allow parameter attribute overriding in non-hydridization case. 4. add the method “share_parameters" to recursively share parameters in children blocks. ## Parameter naming Once a parameter is created, `param.name` would not be changed in the following operations. It is in the form of `param_{uuid4}_{name}`, where `name` is from `__init __` parameter. Here `name` is optional, default `weight`. It is mainly used to denote which default initialization should be used. We use `param.name` as the name of a parameter's symbol representation. ## collect_params() It returns a `dict`, where the keys are structural names of parameters, like `{'hidden1.weight': Parameter (shape=(3, -1), dtype=float32), 'hidden1.bias': Parameter (shape=(3,), dtype=float32)}` Note that we use `.` as the linking character again because the structured based naming scheme is no longer used in the symbol representation. ## Save and Load For `HybridBlock`, there are two ways to save and load parameters: ### save_parameters() and load_parameters() In `save_parameters()`, we use `structural name` to save parameters, and they should be loaded by `load_parameters()`, which loads parameters based on a model's structure. ### HybridBlock.export and SymbolBlock.imports In `export`, we only save parameters using `param.name` without `structural name`. The param file should be loaded in SymbolBlock.imports. ## SymbolBlock When using `SymbolBlock.imports`, keys in `self.param` would be the loaded parameters' names `param.name`. While in `SymbolBlock(outputs, inputs, params=None)`, if you provide like `params=net.collect_params()`, keys in `self.param` would be structural names of `net`'s parameters (keys in net.collect_params() ). It is often used in this situation that a `SymbolBlock` is a children block of another `HybridBlock`. Otherwise, keys in `self.param` would be the loaded parameters' names `param.name`.
- Loading branch information
Showing
54 changed files
with
1,746 additions
and
2,482 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.