diff --git a/ding/model/template/qvac.py b/ding/model/template/qvac.py new file mode 100644 index 0000000000..50cdca8e7d --- /dev/null +++ b/ding/model/template/qvac.py @@ -0,0 +1,363 @@ +from typing import Union, Dict, Optional +from easydict import EasyDict +import numpy as np +import torch +import torch.nn as nn + +from ding.utils import SequenceType, squeeze, MODEL_REGISTRY +from ..common import RegressionHead, ReparameterizationHead, DiscreteHead, MultiHead, \ + FCEncoder, ConvEncoder + + +@MODEL_REGISTRY.register('continuous_qvac') +class ContinuousQVAC(nn.Module): + """ + Overview: + The neural network and computation graph of algorithms related to Actor-Critic that have both Q-value and V-value critic, such as \ + IQL. This model now supports continuous and hybrid action space. The ContinuousQVAC is composed of \ + four parts: ``actor_encoder``, ``critic_encoder``, ``actor_head`` and ``critic_head``. Encoders are used to \ + extract the feature from various observation. Heads are used to predict corresponding Q-value and V-value or action logit. \ + In high-dimensional observation space like 2D image, we often use a shared encoder for both ``actor_encoder`` \ + and ``critic_encoder``. In low-dimensional observation space like 1D vector, we often use different encoders. + Interfaces: + ``__init__``, ``forward``, ``compute_actor``, ``compute_critic`` + """ + mode = ['compute_actor', 'compute_critic'] + + def __init__( + self, + obs_shape: Union[int, SequenceType], + action_shape: Union[int, SequenceType, EasyDict], + action_space: str, + twin_critic: bool = False, + actor_head_hidden_size: int = 64, + actor_head_layer_num: int = 1, + critic_head_hidden_size: int = 64, + critic_head_layer_num: int = 1, + activation: Optional[nn.Module] = nn.SiLU(), #nn.ReLU(), + norm_type: Optional[str] = None, + encoder_hidden_size_list: Optional[SequenceType] = None, + share_encoder: Optional[bool] = False, + ) -> None: + """ + Overview: + Initailize the ContinuousQVAC Model according to input arguments. + Arguments: + - obs_shape (:obj:`Union[int, SequenceType]`): Observation's shape, such as 128, (156, ). + - action_shape (:obj:`Union[int, SequenceType, EasyDict]`): Action's shape, such as 4, (3, ), \ + EasyDict({'action_type_shape': 3, 'action_args_shape': 4}). + - action_space (:obj:`str`): The type of action space, including [``regression``, ``reparameterization``, \ + ``hybrid``], ``regression`` is used for DDPG/TD3, ``reparameterization`` is used for SAC and \ + ``hybrid`` for PADDPG. + - twin_critic (:obj:`bool`): Whether to use twin critic, one of tricks in TD3. + - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor head. + - actor_head_layer_num (:obj:`int`): The num of layers used in the actor network to compute action. + - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic head. + - critic_head_layer_num (:obj:`int`): The num of layers used in the critic network to compute Q-value. + - activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` \ + after each FC layer, if ``None`` then default set to ``nn.ReLU()``. + - norm_type (:obj:`Optional[str]`): The type of normalization to after network layer (FC, Conv), \ + see ``ding.torch_utils.network`` for more details. + - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \ + the last element must match ``head_hidden_size``, this argument is only used in image observation. + - share_encoder (:obj:`Optional[bool]`): Whether to share encoder between actor and critic. + """ + super(ContinuousQVAC, self).__init__() + obs_shape: int = squeeze(obs_shape) + action_shape = squeeze(action_shape) + self.action_shape = action_shape + self.action_space = action_space + assert self.action_space in ['regression', 'reparameterization', 'hybrid'], self.action_space + + # encoder + self.share_encoder = share_encoder + if np.isscalar(obs_shape) or len(obs_shape) == 1: + assert not self.share_encoder, "Vector observation doesn't need share encoder." + assert encoder_hidden_size_list is None, "Vector obs encoder only uses one layer nn.Linear" + # Because there is already a layer nn.Linear in the head, so we use nn.Identity here to keep + # compatible with the image observation and avoid adding an extra layer nn.Linear. + self.actor_encoder = nn.Identity() + self.critic_encoder = nn.Identity() + encoder_output_size = obs_shape + elif len(obs_shape) == 3: + + def setup_conv_encoder(): + kernel_size = [3 for _ in range(len(encoder_hidden_size_list))] + stride = [2] + [1 for _ in range(len(encoder_hidden_size_list) - 1)] + return ConvEncoder( + obs_shape, + encoder_hidden_size_list, + activation=activation, + norm_type=norm_type, + kernel_size=kernel_size, + stride=stride + ) + + if self.share_encoder: + encoder = setup_conv_encoder() + self.actor_encoder = self.critic_encoder = encoder + else: + self.actor_encoder = setup_conv_encoder() + self.critic_encoder = setup_conv_encoder() + encoder_output_size = self.actor_encoder.output_size + else: + raise RuntimeError("not support observation shape: {}".format(obs_shape)) + # head + if self.action_space == 'regression': # DDPG, TD3 + self.actor_head = nn.Sequential( + nn.Linear(encoder_output_size, actor_head_hidden_size), activation, + RegressionHead( + actor_head_hidden_size, + action_shape, + actor_head_layer_num, + final_tanh=True, + activation=activation, + norm_type=norm_type + ) + ) + elif self.action_space == 'reparameterization': # SAC + self.actor_head = nn.Sequential( + nn.Linear(encoder_output_size, actor_head_hidden_size), activation, + ReparameterizationHead( + actor_head_hidden_size, + action_shape, + actor_head_layer_num, + sigma_type='conditioned', + activation=activation, + norm_type=norm_type + ) + ) + elif self.action_space == 'hybrid': # PADDPG + # hybrid action space: action_type(discrete) + action_args(continuous), + # such as {'action_type_shape': torch.LongTensor([0]), 'action_args_shape': torch.FloatTensor([0.1, -0.27])} + action_shape.action_args_shape = squeeze(action_shape.action_args_shape) + action_shape.action_type_shape = squeeze(action_shape.action_type_shape) + actor_action_args = nn.Sequential( + nn.Linear(encoder_output_size, actor_head_hidden_size), activation, + RegressionHead( + actor_head_hidden_size, + action_shape.action_args_shape, + actor_head_layer_num, + final_tanh=True, + activation=activation, + norm_type=norm_type + ) + ) + actor_action_type = nn.Sequential( + nn.Linear(encoder_output_size, actor_head_hidden_size), activation, + DiscreteHead( + actor_head_hidden_size, + action_shape.action_type_shape, + actor_head_layer_num, + activation=activation, + norm_type=norm_type, + ) + ) + self.actor_head = nn.ModuleList([actor_action_type, actor_action_args]) + + self.twin_critic = twin_critic + if self.action_space == 'hybrid': + critic_q_input_size = encoder_output_size + action_shape.action_type_shape + action_shape.action_args_shape + critic_v_input_size = encoder_output_size + else: + critic_q_input_size = encoder_output_size + action_shape + critic_v_input_size = encoder_output_size + if self.twin_critic: + self.critic_q_head = nn.ModuleList() + self.critic_v_head = nn.ModuleList() + for _ in range(2): + self.critic_q_head.append( + nn.Sequential( + nn.Linear(critic_q_input_size, critic_head_hidden_size), activation, + RegressionHead( + critic_head_hidden_size, + 1, + critic_head_layer_num, + final_tanh=False, + activation=activation, + norm_type=norm_type + ) + ) + ) + self.critic_v_head.append( + nn.Sequential( + nn.Linear(critic_v_input_size, critic_head_hidden_size), activation, + RegressionHead( + critic_head_hidden_size, + 1, + critic_head_layer_num, + final_tanh=False, + activation=activation, + norm_type=norm_type + ) + ) + ) + else: + self.critic_q_head = nn.Sequential( + nn.Linear(critic_q_input_size, critic_head_hidden_size), activation, + RegressionHead( + critic_head_hidden_size, + 1, + critic_head_layer_num, + final_tanh=False, + activation=activation, + norm_type=norm_type + ) + ) + self.critic_v_head = nn.Sequential( + nn.Linear(critic_v_input_size, critic_head_hidden_size), activation, + RegressionHead( + critic_head_hidden_size, + 1, + critic_head_layer_num, + final_tanh=False, + activation=activation, + norm_type=norm_type + ) + ) + + # Convenient for calling some apis (e.g. self.critic.parameters()), + # but may cause misunderstanding when `print(self)` + self.actor = nn.ModuleList([self.actor_encoder, self.actor_head]) + self.critic = nn.ModuleList([self.critic_encoder, self.critic_q_head, self.critic_v_head]) + + def forward(self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], mode: str) -> Dict[str, torch.Tensor]: + """ + Overview: + QVAC forward computation graph, input observation tensor to predict Q-value or action logit. Different \ + ``mode`` will forward with different network modules to get different outputs and save computation. + Arguments: + - inputs (:obj:`Union[torch.Tensor, Dict[str, torch.Tensor]]`): The input data for forward computation \ + graph, for ``compute_actor``, it is the observation tensor, for ``compute_critic``, it is the \ + dict data including obs and action tensor. + - mode (:obj:`str`): The forward mode, all the modes are defined in the beginning of this class. + Returns: + - output (:obj:`Dict[str, torch.Tensor]`): The output dict of QVAC forward computation graph, whose \ + key-values vary in different forward modes. + Examples (Actor): + >>> # Regression mode + >>> model = ContinuousQVAC(64, 6, 'regression') + >>> obs = torch.randn(4, 64) + >>> actor_outputs = model(obs,'compute_actor') + >>> assert actor_outputs['action'].shape == torch.Size([4, 6]) + >>> # Reparameterization Mode + >>> model = ContinuousQVAC(64, 6, 'reparameterization') + >>> obs = torch.randn(4, 64) + >>> actor_outputs = model(obs,'compute_actor') + >>> assert actor_outputs['logit'][0].shape == torch.Size([4, 6]) # mu + >>> actor_outputs['logit'][1].shape == torch.Size([4, 6]) # sigma + + Examples (Critic): + >>> inputs = {'obs': torch.randn(4, 8), 'action': torch.randn(4, 1)} + >>> model = ContinuousQVAC(obs_shape=(8, ),action_shape=1, action_space='regression') + >>> assert model(inputs, mode='compute_critic')['q_value'].shape == (4, ) # q value + """ + assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) + return getattr(self, mode)(inputs) + + def compute_actor(self, obs: torch.Tensor) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]: + """ + Overview: + QVAC forward computation graph for actor part, input observation tensor to predict action or action logit. + Arguments: + - x (:obj:`torch.Tensor`): The input observation tensor data. + Returns: + - outputs (:obj:`Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]`): Actor output dict varying \ + from action_space: ``regression``, ``reparameterization``, ``hybrid``. + ReturnsKeys (regression): + - action (:obj:`torch.Tensor`): Continuous action with same size as ``action_shape``, usually in DDPG/TD3. + ReturnsKeys (reparameterization): + - logit (:obj:`Dict[str, torch.Tensor]`): The predictd reparameterization action logit, usually in SAC. \ + It is a list containing two tensors: ``mu`` and ``sigma``. The former is the mean of the gaussian \ + distribution, the latter is the standard deviation of the gaussian distribution. + ReturnsKeys (hybrid): + - logit (:obj:`torch.Tensor`): The predicted discrete action type logit, it will be the same dimension \ + as ``action_type_shape``, i.e., all the possible discrete action types. + - action_args (:obj:`torch.Tensor`): Continuous action arguments with same size as ``action_args_shape``. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, N0)`, B is batch size and N0 corresponds to ``obs_shape``. + - action (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size and N1 corresponds to ``action_shape``. + - logit.mu (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size and N1 corresponds to ``action_shape``. + - logit.sigma (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size. + - logit (:obj:`torch.Tensor`): :math:`(B, N2)`, B is batch size and N2 corresponds to \ + ``action_shape.action_type_shape``. + - action_args (:obj:`torch.Tensor`): :math:`(B, N3)`, B is batch size and N3 corresponds to \ + ``action_shape.action_args_shape``. + Examples: + >>> # Regression mode + >>> model = ContinuousQVAC(64, 6, 'regression') + >>> obs = torch.randn(4, 64) + >>> actor_outputs = model(obs,'compute_actor') + >>> assert actor_outputs['action'].shape == torch.Size([4, 6]) + >>> # Reparameterization Mode + >>> model = ContinuousQVAC(64, 6, 'reparameterization') + >>> obs = torch.randn(4, 64) + >>> actor_outputs = model(obs,'compute_actor') + >>> assert actor_outputs['logit'][0].shape == torch.Size([4, 6]) # mu + >>> actor_outputs['logit'][1].shape == torch.Size([4, 6]) # sigma + """ + obs = self.actor_encoder(obs) + if self.action_space == 'regression': + x = self.actor_head(obs) + return {'action': x['pred']} + elif self.action_space == 'reparameterization': + x = self.actor_head(obs) + return {'logit': [x['mu'], x['sigma']]} + elif self.action_space == 'hybrid': + logit = self.actor_head[0](obs) + action_args = self.actor_head[1](obs) + return {'logit': logit['logit'], 'action_args': action_args['pred']} + + def compute_critic(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Overview: + QVAC forward computation graph for critic part, input observation and action tensor to predict Q-value. + Arguments: + - inputs (:obj:`Dict[str, torch.Tensor]`): The dict of input data, including ``obs`` and ``action`` \ + tensor, also contains ``logit`` and ``action_args`` tensor in hybrid action_space. + ArgumentsKeys: + - obs: (:obj:`torch.Tensor`): Observation tensor data, now supports a batch of 1-dim vector data. + - action (:obj:`Union[torch.Tensor, Dict]`): Continuous action with same size as ``action_shape``. + - logit (:obj:`torch.Tensor`): Discrete action logit, only in hybrid action_space. + - action_args (:obj:`torch.Tensor`): Continuous action arguments, only in hybrid action_space. + Returns: + - outputs (:obj:`Dict[str, torch.Tensor]`): The output dict of QVAC's forward computation graph for critic, \ + including ``q_value``. + ReturnKeys: + - q_value (:obj:`torch.Tensor`): Q value tensor with same size as batch size. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, N1)`, where B is batch size and N1 is ``obs_shape``. + - logit (:obj:`torch.Tensor`): :math:`(B, N2)`, B is batch size and N2 corresponds to \ + ``action_shape.action_type_shape``. + - action_args (:obj:`torch.Tensor`): :math:`(B, N3)`, B is batch size and N3 corresponds to \ + ``action_shape.action_args_shape``. + - action (:obj:`torch.Tensor`): :math:`(B, N4)`, where B is batch size and N4 is ``action_shape``. + - q_value (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch size. + + Examples: + >>> inputs = {'obs': torch.randn(4, 8), 'action': torch.randn(4, 1)} + >>> model = ContinuousQVAC(obs_shape=(8, ),action_shape=1, action_space='regression') + >>> assert model(inputs, mode='compute_critic')['q_value'].shape == (4, ) # q value + """ + + obs, action = inputs['obs'], inputs['action'] + obs = self.critic_encoder(obs) + assert len(obs.shape) == 2 + if self.action_space == 'hybrid': + action_type_logit = inputs['logit'] + action_type_logit = torch.softmax(action_type_logit, dim=-1) + action_args = action['action_args'] + if len(action_args.shape) == 1: + action_args = action_args.unsqueeze(1) + x = torch.cat([obs, action_type_logit, action_args], dim=1) + else: + if len(action.shape) == 1: # (B, ) -> (B, 1) + action = action.unsqueeze(1) + x = torch.cat([obs, action], dim=1) + if self.twin_critic: + x = [m(x)['pred'] for m in self.critic_q_head] + y = [m(obs)['pred'] for m in self.critic_v_head] + else: + x = self.critic_q_head(x)['pred'] + y = self.critic_v_head(obs)['pred'] + return {'q_value': x, 'v_value': y} diff --git a/ding/policy/command_mode_policy_instance.py b/ding/policy/command_mode_policy_instance.py index 2e817ead4b..1ed500dbbd 100644 --- a/ding/policy/command_mode_policy_instance.py +++ b/ding/policy/command_mode_policy_instance.py @@ -43,6 +43,7 @@ from .d4pg import D4PGPolicy from .cql import CQLPolicy, DiscreteCQLPolicy +from .iql import IQLPolicy from .dt import DTPolicy from .pdqn import PDQNPolicy from .madqn import MADQNPolicy @@ -321,6 +322,11 @@ class CQLCommandModePolicy(CQLPolicy, DummyCommandModePolicy): pass +@POLICY_REGISTRY.register('iql_command') +class IQLCommandModePolicy(IQLPolicy, DummyCommandModePolicy): + pass + + @POLICY_REGISTRY.register('discrete_cql_command') class DiscreteCQLCommandModePolicy(DiscreteCQLPolicy, EpsCommandModePolicy): pass diff --git a/ding/policy/iql.py b/ding/policy/iql.py new file mode 100644 index 0000000000..adc8483891 --- /dev/null +++ b/ding/policy/iql.py @@ -0,0 +1,654 @@ +from typing import List, Dict, Any, Tuple, Union +import copy +from collections import namedtuple +import numpy as np +import torch +import torch.nn.functional as F +from torch.distributions import Normal, Independent, TransformedDistribution +from torch.distributions.transforms import TanhTransform, AffineTransform + +from ding.torch_utils import Adam, to_device +from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample, \ + qrdqn_nstep_td_data, qrdqn_nstep_td_error, get_nstep_return_data +from ding.model import model_wrap +from ding.utils import POLICY_REGISTRY +from ding.utils.data import default_collate, default_decollate +from .base_policy import Policy +from .common_utils import default_preprocess_learn + + +def asymmetric_l2_loss(u, tau): + return torch.mean(torch.abs(tau - (u < 0).float()) * u ** 2) + + +@POLICY_REGISTRY.register('iql') +class IQLPolicy(Policy): + """ + Overview: + Policy class of Implicit Q-Learning (IQL) algorithm for continuous control. Paper link: https://arxiv.org/abs/2110.06169. + + Config: + == ==================== ======== ============= ================================= ======================= + ID Symbol Type Default Value Description Other(Shape) + == ==================== ======== ============= ================================= ======================= + 1 ``type`` str iql | RL policy register name, refer | this arg is optional, + | to registry ``POLICY_REGISTRY`` | a placeholder + 2 ``cuda`` bool True | Whether to use cuda for network | + 3 | ``random_`` int 10000 | Number of randomly collected | Default to 10000 for + | ``collect_size`` | training samples in replay | SAC, 25000 for DDPG/ + | | buffer when training starts. | TD3. + 4 | ``model.policy_`` int 256 | Linear layer size for policy | + | ``embedding_size`` | network. | + 5 | ``model.soft_q_`` int 256 | Linear layer size for soft q | + | ``embedding_size`` | network. | + 6 | ``model.value_`` int 256 | Linear layer size for value | Defalut to None when + | ``embedding_size`` | network. | model.value_network + | | | is False. + 7 | ``learn.learning`` float 3e-4 | Learning rate for soft q | Defalut to 1e-3, when + | ``_rate_q`` | network. | model.value_network + | | | is True. + 8 | ``learn.learning`` float 3e-4 | Learning rate for policy | Defalut to 1e-3, when + | ``_rate_policy`` | network. | model.value_network + | | | is True. + 9 | ``learn.learning`` float 3e-4 | Learning rate for policy | Defalut to None when + | ``_rate_value`` | network. | model.value_network + | | | is False. + 10 | ``learn.alpha`` float 0.2 | Entropy regularization | alpha is initiali- + | | coefficient. | zation for auto + | | | `alpha`, when + | | | auto_alpha is True + 11 | ``learn.repara_`` bool True | Determine whether to use | + | ``meterization`` | reparameterization trick. | + 12 | ``learn.`` bool False | Determine whether to use | Temperature parameter + | ``auto_alpha`` | auto temperature parameter | determines the + | | `alpha`. | relative importance + | | | of the entropy term + | | | against the reward. + 13 | ``learn.-`` bool False | Determine whether to ignore | Use ignore_done only + | ``ignore_done`` | done flag. | in halfcheetah env. + 14 | ``learn.-`` float 0.005 | Used for soft update of the | aka. Interpolation + | ``target_theta`` | target network. | factor in polyak aver + | | | aging for target + | | | networks. + == ==================== ======== ============= ================================= ======================= + """ + + config = dict( + # (str) RL policy register name (refer to function "POLICY_REGISTRY"). + type='iql', + # (bool) Whether to use cuda for policy. + cuda=False, + # (bool) on_policy: Determine whether on-policy or off-policy. + # on-policy setting influences the behaviour of buffer. + on_policy=False, + # (bool) priority: Determine whether to use priority in buffer sample. + priority=False, + # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. + priority_IS_weight=False, + # (int) Number of training samples(randomly collected) in replay buffer when training starts. + random_collect_size=10000, + model=dict( + # (bool type) twin_critic: Determine whether to use double-soft-q-net for target q computation. + # Please refer to TD3 about Clipped Double-Q Learning trick, which learns two Q-functions instead of one . + # Default to True. + twin_critic=True, + # (str type) action_space: Use reparameterization trick for continous action + action_space='reparameterization', + # (int) Hidden size for actor network head. + actor_head_hidden_size=512, + actor_head_layer_num=3, + # (int) Hidden size for critic network head. + critic_head_hidden_size=512, + critic_head_layer_num=2, + ), + # learn_mode config + learn=dict( + # (int) How many updates (iterations) to train after collector's one collection. + # Bigger "update_per_collect" means bigger off-policy. + update_per_collect=1, + # (int) Minibatch size for gradient descent. + batch_size=256, + # (float) learning_rate_q: Learning rate for soft q network. + learning_rate_q=3e-4, + # (float) learning_rate_policy: Learning rate for policy network. + learning_rate_policy=3e-4, + # (float) learning_rate_alpha: Learning rate for auto temperature parameter ``alpha``. + learning_rate_alpha=3e-4, + # (float) target_theta: Used for soft update of the target network, + # aka. Interpolation factor in polyak averaging for target networks. + target_theta=0.005, + # (float) discount factor for the discounted sum of rewards, aka. gamma. + discount_factor=0.99, + # (float) alpha: Entropy regularization coefficient. + # Please check out the original SAC paper (arXiv 1801.01290): Eq 1 for more details. + # If auto_alpha is set to `True`, alpha is initialization for auto `\alpha`. + # Default to 0.2. + alpha=0.2, + # (bool) auto_alpha: Determine whether to use auto temperature parameter `\alpha` . + # Temperature parameter determines the relative importance of the entropy term against the reward. + # Please check out the original SAC paper (arXiv 1801.01290): Eq 1 for more details. + # Default to False. + # Note that: Using auto alpha needs to set learning_rate_alpha in `cfg.policy.learn`. + auto_alpha=True, + # (bool) log_space: Determine whether to use auto `\alpha` in log space. + log_space=True, + # (bool) Whether ignore done(usually for max step termination env. e.g. pendulum) + # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers. + # These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000. + # However, interaction with HalfCheetah always gets done with done is False, + # Since we inplace done==True with done==False to keep + # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``), + # when the episode step is greater than max episode step. + ignore_done=False, + # (float) Weight uniform initialization range in the last output layer. + init_w=3e-3, + # (int) The numbers of action sample each at every state s from a uniform-at-random. + num_actions=10, + # (bool) Whether use lagrange multiplier in q value loss. + with_lagrange=False, + # (float) The threshold for difference in Q-values. + lagrange_thresh=-1, + # (float) Loss weight for conservative item. + min_q_weight=1.0, + # (float) coefficient for the asymmetric loss, range from [0.5, 1.0], default to 0.70. + tau=0.7, + # (float) temperature coefficient for Advantage Weighted Regression loss, default to 1.0. + beta=1.0, + ), + eval=dict(), # for compatibility + ) + + def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ + automatically call this method to get the default model setting and create model. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. + """ + + return 'continuous_qvac', ['ding.model.template.qvac'] + + def _init_learn(self) -> None: + """ + Overview: + Initialize the learn mode of policy, including related attributes and modules. For SAC, it mainly \ + contains three optimizers, algorithm-specific arguments such as gamma, min_q_weight, with_lagrange, \ + main and target model. Especially, the ``auto_alpha`` mechanism for balancing max entropy \ + target is also initialized here. + This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. + + .. note:: + For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ + and ``_load_state_dict_learn`` methods. + + .. note:: + For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. + + .. note:: + If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ + with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. + """ + self._priority = self._cfg.priority + self._priority_IS_weight = self._cfg.priority_IS_weight + self._twin_critic = self._cfg.model.twin_critic + self._num_actions = self._cfg.learn.num_actions + + self._min_q_version = 3 + self._min_q_weight = self._cfg.learn.min_q_weight + self._with_lagrange = self._cfg.learn.with_lagrange and (self._lagrange_thresh > 0) + self._lagrange_thresh = self._cfg.learn.lagrange_thresh + if self._with_lagrange: + self.target_action_gap = self._lagrange_thresh + self.log_alpha_prime = torch.tensor(0.).to(self._device).requires_grad_() + self.alpha_prime_optimizer = Adam( + [self.log_alpha_prime], + lr=self._cfg.learn.learning_rate_q, + ) + + # Weight Init + init_w = self._cfg.learn.init_w + self._model.actor_head[-1].mu.weight.data.uniform_(-init_w, init_w) + self._model.actor_head[-1].mu.bias.data.uniform_(-init_w, init_w) + # self._model.actor_head[-1].log_sigma_layer.weight.data.uniform_(-init_w, init_w) + # self._model.actor_head[-1].log_sigma_layer.bias.data.uniform_(-init_w, init_w) + if self._twin_critic: + self._model.critic_q_head[0][-1].last.weight.data.uniform_(-init_w, init_w) + self._model.critic_q_head[0][-1].last.bias.data.uniform_(-init_w, init_w) + self._model.critic_q_head[1][-1].last.weight.data.uniform_(-init_w, init_w) + self._model.critic_q_head[1][-1].last.bias.data.uniform_(-init_w, init_w) + self._model.critic_v_head[0][-1].last.weight.data.uniform_(-init_w, init_w) + self._model.critic_v_head[0][-1].last.bias.data.uniform_(-init_w, init_w) + self._model.critic_v_head[1][-1].last.weight.data.uniform_(-init_w, init_w) + self._model.critic_v_head[1][-1].last.bias.data.uniform_(-init_w, init_w) + else: + self._model.critic_q_head[2].last.weight.data.uniform_(-init_w, init_w) + self._model.critic_q_head[-1].last.bias.data.uniform_(-init_w, init_w) + self._model.critic_v_head[2].last.weight.data.uniform_(-init_w, init_w) + self._model.critic_v_head[-1].last.bias.data.uniform_(-init_w, init_w) + + # Optimizers + self._optimizer_q = Adam( + self._model.critic.parameters(), + lr=self._cfg.learn.learning_rate_q, + ) + self._optimizer_policy = Adam( + self._model.actor.parameters(), + lr=self._cfg.learn.learning_rate_policy, + ) + + # Algorithm config + self._gamma = self._cfg.learn.discount_factor + + self._learn_model = model_wrap(self._model, wrapper_name='base') + self._learn_model.reset() + + self._forward_learn_cnt = 0 + + self._tau = self._cfg.learn.tau + self._beta = self._cfg.learn.beta + self._policy_start_training_counter = 10000 #300000 + + def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + Overview: + Policy forward function of learn mode (training policy and updating parameters). Forward means \ + that the policy inputs some training batch data from the offline dataset and then returns the output \ + result, including various training information such as loss, action, priority. + Arguments: + - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ + training samples. For each element in list, the key of the dict is the name of data items and the \ + value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ + combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ + dimension by some utility functions such as ``default_preprocess_learn``. \ + For IQL, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \ + ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight``. + Returns: + - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ + recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ + detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + """ + loss_dict = {} + data = default_preprocess_learn( + data, + use_priority=self._priority, + use_priority_IS_weight=self._cfg.priority_IS_weight, + ignore_done=self._cfg.learn.ignore_done, + use_nstep=False + ) + if len(data.get('action').shape) == 1: + data['action'] = data['action'].reshape(-1, 1) + + if self._cuda: + data = to_device(data, self._device) + + self._learn_model.train() + obs = data['obs'] + next_obs = data['next_obs'] + reward = data['reward'] + done = data['done'] + + # 1. predict q and v value + value = self._learn_model.forward(data, mode='compute_critic') + q_value, v_value = value['q_value'], value['v_value'] + + # 2. predict target value + with torch.no_grad(): + (mu, sigma) = self._learn_model.forward(next_obs, mode='compute_actor')['logit'] + + next_obs_dist = TransformedDistribution( + Independent(Normal(mu, sigma), 1), + transforms=[TanhTransform(cache_size=1), + AffineTransform(loc=0.0, scale=1.05)] + ) + next_action = next_obs_dist.rsample() + next_log_prob = next_obs_dist.log_prob(next_action) + + next_data = {'obs': next_obs, 'action': next_action} + next_value = self._learn_model.forward(next_data, mode='compute_critic') + next_q_value, next_v_value = next_value['q_value'], next_value['v_value'] + + # the value of a policy according to the maximum entropy objective + if self._twin_critic: + next_q_value = torch.min(next_q_value[0], next_q_value[1]) + + # 3. compute v loss + if self._twin_critic: + q_value_min = torch.min(q_value[0], q_value[1]).detach() + v_loss_0 = asymmetric_l2_loss(q_value_min - v_value[0], self._tau) + v_loss_1 = asymmetric_l2_loss(q_value_min - v_value[1], self._tau) + v_loss = (v_loss_0 + v_loss_1) / 2 + else: + advantage = q_value.detach() - v_value + v_loss = asymmetric_l2_loss(advantage, self._tau) + + # 4. compute q loss + if self._twin_critic: + next_v_value = torch.min(next_v_value[0], next_v_value[1]) + q_data0 = v_1step_td_data(q_value[0], next_v_value, reward, done, data['weight']) + loss_dict['critic_q_loss'], td_error_per_sample0 = v_1step_td_error(q_data0, self._gamma) + q_data1 = v_1step_td_data(q_value[1], next_v_value, reward, done, data['weight']) + loss_dict['twin_critic_q_loss'], td_error_per_sample1 = v_1step_td_error(q_data1, self._gamma) + q_loss = (loss_dict['critic_q_loss'] + loss_dict['twin_critic_q_loss']) / 2 + else: + q_data = v_1step_td_data(q_value, next_v_value, reward, done, data['weight']) + loss_dict['critic_q_loss'], td_error_per_sample = v_1step_td_error(q_data, self._gamma) + q_loss = loss_dict['critic_q_loss'] + + # 5. update q and v network + self._optimizer_q.zero_grad() + v_loss.backward() + q_loss.backward() + self._optimizer_q.step() + + # 6. evaluate to get action distribution + (mu, sigma) = self._learn_model.forward(data['obs'], mode='compute_actor')['logit'] + + dist = TransformedDistribution( + Independent(Normal(mu, sigma), 1), + transforms=[TanhTransform(cache_size=1), AffineTransform(loc=0.0, scale=1.05)] + ) + action = data['action'] + log_prob = dist.log_prob(action) + + eval_data = {'obs': obs, 'action': action} + new_value = self._learn_model.forward(eval_data, mode='compute_critic') + new_q_value, new_v_value = new_value['q_value'], new_value['v_value'] + if self._twin_critic: + new_q_value = torch.min(new_q_value[0], new_q_value[1]) + new_v_value = torch.min(new_v_value[0], new_v_value[1]) + new_advantage = new_q_value - new_v_value + + # 8. compute policy loss + policy_loss = (-log_prob * torch.exp(new_advantage.detach() / self._beta).clamp(max=20.0)).mean() + self._policy_start_training_counter -= 1 + + loss_dict['policy_loss'] = policy_loss + + # 9. update policy network + self._optimizer_policy.zero_grad() + policy_loss.backward() + policy_grad_norm = torch.nn.utils.clip_grad_norm_(self._model.actor.parameters(), 1) + self._optimizer_policy.step() + + loss_dict['total_loss'] = sum(loss_dict.values()) + + # ============= + # after update + # ============= + self._forward_learn_cnt += 1 + + return { + 'cur_lr_q': self._optimizer_q.defaults['lr'], + 'cur_lr_p': self._optimizer_policy.defaults['lr'], + 'priority': q_loss.abs().tolist(), + 'q_loss': q_loss.detach().mean().item(), + 'v_loss': v_loss.detach().mean().item(), + 'log_prob': log_prob.detach().mean().item(), + 'next_q_value': next_q_value.detach().mean().item(), + 'next_v_value': next_v_value.detach().mean().item(), + 'policy_loss': policy_loss.detach().mean().item(), + 'total_loss': loss_dict['total_loss'].detach().item(), + 'advantage_max': new_advantage.max().detach().item(), + 'new_q_value': new_q_value.detach().mean().item(), + 'new_v_value': new_v_value.detach().mean().item(), + 'policy_grad_norm': policy_grad_norm, + } + + def _get_policy_actions(self, data: Dict, num_actions: int = 10, epsilon: float = 1e-6) -> List: + # evaluate to get action distribution + obs = data['obs'] + obs = obs.unsqueeze(1).repeat(1, num_actions, 1).view(obs.shape[0] * num_actions, obs.shape[1]) + (mu, sigma) = self._learn_model.forward(obs, mode='compute_actor')['logit'] + dist = Independent(Normal(mu, sigma), 1) + pred = dist.rsample() + action = torch.tanh(pred) + + # evaluate action log prob depending on Jacobi determinant. + y = 1 - action.pow(2) + epsilon + log_prob = dist.log_prob(pred).unsqueeze(-1) + log_prob = log_prob - torch.log(y).sum(-1, keepdim=True) + + return action, log_prob.view(-1, num_actions, 1) + + def _get_q_value(self, data: Dict, keep: bool = True) -> torch.Tensor: + new_q_value = self._learn_model.forward(data, mode='compute_critic')['q_value'] + if self._twin_critic: + new_q_value = [value.view(-1, self._num_actions, 1) for value in new_q_value] + else: + new_q_value = new_q_value.view(-1, self._num_actions, 1) + if self._twin_critic and not keep: + new_q_value = torch.min(new_q_value[0], new_q_value[1]) + return new_q_value + + def _get_v_value(self, data: Dict, keep: bool = True) -> torch.Tensor: + new_v_value = self._learn_model.forward(data, mode='compute_critic')['v_value'] + if self._twin_critic: + new_v_value = [value.view(-1, self._num_actions, 1) for value in new_v_value] + else: + new_v_value = new_v_value.view(-1, self._num_actions, 1) + if self._twin_critic and not keep: + new_v_value = torch.min(new_v_value[0], new_v_value[1]) + return new_v_value + + def _init_collect(self) -> None: + """ + Overview: + Initialize the collect mode of policy, including related attributes and modules. For SAC, it contains the \ + collect_model other algorithm-specific arguments such as unroll_len. \ + This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``. + + .. note:: + If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \ + with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``. + """ + self._unroll_len = self._cfg.collect.unroll_len + self._collect_model = model_wrap(self._model, wrapper_name='base') + self._collect_model.reset() + + def _forward_collect(self, data: Dict[int, Any], **kwargs) -> Dict[int, Any]: + """ + Overview: + Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \ + that the policy gets some necessary data (mainly observation) from the envs and then returns the output \ + data, such as the action to interact with the envs. + Arguments: + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ + key of the dict is environment id and the value is the corresponding data of the env. + Returns: + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \ + other necessary data for learn mode defined in ``self._process_transition`` method. The key of the \ + dict is the same as the input data, i.e. environment id. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + ``logit`` in SAC means the mu and sigma of Gaussioan distribution. Here we use this name for consistency. + + .. note:: + For more detailed examples, please refer to our unittest for SACPolicy: ``ding.policy.tests.test_sac``. + """ + data_id = list(data.keys()) + data = default_collate(list(data.values())) + if self._cuda: + data = to_device(data, self._device) + self._collect_model.eval() + with torch.no_grad(): + (mu, sigma) = self._collect_model.forward(data, mode='compute_actor')['logit'] + dist = Independent(Normal(mu, sigma), 1) + action = torch.tanh(dist.rsample()) + output = {'logit': (mu, sigma), 'action': action} + if self._cuda: + output = to_device(output, 'cpu') + output = default_decollate(output) + return {i: d for i, d in zip(data_id, output)} + + def _process_transition(self, obs: torch.Tensor, policy_output: Dict[str, torch.Tensor], + timestep: namedtuple) -> Dict[str, torch.Tensor]: + """ + Overview: + Process and pack one timestep transition data into a dict, which can be directly used for training and \ + saved in replay buffer. For continuous SAC, it contains obs, next_obs, action, reward, done. The logit \ + will be also added when ``collector_logit`` is True. + Arguments: + - obs (:obj:`torch.Tensor`): The env observation of current timestep, such as stacked 2D image in Atari. + - policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network with the observation \ + as input. For continuous SAC, it contains the action and the logit (mu and sigma) of the action. + - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \ + except all the elements have been transformed into tensor data. Usually, it contains the next obs, \ + reward, done, info, etc. + Returns: + - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep. + """ + if self._cfg.collect.collector_logit: + transition = { + 'obs': obs, + 'next_obs': timestep.obs, + 'logit': policy_output['logit'], + 'action': policy_output['action'], + 'reward': timestep.reward, + 'done': timestep.done, + } + else: + transition = { + 'obs': obs, + 'next_obs': timestep.obs, + 'action': policy_output['action'], + 'reward': timestep.reward, + 'done': timestep.done, + } + return transition + + def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Overview: + For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \ + can be used for training directly. In continuous SAC, a train sample is a processed transition \ + (unroll_len=1). + Arguments: + - transitions (:obj:`List[Dict[str, Any]`): The trajectory data (a list of transition), each element is \ + the same format as the return value of ``self._process_transition`` method. + Returns: + - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each element is the similar format \ + as input transitions, but may contain more data for training. + """ + return get_train_sample(transitions, self._unroll_len) + + def _init_eval(self) -> None: + """ + Overview: + Initialize the eval mode of policy, including related attributes and modules. For SAC, it contains the \ + eval model, which is equipped with ``base`` model wrapper to ensure compability. + This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``. + + .. note:: + If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \ + with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``. + """ + self._eval_model = model_wrap(self._model, wrapper_name='base') + self._eval_model.reset() + + def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: + """ + Overview: + Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \ + means that the policy gets some necessary data (mainly observation) from the envs and then returns the \ + action to interact with the envs. + Arguments: + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ + key of the dict is environment id and the value is the corresponding data of the env. + Returns: + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ + key of the dict is the same as the input data, i.e. environment id. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + ``logit`` in SAC means the mu and sigma of Gaussioan distribution. Here we use this name for consistency. + + .. note:: + For more detailed examples, please refer to our unittest for SACPolicy: ``ding.policy.tests.test_sac``. + """ + data_id = list(data.keys()) + data = default_collate(list(data.values())) + if self._cuda: + data = to_device(data, self._device) + self._eval_model.eval() + with torch.no_grad(): + (mu, sigma) = self._eval_model.forward(data, mode='compute_actor')['logit'] + action = torch.tanh(mu) / 1.05 # deterministic_eval + output = {'action': action} + if self._cuda: + output = to_device(output, 'cpu') + output = default_decollate(output) + return {i: d for i, d in zip(data_id, output)} + + def _monitor_vars_learn(self) -> List[str]: + """ + Overview: + Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ + as text logger, tensorboard logger, will use these keys to save the corresponding data. + Returns: + - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. + """ + twin_critic = ['twin_critic_loss'] if self._twin_critic else [] + return [ + 'cur_lr_q', + 'cur_lr_p', + 'value_loss' + 'policy_loss', + 'q_loss', + 'v_loss', + 'policy_loss', + 'log_prob', + 'total_loss', + 'advantage_max', + 'next_q_value', + 'next_v_value', + 'new_q_value', + 'new_v_value', + 'policy_grad_norm', + ] + twin_critic + + def _state_dict_learn(self) -> Dict[str, Any]: + """ + Overview: + Return the state_dict of learn mode, usually including model and optimizer. + Returns: + - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. + """ + return { + 'model': self._learn_model.state_dict(), + 'optimizer_q': self._optimizer_q.state_dict(), + 'optimizer_policy': self._optimizer_policy.state_dict(), + } + + def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + """ + Overview: + Load the state_dict variable into policy learn mode. + Arguments: + - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. + + .. tip:: + If you want to only load some parts of model, you can simply set the ``strict`` argument in \ + load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \ + complicated operation. + """ + self._learn_model.load_state_dict(state_dict['model']) + self._optimizer_q.load_state_dict(state_dict['optimizer_q']) + self._optimizer_policy.load_state_dict(state_dict['optimizer_policy']) diff --git a/ding/utils/data/dataset.py b/ding/utils/data/dataset.py index f29dd3335a..bd24e2eaf8 100755 --- a/ding/utils/data/dataset.py +++ b/ding/utils/data/dataset.py @@ -114,6 +114,38 @@ def __init__(self, cfg: dict) -> None: except (KeyError, AttributeError): # do not normalize pass + if hasattr(cfg.env, "reward_norm"): + if cfg.env.reward_norm == "normalize": + dataset['rewards'] = (dataset['rewards'] - dataset['rewards'].mean()) / dataset['rewards'].std() + elif cfg.env.reward_norm == "iql_antmaze": + dataset['rewards'] = dataset['rewards'] - 1.0 + elif cfg.env.reward_norm == "iql_locomotion": + + def return_range(dataset, max_episode_steps): + returns, lengths = [], [] + ep_ret, ep_len = 0.0, 0 + for r, d in zip(dataset["rewards"], dataset["terminals"]): + ep_ret += float(r) + ep_len += 1 + if d or ep_len == max_episode_steps: + returns.append(ep_ret) + lengths.append(ep_len) + ep_ret, ep_len = 0.0, 0 + # returns.append(ep_ret) # incomplete trajectory + lengths.append(ep_len) # but still keep track of number of steps + assert sum(lengths) == len(dataset["rewards"]) + return min(returns), max(returns) + + min_ret, max_ret = return_range(dataset, 1000) + dataset['rewards'] /= max_ret - min_ret + dataset['rewards'] *= 1000 + elif cfg.env.reward_norm == "cql_antmaze": + dataset['rewards'] = (dataset['rewards'] - 0.5) * 4.0 + elif cfg.env.reward_norm == "antmaze": + dataset['rewards'] = (dataset['rewards'] - 0.25) * 2.0 + else: + raise NotImplementedError + self._data = [] self._load_d4rl(dataset) diff --git a/dizoo/d4rl/config/halfcheetah_medium_iql_config.py b/dizoo/d4rl/config/halfcheetah_medium_iql_config.py new file mode 100644 index 0000000000..545ecf970b --- /dev/null +++ b/dizoo/d4rl/config/halfcheetah_medium_iql_config.py @@ -0,0 +1,54 @@ +# You can conduct Experiments on D4RL with this config file through the following command: +# cd ../entry && python d4rl_iql_main.py +from easydict import EasyDict + +main_config = dict( + exp_name="halfcheetah_medium_iql_seed0", + env=dict( + env_id='halfcheetah-medium-v2', + collector_env_num=1, + evaluator_env_num=8, + use_act_scale=True, + n_evaluator_episode=8, + stop_value=6000, + reward_norm="iql_locomotion", + ), + policy=dict( + cuda=True, + model=dict( + obs_shape=17, + action_shape=6, + + ), + learn=dict( + data_path=None, + train_epoch=30000, + batch_size=4096, + learning_rate_q=3e-4, + learning_rate_policy=1e-4, + beta=0.05, + tau=0.7, + ), + collect=dict(data_type='d4rl', ), + eval=dict(evaluator=dict(eval_freq=5000, )), + other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ), + ), +) + +main_config = EasyDict(main_config) +main_config = main_config + +create_config = dict( + env=dict( + type='d4rl', + import_names=['dizoo.d4rl.envs.d4rl_env'], + ), + env_manager=dict(type='base'), + policy=dict( + type='iql', + import_names=['ding.policy.iql'], + ), + replay_buffer=dict(type='naive', ), +) +create_config = EasyDict(create_config) +create_config = create_config diff --git a/dizoo/d4rl/entry/d4rl_iql_main.py b/dizoo/d4rl/entry/d4rl_iql_main.py new file mode 100644 index 0000000000..ded097ee42 --- /dev/null +++ b/dizoo/d4rl/entry/d4rl_iql_main.py @@ -0,0 +1,21 @@ +from ding.entry import serial_pipeline_offline +from ding.config import read_config +from pathlib import Path + + +def train(args): + # launch from anywhere + config = Path(__file__).absolute().parent.parent / 'config' / args.config + config = read_config(str(config)) + config[0].exp_name = config[0].exp_name.replace('0', str(args.seed)) + serial_pipeline_offline(config, seed=args.seed) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument('--seed', '-s', type=int, default=10) + parser.add_argument('--config', '-c', type=str, default='halfcheetah_medium_iql_config.py') + args = parser.parse_args() + train(args)