Skip to content

Commit

Permalink
polish(rjy): modify network para
Browse files Browse the repository at this point in the history
  • Loading branch information
nighood committed Sep 7, 2022
1 parent eb0c321 commit df351a2
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 31 deletions.
16 changes: 11 additions & 5 deletions ding/model/common/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,12 +907,13 @@ class RegressionHead(nn.Module):

def __init__(
self,
hidden_size: int,
input_size: int,
output_size: int,
layer_num: int = 2,
final_tanh: Optional[bool] = False,
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None
norm_type: Optional[str] = None,
hidden_size: int = None,
) -> None:
"""
Overview:
Expand All @@ -928,7 +929,9 @@ def __init__(
for more details. Default ``None``.
"""
super(RegressionHead, self).__init__()
self.main = MLP(hidden_size, hidden_size, hidden_size, layer_num, activation=activation, norm_type=norm_type)
if hidden_size is None:
hidden_size = input_size
self.main = MLP(input_size, hidden_size, hidden_size, layer_num, activation=activation, norm_type=norm_type)
self.last = nn.Linear(hidden_size, output_size) # for convenience of special initialization
self.final_tanh = final_tanh
if self.final_tanh:
Expand Down Expand Up @@ -977,14 +980,15 @@ class ReparameterizationHead(nn.Module):

def __init__(
self,
hidden_size: int,
input_size: int,
output_size: int,
layer_num: int = 2,
sigma_type: Optional[str] = None,
fixed_sigma_value: Optional[float] = 1.0,
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None,
bound_type: Optional[str] = None,
hidden_size: int = None
) -> None:
"""
Overview:
Expand All @@ -1005,6 +1009,8 @@ def __init__(
Default is ``None``.
"""
super(ReparameterizationHead, self).__init__()
if hidden_size is None:
hidden_size = input_size
self.sigma_type = sigma_type
assert sigma_type in self.default_sigma_type, "Please indicate sigma_type as one of {}".format(
self.default_sigma_type
Expand All @@ -1013,7 +1019,7 @@ def __init__(
assert bound_type in self.default_bound_type, "Please indicate bound_type as one of {}".format(
self.default_bound_type
)
self.main = MLP(hidden_size, hidden_size, hidden_size, layer_num, activation=activation, norm_type=norm_type)
self.main = MLP(input_size, hidden_size, hidden_size, layer_num, activation=activation, norm_type=norm_type)
self.mu = nn.Linear(hidden_size, output_size)
if self.sigma_type == 'fixed':
self.sigma = torch.full((1, output_size), fixed_sigma_value)
Expand Down
60 changes: 43 additions & 17 deletions ding/model/template/qac.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,10 +295,10 @@ def __init__(
action_space: str = 'reparameterization',
encoder_hidden_size_list: SequenceType = [128, 128, 64],
twin_critic: bool = False,
actor_head_hidden_size: int = 64,
actor_head_layer_num: int = 1,
critic_head_hidden_size: int = 64,
critic_head_layer_num: int = 1,
actor_head_hidden_size: int = 1024,
actor_head_layer_num: int = 2,
critic_head_hidden_size: int = 1024,
critic_head_layer_num: int = 2,
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None,
share_conv_encoder: bool = False,
Expand Down Expand Up @@ -347,18 +347,27 @@ def __init__(
self.action_shape = action_shape

# now only support action_space == 'reparameterization'
if actor_head_hidden_size is None:
actor_head_hidden_size = encoder_hidden_size_list[-1]
assert actor_head_hidden_size == encoder_hidden_size_list[-1], 'hidden size did not match'
# if actor_head_hidden_size is None:
# actor_head_hidden_size = encoder_hidden_size_list[-1]
# assert actor_head_hidden_size == encoder_hidden_size_list[-1], 'hidden size did not match'
actor_head_input_size = encoder_hidden_size_list[-1]
self.actor = nn.Sequential(
encoder_cls(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type),
encoder_cls(
obs_shape,
encoder_hidden_size_list,
activation=activation,
norm_type=norm_type,
kernel_size=[3, 3],
stride=[2, 1]
),
ReparameterizationHead(
actor_head_hidden_size,
actor_head_input_size,
action_shape,
actor_head_layer_num,
sigma_type='conditioned',
activation=activation,
norm_type=norm_type
norm_type=norm_type,
hidden_size=actor_head_hidden_size,
)
)
if self.embed_action:
Expand All @@ -371,7 +380,12 @@ def __init__(
if self.twin_critic:
if self.share_conv_encoder:
self.critic_encoder = global_encoder_cls(
obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type
obs_shape,
encoder_hidden_size_list,
activation=activation,
norm_type=norm_type,
kernel_size=[3, 3],
stride=[2, 1]
)
else:
self.critic_encoder = nn.ModuleList()
Expand All @@ -381,30 +395,42 @@ def __init__(
if not self.share_conv_encoder:
self.critic_encoder.append(
global_encoder_cls(
obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type
obs_shape,
encoder_hidden_size_list,
activation=activation,
norm_type=norm_type,
kernel_size=[3, 3],
stride=[2, 1]
)
)
self.critic_head.append(
RegressionHead(
critic_head_input_size,
1,
action_shape,
critic_head_layer_num,
final_tanh=False,
activation=activation,
norm_type=norm_type
norm_type=norm_type,
hidden_size=critic_head_hidden_size,
)
)
else:
self.critic_encoder = global_encoder_cls(
obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type
obs_shape,
encoder_hidden_size_list,
activation=activation,
norm_type=norm_type,
kernel_size=[3, 3],
stride=[2, 1]
)
self.critic_head = RegressionHead(
critic_head_input_size,
1,
action_shape,
critic_head_layer_num,
final_tanh=False,
activation=activation,
norm_type=norm_type
norm_type=norm_type,
hidden_size=critic_head_hidden_size,
)
if self.twin_critic:
# if not share conv encoder, and not use embed_action
Expand Down
18 changes: 9 additions & 9 deletions dizoo/dmc2gym/config/dmc2gym_sac_pixel_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
env_id='dmc2gym-v0',
domain_name="cartpole",
task_name="swingup",
frame_skip=8,
frame_skip=2,
warp_frame=True,
scale=True,
clip_rewards=False,
Expand All @@ -32,30 +32,30 @@
obs_shape=(3, 84, 84),
action_shape=1,
twin_critic=True,
encoder_hidden_size_list=[256, 256, 128],
actor_head_hidden_size=128,
critic_head_hidden_size=128,
encoder_hidden_size_list=[32, 32, 50],
actor_head_hidden_size=1024,
critic_head_hidden_size=1024,

# different option about whether to share_conv_encoder in two Q networks
# and whether to use embed_action

# share_conv_encoder=False,
# embed_action=False,
share_conv_encoder=False,
embed_action=False,

# share_conv_encoder=True,
# embed_action=False,

# share_conv_encoder=False,
# embed_action=True,

share_conv_encoder=True,
embed_action=True,
# share_conv_encoder=True,
# embed_action=True,
embed_action_density=0.1,
),
learn=dict(
ignore_done=True,
update_per_collect=1,
batch_size=256,
batch_size=128,
# batch_size=4, # debug
learning_rate_q=1e-3,
learning_rate_policy=1e-3,
Expand Down

0 comments on commit df351a2

Please sign in to comment.