Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

polish(zym): optimize ppo continuous act #801

Merged
merged 2 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 17 additions & 12 deletions ding/model/template/vac.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ def __init__(
``ReparameterizationHead``, and hybrid heads.
- share_encoder (:obj:`bool`): Whether to share observation encoders between actor and decoder.
- encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \
the last element must match ``head_hidden_size``.
the last element is used as the input size of ``actor_head`` and ``critic_head``.
- actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of ``actor_head`` network, defaults \
to 64, it must match the last element of ``encoder_hidden_size_list``.
to 64, it is the hidden size of the last layer of the ``actor_head`` network.
- actor_head_layer_num (:obj:`int`): The num of layers used in the ``actor_head`` network to compute action.
- critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of ``critic_head`` network, defaults \
to 64, it must match the last element of ``encoder_hidden_size_list``.
to 64, it is the hidden size of the last layer of the ``critic_head`` network.
- critic_head_layer_num (:obj:`int`): The num of layers used in the ``critic_head`` network.
- activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \
if ``None`` then default set it to ``nn.ReLU()``.
Expand Down Expand Up @@ -108,15 +108,13 @@ def new_encoder(outsize, activation):
)

if self.share_encoder:
assert actor_head_hidden_size == critic_head_hidden_size, \
"actor and critic network head should have same size."
if encoder:
if isinstance(encoder, torch.nn.Module):
self.encoder = encoder
else:
raise ValueError("illegal encoder instance.")
else:
self.encoder = new_encoder(actor_head_hidden_size, activation)
self.encoder = new_encoder(encoder_hidden_size_list[-1], activation)
else:
if encoder:
if isinstance(encoder, torch.nn.Module):
Expand All @@ -125,25 +123,31 @@ def new_encoder(outsize, activation):
else:
raise ValueError("illegal encoder instance.")
else:
self.actor_encoder = new_encoder(actor_head_hidden_size, activation)
self.critic_encoder = new_encoder(critic_head_hidden_size, activation)
self.actor_encoder = new_encoder(encoder_hidden_size_list[-1], activation)
self.critic_encoder = new_encoder(encoder_hidden_size_list[-1], activation)

# Head Type
self.critic_head = RegressionHead(
critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type
encoder_hidden_size_list[-1],
1,
critic_head_layer_num,
activation=activation,
norm_type=norm_type,
hidden_size=critic_head_hidden_size
)
self.action_space = action_space
assert self.action_space in ['discrete', 'continuous', 'hybrid'], self.action_space
if self.action_space == 'continuous':
self.multi_head = False
self.actor_head = ReparameterizationHead(
actor_head_hidden_size,
encoder_hidden_size_list[-1],
action_shape,
actor_head_layer_num,
sigma_type=sigma_type,
activation=activation,
norm_type=norm_type,
bound_type=bound_type
bound_type=bound_type,
hidden_size=actor_head_hidden_size,
)
elif self.action_space == 'discrete':
actor_head_cls = DiscreteHead
Expand Down Expand Up @@ -172,14 +176,15 @@ def new_encoder(outsize, activation):
action_shape.action_args_shape = squeeze(action_shape.action_args_shape)
action_shape.action_type_shape = squeeze(action_shape.action_type_shape)
actor_action_args = ReparameterizationHead(
actor_head_hidden_size,
encoder_hidden_size_list[-1],
action_shape.action_args_shape,
actor_head_layer_num,
sigma_type=sigma_type,
fixed_sigma_value=fixed_sigma_value,
activation=activation,
norm_type=norm_type,
bound_type=bound_type,
hidden_size=actor_head_hidden_size,
)
actor_action_type = DiscreteHead(
actor_head_hidden_size,
Expand Down
26 changes: 25 additions & 1 deletion ding/policy/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ class PPOPolicy(Policy):
batch_size=64,
# (float) The step size of gradient descent.
learning_rate=3e-4,
# (dict or None) The learning rate decay.
# If not None, should contain key 'epoch_num' and 'min_lr_lambda'.
# where 'epoch_num' is the total epoch num to decay the learning rate to min value,
# 'min_lr_lambda' is the final decayed learning rate.
lr_scheduler=None,
# (float) The loss weight of value network, policy network weight is set to 1.
value_weight=0.5,
# (float) The loss weight of entropy regularization, policy network weight is set to 1.
Expand Down Expand Up @@ -169,6 +174,16 @@ def _init_learn(self) -> None:
clip_value=self._cfg.learn.grad_clip_value
)

# Define linear lr scheduler
if self._cfg.learn.lr_scheduler is not None:
epoch_num = self._cfg.learn.lr_scheduler['epoch_num']
min_lr_lambda = self._cfg.learn.lr_scheduler['min_lr_lambda']

self._lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
self._optimizer,
lr_lambda=lambda epoch: max(1.0 - epoch * (1.0 - min_lr_lambda) / epoch_num, min_lr_lambda)
)

self._learn_model = model_wrap(self._model, wrapper_name='base')

# Algorithm config
Expand Down Expand Up @@ -314,8 +329,13 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
total_loss.backward()
self._optimizer.step()

if self._cfg.learn.lr_scheduler is not None:
cur_lr = sum(self._lr_scheduler.get_last_lr()) / len(self._lr_scheduler.get_last_lr())
else:
cur_lr = self._optimizer.defaults['lr']

return_info = {
'cur_lr': self._optimizer.defaults['lr'],
'cur_lr': cur_lr,
'total_loss': total_loss.item(),
'policy_loss': ppo_loss.policy_loss.item(),
'value_loss': ppo_loss.value_loss.item(),
Expand All @@ -336,6 +356,10 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
}
)
return_infos.append(return_info)

if self._cfg.learn.lr_scheduler is not None:
self._lr_scheduler.step()

return return_infos

def _init_collect(self) -> None:
Expand Down
14 changes: 12 additions & 2 deletions dizoo/mujoco/config/ant_onppo_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from easydict import EasyDict
import torch.nn as nn

ant_ppo_config = dict(
exp_name="ant_onppo_seed0",
Expand All @@ -17,15 +18,24 @@
recompute_adv=True,
action_space='continuous',
model=dict(
encoder_hidden_size_list=[128, 128],
action_space='continuous',
obs_shape=111,
action_shape=8,
share_encoder=False,
actor_head_layer_num=0,
critic_head_layer_num=2,
critic_head_hidden_size=256,
actor_head_hidden_size=128,
activation=nn.Tanh(),
bound_type='tanh',
),
learn=dict(
epoch_per_collect=10,
update_per_collect=1,
batch_size=320,
batch_size=128,
learning_rate=3e-4,
lr_scheduler=dict(epoch_num=1500, min_lr_lambda=0),
value_weight=0.5,
entropy_weight=0.001,
clip_ratio=0.2,
Expand All @@ -39,7 +49,7 @@
grad_clip_value=0.5,
),
collect=dict(
n_sample=3200,
n_sample=2048,
unroll_len=1,
discount_factor=0.99,
gae_lambda=0.95,
Expand Down
22 changes: 16 additions & 6 deletions dizoo/mujoco/config/halfcheetah_onppo_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from easydict import EasyDict
import torch.nn as nn

collector_env_num = 1
evaluator_env_num = 1
collector_env_num = 8
evaluator_env_num = 8
halfcheetah_ppo_config = dict(
exp_name='halfcheetah_onppo_seed0',
env=dict(
Expand All @@ -10,23 +11,32 @@
norm_reward=dict(use_norm=False, ),
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
n_evaluator_episode=1,
n_evaluator_episode=8,
stop_value=12000,
),
policy=dict(
cuda=True,
recompute_adv=True,
action_space='continuous',
model=dict(
encoder_hidden_size_list=[128, 128],
action_space='continuous',
share_encoder=False,
actor_head_layer_num=0,
critic_head_layer_num=2,
critic_head_hidden_size=256,
actor_head_hidden_size=128,
obs_shape=17,
action_shape=6,
activation=nn.Tanh(),
bound_type='tanh',
),
learn=dict(
epoch_per_collect=10,
update_per_collect=1,
batch_size=320,
batch_size=128,
learning_rate=3e-4,
lr_scheduler=dict(epoch_num=1500, min_lr_lambda=0),
value_weight=0.5,
entropy_weight=0.001,
clip_ratio=0.2,
Expand All @@ -42,12 +52,12 @@
),
collect=dict(
collector_env_num=collector_env_num,
n_sample=3200,
n_sample=2048,
unroll_len=1,
discount_factor=0.99,
gae_lambda=0.95,
),
eval=dict(evaluator=dict(eval_freq=500, )),
eval=dict(evaluator=dict(eval_freq=5000, )),
),
)
halfcheetah_ppo_config = EasyDict(halfcheetah_ppo_config)
Expand Down
18 changes: 14 additions & 4 deletions dizoo/mujoco/config/hopper_onppo_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from easydict import EasyDict
import torch.nn as nn

hopper_onppo_config = dict(
exp_name='hopper_onppo_seed0',
Expand All @@ -12,19 +13,28 @@
stop_value=4000,
),
policy=dict(
cuda=True,
cuda=False,
recompute_adv=True,
action_space='continuous',
model=dict(
encoder_hidden_size_list=[128, 128],
obs_shape=11,
action_shape=3,
action_space='continuous',
share_encoder=False,
actor_head_layer_num=0,
critic_head_layer_num=2,
critic_head_hidden_size=256,
actor_head_hidden_size=128,
activation=nn.Tanh(),
bound_type='tanh',
),
learn=dict(
epoch_per_collect=10,
update_per_collect=1,
batch_size=320,
batch_size=128,
learning_rate=3e-4,
lr_scheduler=dict(epoch_num=1500, min_lr_lambda=0),
value_weight=0.5,
entropy_weight=0.001,
clip_ratio=0.2,
Expand All @@ -39,12 +49,12 @@
grad_clip_value=0.5,
),
collect=dict(
n_sample=3200,
n_sample=2048,
unroll_len=1,
discount_factor=0.99,
gae_lambda=0.95,
),
eval=dict(evaluator=dict(eval_freq=500, )),
eval=dict(evaluator=dict(eval_freq=5000, )),
),
)
hopper_onppo_config = EasyDict(hopper_onppo_config)
Expand Down
22 changes: 16 additions & 6 deletions dizoo/mujoco/config/walker2d_onppo_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from easydict import EasyDict
import torch.nn as nn

collector_env_num = 1
evaluator_env_num = 1
collector_env_num = 8
evaluator_env_num = 8
walker2d_onppo_config = dict(
exp_name='walker2d_onppo_seed0',
env=dict(
Expand All @@ -10,23 +11,32 @@
norm_reward=dict(use_norm=False, ),
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
n_evaluator_episode=10,
n_evaluator_episode=8,
stop_value=6000,
),
policy=dict(
cuda=True,
recompute_adv=True,
action_space='continuous',
model=dict(
encoder_hidden_size_list=[128, 128],
action_space='continuous',
share_encoder=False,
actor_head_layer_num=0,
critic_head_layer_num=2,
critic_head_hidden_size=256,
actor_head_hidden_size=128,
obs_shape=17,
action_shape=6,
activation=nn.Tanh(),
bound_type='tanh',
),
learn=dict(
epoch_per_collect=10,
update_per_collect=1,
batch_size=320,
batch_size=128,
learning_rate=3e-4,
lr_scheduler=dict(epoch_num=1500, min_lr_lambda=0),
value_weight=0.5,
entropy_weight=0.001,
clip_ratio=0.2,
Expand All @@ -43,12 +53,12 @@
),
collect=dict(
collector_env_num=collector_env_num,
n_sample=3200,
n_sample=2048,
unroll_len=1,
discount_factor=0.99,
gae_lambda=0.95,
),
eval=dict(evaluator=dict(eval_freq=500, )),
eval=dict(evaluator=dict(eval_freq=5000, )),
),
)
walker2d_onppo_config = EasyDict(walker2d_onppo_config)
Expand Down
Loading