-
Notifications
You must be signed in to change notification settings - Fork 373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature(zc): add bcq algorithm #640
Merged
Merged
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
9f9fb56
add bcq
Super1ce 985def7
modif policy_init
Super1ce 7b7a99e
modify bcq
Super1ce e26e0ef
polish
Super1ce 1b8454a
modify default config
Super1ce d8ac3f3
format
Super1ce 9c08f7c
format
Super1ce d594000
format
Super1ce 9162870
Merge branch 'temp' into bcq
Super1ce 5431ba8
solve conflicts
Super1ce 4d4c997
modify format
Super1ce File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import gym | ||
from ditk import logging | ||
from ding.model import BCQ | ||
from ding.policy import BCQPolicy | ||
from ding.envs import DingEnvWrapper, BaseEnvManagerV2 | ||
from ding.data import create_dataset | ||
from ding.config import compile_config | ||
from ding.framework import task, ding_init | ||
from ding.framework.context import OfflineRLContext | ||
from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_data_fetcher, offline_logger | ||
from ding.utils import set_pkg_seed | ||
from dizoo.d4rl.envs import D4RLEnv | ||
from dizoo.d4rl.config.halfcheetah_medium_bcq_config import main_config, create_config | ||
|
||
|
||
def main(): | ||
# If you don't have offline data, you need to prepare if first and set the data_path in config | ||
# For demostration, we also can train a RL policy (e.g. SAC) and collect some data | ||
logging.getLogger().setLevel(logging.INFO) | ||
cfg = compile_config(main_config, create_cfg=create_config, auto=True) | ||
ding_init(cfg) | ||
with task.start(async_mode=False, ctx=OfflineRLContext()): | ||
evaluator_env = BaseEnvManagerV2( | ||
env_fn=[lambda: D4RLEnv(cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager | ||
) | ||
|
||
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) | ||
|
||
dataset = create_dataset(cfg) | ||
model = BCQ(**cfg.policy.model) | ||
policy = BCQPolicy(cfg.policy, model=model) | ||
|
||
task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) | ||
task.use(offline_data_fetcher(cfg, dataset)) | ||
task.use(trainer(cfg, policy.learn_mode)) | ||
task.use(CkptSaver(policy, cfg.exp_name, train_freq=10000000)) | ||
task.use(offline_logger()) | ||
task.run() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
from typing import Union, Dict, Optional, List | ||
from easydict import EasyDict | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
|
||
from ding.utils import SequenceType, squeeze, MODEL_REGISTRY | ||
from ..common import RegressionHead, ReparameterizationHead | ||
from .vae import VanillaVAE | ||
|
||
|
||
@MODEL_REGISTRY.register('bcq') | ||
class BCQ(nn.Module): | ||
|
||
mode = ['compute_actor', 'compute_critic', 'compute_vae', 'compute_eval'] | ||
|
||
def __init__( | ||
self, | ||
obs_shape: Union[int, SequenceType], | ||
action_shape: Union[int, SequenceType, EasyDict], | ||
actor_head_hidden_size: List = [400, 300], | ||
critic_head_hidden_size: List = [400, 300], | ||
activation: Optional[nn.Module] = nn.ReLU(), | ||
vae_hidden_dims: List = [750, 750], | ||
phi: float = 0.05 | ||
) -> None: | ||
""" | ||
Overview: | ||
Initialize neural network, i.e. agent Q network and actor. | ||
Arguments: | ||
- obs_shape (:obj:`int`): the dimension of observation state | ||
- action_shape (:obj:`int`): the dimension of action shape | ||
- actor_hidden_size (:obj:`list`): the list of hidden size of actor | ||
- critic_hidden_size (:obj:'list'): the list of hidden size of critic | ||
- activation (:obj:`nn.Module`): Activation function in network, defaults to nn.ReLU(). | ||
- vae_hidden_dims (:obj:`list`): the list of hidden size of vae | ||
""" | ||
super(BCQ, self).__init__() | ||
obs_shape: int = squeeze(obs_shape) | ||
action_shape = squeeze(action_shape) | ||
self.action_shape = action_shape | ||
self.input_size = obs_shape | ||
self.phi = phi | ||
|
||
critic_input_size = self.input_size + action_shape | ||
self.critic = nn.ModuleList() | ||
for _ in range(2): | ||
net = [] | ||
d = critic_input_size | ||
for dim in critic_head_hidden_size: | ||
net.append(nn.Linear(d, dim)) | ||
net.append(activation) | ||
d = dim | ||
net.append(nn.Linear(d, 1)) | ||
self.critic.append(nn.Sequential(*net)) | ||
|
||
net = [] | ||
d = critic_input_size | ||
for dim in actor_head_hidden_size: | ||
net.append(nn.Linear(d, dim)) | ||
net.append(activation) | ||
d = dim | ||
net.append(nn.Linear(d, 1)) | ||
self.actor = nn.Sequential(*net) | ||
|
||
self.vae = VanillaVAE(action_shape, obs_shape, action_shape * 2, vae_hidden_dims) | ||
|
||
def forward(self, inputs: Dict[str, torch.Tensor], mode: str) -> Dict[str, torch.Tensor]: | ||
""" | ||
Overview: | ||
The unique execution (forward) method of BCQ method, and one can indicate different modes to implement \ | ||
different computation graph, including ``compute_actor`` and ``compute_critic`` in BCQ. | ||
Mode compute_actor: | ||
Arguments: | ||
- inputs (:obj:`Dict`): Input dict data, including obs and action tensor. | ||
Returns: | ||
- output (:obj:`Dict`): Output dict data, including action tensor. | ||
Mode compute_critic: | ||
Arguments: | ||
- inputs (:obj:`Dict`): Input dict data, including obs and action tensor. | ||
Returns: | ||
- output (:obj:`Dict`): Output dict data, including q_value tensor. | ||
Mode compute_vae: | ||
Arguments: | ||
- inputs (:obj:`Dict`): Input dict data, including obs and action tensor. | ||
Returns: | ||
- outputs (:obj:`Dict`): Dict containing keywords ``recons_action`` \ | ||
(:obj:`torch.Tensor`), ``prediction_residual`` (:obj:`torch.Tensor`), \ | ||
``input`` (:obj:`torch.Tensor`), ``mu`` (:obj:`torch.Tensor`), \ | ||
``log_var`` (:obj:`torch.Tensor`) and ``z`` (:obj:`torch.Tensor`). | ||
Mode compute_eval: | ||
Arguments: | ||
- inputs (:obj:`Dict`): Input dict data, including obs and action tensor. | ||
Returns: | ||
- output (:obj:`Dict`): Output dict data, including action tensor. | ||
|
||
|
||
.. note:: | ||
For specific examples, one can refer to API doc of ``compute_actor`` and ``compute_critic`` respectively. | ||
""" | ||
assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) | ||
return getattr(self, mode)(inputs) | ||
|
||
def compute_critic(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | ||
obs, action = inputs['obs'], inputs['action'] | ||
if len(action.shape) == 1: # (B, ) -> (B, 1) | ||
action = action.unsqueeze(1) | ||
x = torch.cat([obs, action], dim=-1) | ||
x = [m(x).squeeze() for m in self.critic] | ||
return {'q_value': x} | ||
|
||
def compute_actor(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]: | ||
input = torch.cat([inputs['obs'], inputs['action']], -1) | ||
x = self.actor(input) | ||
action = self.phi * 1 * torch.tanh(x) | ||
action = (action + inputs['action']).clamp(-1, 1) | ||
return {'action': action} | ||
|
||
def compute_vae(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | ||
return self.vae.forward(inputs) | ||
|
||
def compute_eval(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | ||
obs = inputs['obs'] | ||
obs_rep = obs.clone().unsqueeze(0).repeat_interleave(100, dim=0) | ||
z = torch.randn((obs_rep.shape[0], obs_rep.shape[1], self.action_shape * 2)).to(obs.device).clamp(-0.5, 0.5) | ||
sample_action = self.vae.decode_with_obs(z, obs_rep)['reconstruction_action'] | ||
action = self.compute_actor({'obs': obs_rep, 'action': sample_action})['action'] | ||
q = self.compute_critic({'obs': obs_rep, 'action': action})['q_value'][0] | ||
idx = q.argmax(dim=0).unsqueeze(0).unsqueeze(-1) | ||
idx = idx.repeat_interleave(action.shape[-1], dim=-1) | ||
action = action.gather(0, idx).squeeze() | ||
return {'action': action} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add note for arguments.