-
Notifications
You must be signed in to change notification settings - Fork 373
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feature(nyz): add policy gradient algo implementation (#544)
* feature(nyz): add policy gradient algo implementation * demo(nyz): add lunarlander pg demo * style(nyz): add pg link in readme * fix(nyz): fix config conflict with data generation * fix(nyz): fix action space error type in model
- Loading branch information
Showing
19 changed files
with
530 additions
and
69 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from .head import DiscreteHead, DuelingHead, DistributionHead, RainbowHead, QRDQNHead, \ | ||
QuantileHead, FQFHead, RegressionHead, ReparameterizationHead, MultiHead, head_cls_map | ||
QuantileHead, FQFHead, RegressionHead, ReparameterizationHead, MultiHead, head_cls_map, \ | ||
independent_normal_dist | ||
from .encoder import ConvEncoder, FCEncoder, IMPALAConvEncoder | ||
from .utils import create_model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from typing import Union, Optional, Dict, Callable, List | ||
import torch | ||
import torch.nn as nn | ||
from easydict import EasyDict | ||
|
||
from ding.torch_utils import get_lstm | ||
from ding.utils import MODEL_REGISTRY, SequenceType, squeeze | ||
from ..common import FCEncoder, ConvEncoder, DiscreteHead, DuelingHead, \ | ||
MultiHead, RegressionHead, ReparameterizationHead, independent_normal_dist | ||
|
||
|
||
@MODEL_REGISTRY.register('pg') | ||
class PG(nn.Module): | ||
|
||
def __init__( | ||
self, | ||
obs_shape: Union[int, SequenceType], | ||
action_shape: Union[int, SequenceType], | ||
action_space: str = 'discrete', | ||
encoder_hidden_size_list: SequenceType = [128, 128, 64], | ||
head_hidden_size: Optional[int] = None, | ||
head_layer_num: int = 1, | ||
activation: Optional[nn.Module] = nn.ReLU(), | ||
norm_type: Optional[str] = None | ||
) -> None: | ||
super(PG, self).__init__() | ||
# For compatibility: 1, (1, ), [4, 32, 32] | ||
obs_shape, action_shape = squeeze(obs_shape), squeeze(action_shape) | ||
if head_hidden_size is None: | ||
head_hidden_size = encoder_hidden_size_list[-1] | ||
# FC Encoder | ||
if isinstance(obs_shape, int) or len(obs_shape) == 1: | ||
self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type) | ||
# Conv Encoder | ||
elif len(obs_shape) == 3: | ||
self.encoder = ConvEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type) | ||
else: | ||
raise RuntimeError( | ||
"not support obs_shape for pre-defined encoder: {}, please customize your own BC".format(obs_shape) | ||
) | ||
self.action_space = action_space | ||
# Head | ||
if self.action_space == 'discrete': | ||
self.head = DiscreteHead( | ||
head_hidden_size, action_shape, head_layer_num, activation=activation, norm_type=norm_type | ||
) | ||
elif self.action_space == 'continuous': | ||
self.head = ReparameterizationHead( | ||
head_hidden_size, | ||
action_shape, | ||
head_layer_num, | ||
activation=activation, | ||
norm_type=norm_type, | ||
sigma_type='independent' | ||
) | ||
else: | ||
raise KeyError("not support action space: {}".format(self.action_space)) | ||
|
||
def forward(self, x: torch.Tensor) -> Dict: | ||
x = self.encoder(x) | ||
x = self.head(x) | ||
if self.action_space == 'discrete': | ||
x['dist'] = torch.distributions.Categorical(logits=x['logit']) | ||
elif self.action_space == 'continuous': | ||
x = {'logit': {'mu': x['mu'], 'sigma': x['sigma']}} | ||
x['dist'] = independent_normal_dist(x['logit']) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import torch | ||
import numpy as np | ||
import pytest | ||
from itertools import product | ||
|
||
from ding.model.template import PG | ||
from ding.torch_utils import is_differentiable | ||
from ding.utils import squeeze | ||
|
||
B = 4 | ||
|
||
|
||
@pytest.mark.unittest | ||
class TestDiscretePG: | ||
|
||
def output_check(self, model, outputs): | ||
if isinstance(outputs, torch.Tensor): | ||
loss = outputs.sum() | ||
elif isinstance(outputs, list): | ||
loss = sum([t.sum() for t in outputs]) | ||
elif isinstance(outputs, dict): | ||
loss = sum([v.sum() for v in outputs.values()]) | ||
is_differentiable(loss, model) | ||
|
||
def test_discrete_pg(self): | ||
obs_shape = (4, 84, 84) | ||
action_shape = 5 | ||
model = PG( | ||
obs_shape, | ||
action_shape, | ||
) | ||
inputs = torch.randn(B, 4, 84, 84) | ||
|
||
outputs = model(inputs) | ||
assert isinstance(outputs, dict) | ||
assert outputs['logit'].shape == (B, action_shape) | ||
assert outputs['dist'].sample().shape == (B, ) | ||
self.output_check(model, outputs['logit']) | ||
|
||
def test_continuous_pg(self): | ||
N = 32 | ||
action_shape = (6, ) | ||
inputs = {'obs': torch.randn(B, N), 'action': torch.randn(B, squeeze(action_shape))} | ||
model = PG( | ||
obs_shape=(N, ), | ||
action_shape=action_shape, | ||
action_space='continuous', | ||
) | ||
# compute_action | ||
print(model) | ||
outputs = model(inputs['obs']) | ||
assert isinstance(outputs, dict) | ||
dist = outputs['dist'] | ||
action = dist.sample() | ||
assert action.shape == (B, *action_shape) | ||
|
||
logit = outputs['logit'] | ||
mu, sigma = logit['mu'], logit['sigma'] | ||
assert mu.shape == (B, *action_shape) | ||
assert sigma.shape == (B, *action_shape) | ||
is_differentiable(mu.sum() + sigma.sum(), model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.