From 2811d98183c12a82c0ab050d6b0370d2fd11ba39 Mon Sep 17 00:00:00 2001 From: nighood Date: Wed, 18 Oct 2023 16:30:37 +0800 Subject: [PATCH 1/4] polish(rjy): polish comments in wqmix model --- ding/model/template/wqmix.py | 120 ++++++++++++++++++----------------- 1 file changed, 61 insertions(+), 59 deletions(-) diff --git a/ding/model/template/wqmix.py b/ding/model/template/wqmix.py index c5307d9bbf..c63344caa7 100644 --- a/ding/model/template/wqmix.py +++ b/ding/model/template/wqmix.py @@ -13,21 +13,24 @@ class MixerStar(nn.Module): """ Overview: - mixer network for Q_star in WQMIX , which mix up the independent q_value of - each agent to a total q_value and is diffrent from the Qmix's mixer network, - here the mixing network is a feedforward network with 3 hidden layers of 256 dim. + Mixer network for Q_star in WQMIX(https://arxiv.org/abs/2006.10800), which mix up the independent q_value of \ + each agent to a total q_value and is diffrent from the QMIX's mixer network, \ + here the mixing network is a feedforward network with 3 hidden layers of 256 dim. \ + This Q_star mixing network is not constrained to be monotonic by using non-negative weights and \ + having the state and agent_q be inputs, as opposed to having hypernetworks take the state as input \ + and generate the weights in QMIX. Interface: - __init__, forward + ``__init__``, ``forward``. """ def __init__(self, agent_num: int, state_dim: int, mixing_embed_dim: int) -> None: """ Overview: - initialize the mixer network of Q_star in WQMIX. + Initialize the mixer network of Q_star in WQMIX. Arguments: - - agent_num (:obj:`int`): the number of agent - - state_dim(:obj:`int`): the dimension of global observation state - - mixing_embed_dim (:obj:`int`): the dimension of mixing state emdedding + - agent_num (:obj:`int`): The number of agent, e.g., 8. + - state_dim(:obj:`int`): The dimension of global observation state, e.g., 16. + - mixing_embed_dim (:obj:`int`): The dimension of mixing state emdedding, e.g., 128. """ super(MixerStar, self).__init__() self.agent_num = agent_num @@ -46,17 +49,18 @@ def __init__(self, agent_num: int, state_dim: int, mixing_embed_dim: int) -> Non def forward(self, agent_qs: torch.FloatTensor, states: torch.FloatTensor) -> torch.FloatTensor: """ Overview: - forward computation graph of the mixer network for Q_star in WQMIX. + Forward computation graph of the mixer network for Q_star in WQMIX. This mixer network for \ + is a feed-forward network that takes the state and the appropriate actions' utilities as input. Arguments: - - agent_qs (:obj:`torch.FloatTensor`): the independent q_value of each agent - - states (:obj:`torch.FloatTensor`): the emdedding vector of global state + - agent_qs (:obj:`torch.FloatTensor`): The independent q_value of each agent. + - states (:obj:`torch.FloatTensor`): The emdedding vector of global state. Returns: - - q_tot (:obj:`torch.FloatTensor`): the total mixed q_value + - q_tot (:obj:`torch.FloatTensor`): The total mixed q_value Shapes: - - agent_qs (:obj:`torch.FloatTensor`): :math:`(T,B, N)`, where T is timestep, - B is batch size, A is agent_num, N is obs_shape - - states (:obj:`torch.FloatTensor`): :math:`(T, B, M)`, where M is global_obs_shape - - q_tot (:obj:`torch.FloatTensor`): :math:`(T, B, )` + - agent_qs (:obj:`torch.FloatTensor`): :math:`(T,B, N)`, where T is timestep, \ + B is batch size, A is agent_num, N is obs_shape. + - states (:obj:`torch.FloatTensor`): :math:`(T, B, M)`, where M is global_obs_shape. + - q_tot (:obj:`torch.FloatTensor`): :math:`(T, B, )`. """ # in below annotations about the shape of the variables, T is timestep, # B is batch_size A is agent_num, N is obs_shape, for example, @@ -77,9 +81,13 @@ def forward(self, agent_qs: torch.FloatTensor, states: torch.FloatTensor) -> tor class WQMix(nn.Module): """ Overview: - WQMIX network, which is same as Qmix network + WQMIX (https://arxiv.org/abs/2006.10800) network, There are two components: \ + 1) Q_tot, which is same as QMIX network and composed of agent Q network and mixer network. \ + 2) An unrestricted joint action Q_star, which is composed of agent Q network and mixer_star network. \ + The QMIX paper mentions that all agents share local Q network parameters, so only one Q network is initialized \ + in Q_tot or Q_star. Interface: - __init__, forward, _setup_global_encoder + ``__init__``, ``forward``. """ def __init__( @@ -94,15 +102,19 @@ def __init__( ) -> None: """ Overview: - initialize Qmix network + Initialize WQMIX neural network according to arguments, i.e. agent Q network and mixer, \ + Q_star network and mixer_star. Arguments: - - agent_num (:obj:`int`): the number of agent - - obs_shape (:obj:`int`): the dimension of each agent's observation state - - global_obs_shape (:obj:`int`): the dimension of global observation state - - action_shape (:obj:`int`): the dimension of action shape - - hidden_size_list (:obj:`list`): the list of hidden size - - lstm_type (:obj:`str`): use lstm or gru, default to gru - - dueling (:obj:`bool`): use dueling head or not, default to False. + - agent_num (:obj:`int`): The number of agent, such as 8. + - obs_shape (:obj:`int`): The dimension of each agent's observation state, such as 8 or [4, 84, 84]. + - global_obs_shape (:obj:`int`): The dimension of global observation state, such as 8 or [4, 84, 84]. + - action_shape (:obj:`int`): The dimension of action shape, such as 6 or [2, 3, 3]. + - hidden_size_list (:obj:`list`): The list of hidden size for ``q_network``, \ + the last element must match mixer's ``mixing_embed_dim``. + - lstm_type (:obj:`str`): The type of RNN module in ``q_network``, now support \ + ['normal', 'pytorch', 'gru'], default to gru. + - dueling (:obj:`bool`): Whether choose ``DuelingHead`` (True) or ``DiscreteHead (False)``, \ + default to False. """ super(WQMix, self).__init__() self._act = nn.ReLU() @@ -118,34 +130,36 @@ def __init__( def forward(self, data: dict, single_step: bool = True, q_star: bool = False) -> dict: """ Overview: - forward computation graph of qmix network + Forward computation graph of qmix network. Input dict including time series observation and \ + related data to predict total q_value and each agent q_value. Determine whether to calculate \ + Q_tot or Q_star based on the ``q_star`` parameter Arguments: - - data (:obj:`dict`): input data dict with keys ['obs', 'prev_state', 'action'] - - agent_state (:obj:`torch.Tensor`): each agent local state(obs) - - global_state (:obj:`torch.Tensor`): global state(obs) - - prev_state (:obj:`list`): previous rnn state - - action (:obj:`torch.Tensor` or None): if action is None, use argmax q_value index as action to\ - calculate ``agent_q_act`` - - single_step (:obj:`bool`): whether single_step forward, if so, add timestep dim before forward and\ - remove it after forward - - Q_star (:obj:`bool`): whether Q_star network forward. If True, using the Q_star network, where the\ + - data (:obj:`dict`): Input data dict with keys ['obs', 'prev_state', 'action']. + - agent_state (:obj:`torch.Tensor`): Time series local observation data of each agents. + - global_state (:obj:`torch.Tensor`): Time series global observation data. + - prev_state (:obj:`list`): Previous rnn state for ``q_network`` or ``_q_network_star``. + - action (:obj:`torch.Tensor` or None): If action is None, use argmax q_value index as action to\ + calculate ``agent_q_act``. + - single_step (:obj:`bool`): Whether single_step forward, if so, add timestep dim before forward and\ + remove it after forward. + - Q_star (:obj:`bool`): Whether Q_star network forward. If True, using the Q_star network, where the\ agent networks have the same architecture as Q network but do not share parameters and the mixing\ network is a feedforward network with 3 hidden layers of 256 dim; if False, using the Q network,\ same as the Q network in Qmix paper. Returns: - - ret (:obj:`dict`): output data dict with keys [``total_q``, ``logit``, ``next_state``] - - total_q (:obj:`torch.Tensor`): total q_value, which is the result of mixer network - - agent_q (:obj:`torch.Tensor`): each agent q_value - - next_state (:obj:`list`): next rnn state + - ret (:obj:`dict`): Output data dict with keys [``total_q``, ``logit``, ``next_state``]. + - total_q (:obj:`torch.Tensor`): Total q_value, which is the result of mixer network. + - agent_q (:obj:`torch.Tensor`): Each agent q_value. + - next_state (:obj:`list`): Next rnn state. Shapes: - agent_state (:obj:`torch.Tensor`): :math:`(T, B, A, N)`, where T is timestep, B is batch_size\ - A is agent_num, N is obs_shape - - global_state (:obj:`torch.Tensor`): :math:`(T, B, M)`, where M is global_obs_shape - - prev_state (:obj:`list`): math:`(T, B, A)`, a list of length B, and each element is a list of length A - - action (:obj:`torch.Tensor`): :math:`(T, B, A)` - - total_q (:obj:`torch.Tensor`): :math:`(T, B)` - - agent_q (:obj:`torch.Tensor`): :math:`(T, B, A, P)`, where P is action_shape - - next_state (:obj:`list`): math:`(T, B, A)`, a list of length B, and each element is a list of length A + A is agent_num, N is obs_shape. + - global_state (:obj:`torch.Tensor`): :math:`(T, B, M)`, where M is global_obs_shape. + - prev_state (:obj:`list`): math:`(T, B, A)`, a list of length B, and each element is a list of length A. + - action (:obj:`torch.Tensor`): :math:`(T, B, A)`. + - total_q (:obj:`torch.Tensor`): :math:`(T, B)`. + - agent_q (:obj:`torch.Tensor`): :math:`(T, B, A, P)`, where P is action_shape. + - next_state (:obj:`list`): math:`(T, B, A)`, a list of length B, and each element is a list of length A. """ if q_star: # forward using Q_star network agent_state, global_state, prev_state = data['obs']['agent_state'], data['obs']['global_state'], data[ @@ -239,15 +253,3 @@ def forward(self, data: dict, single_step: bool = True, q_star: bool = False) -> 'next_state': next_state, 'action_mask': data['obs']['action_mask'] } - - def _setup_global_encoder(self, global_obs_shape: int, embedding_size: int) -> torch.nn.Module: - """ - Overview: - Used to encoder global observation. - Arguments: - - global_obs_shape (:obj:`int`): the dimension of global observation state - - embedding_size (:obj:`int`): the dimension of state emdedding - Return: - - outputs (:obj:`torch.nn.Module`): Global observation encoding network - """ - return MLP(global_obs_shape, embedding_size, embedding_size, 2, activation=self._act) From 875e9b5326947293e39832dc020bcab04bb9fc9a Mon Sep 17 00:00:00 2001 From: nighood Date: Thu, 19 Oct 2023 15:22:26 +0800 Subject: [PATCH 2/4] polish(rjy): polish comments in ngu model --- ding/model/template/ngu.py | 64 ++++++++++++++++++++---------------- ding/model/template/wqmix.py | 6 ++-- 2 files changed, 38 insertions(+), 32 deletions(-) diff --git a/ding/model/template/ngu.py b/ding/model/template/ngu.py index eb605982bd..8595d83857 100644 --- a/ding/model/template/ngu.py +++ b/ding/model/template/ngu.py @@ -10,9 +10,9 @@ def parallel_wrapper(forward_fn: Callable) -> Callable: - r""" + """ Overview: - Process timestep T and batch_size B at the same time, in other words, treat different timestep data as + Process timestep T and batch_size B at the same time, in other words, treat different timestep data as \ different trajectories in a batch. Arguments: - forward_fn (:obj:`Callable`): Normal ``nn.Module`` 's forward function. @@ -44,9 +44,12 @@ def reshape(d): class NGU(nn.Module): """ Overview: - The recurrent Q model for NGU policy, modified from the class DRQN in q_leaning.py - input: x_t, a_{t-1}, r_e_{t-1}, r_i_{t-1}, beta - output: + The recurrent Q model for NGU(https://arxiv.org/pdf/2002.06038.pdf) policy, modified from the class DRQN in \ + q_leaning.py. The implementation mentioned in the original paper is 'adapt the R2D2 agent that uses the \ + dueling network architecture with an LSTM layer after a convolutional neural network'. The NGU network \ + includes encoder, LSTM core(rnn) and head. + Interface: + ``__init__``, ``forward``. """ def __init__( @@ -62,20 +65,26 @@ def __init__( activation: Optional[nn.Module] = nn.ReLU(), norm_type: Optional[str] = None ) -> None: - r""" + """ Overview: - Init the DRQN Model according to arguments. + Init the DRQN Model for NGU according to arguments. Arguments: - - obs_shape (:obj:`Union[int, SequenceType]`): Observation's space. - - action_shape (:obj:`Union[int, SequenceType]`): Action's space. - - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder`` - - head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to ``Head``. - - lstm_type (:obj:`Optional[str]`): Version of rnn cell, now support ['normal', 'pytorch', 'hpc', 'gru'] + - obs_shape (:obj:`Union[int, SequenceType]`): Observation's space, such as 8 or [4, 84, 84]. + - action_shape (:obj:`Union[int, SequenceType]`): Action's space, such as 6 or [2, 3, 3]. + - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``. + - collector_env_num (:obj:`Optional[int]`): The number of environments used to collect data simultaneously. + - dueling (:obj:`bool`): Whether choose ``DuelingHead`` (True) or ``DiscreteHead (False)``, \ + default to True. + - head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to ``Head``, should match the \ + last element of ``encoder_hidden_size_list``. + - head_layer_num (:obj:`int`): The number of layers in head network. + - lstm_type (:obj:`Optional[str]`): Version of rnn cell, now support ['normal', 'pytorch', 'hpc', 'gru'], \ + default is 'normal'. - activation (:obj:`Optional[nn.Module]`): - The type of activation function to use in ``MLP`` the after ``layer_fn``, - if ``None`` then default set to ``nn.ReLU()`` + The type of activation function to use in ``MLP`` the after ``layer_fn``, \ + if ``None`` then default set to ``nn.ReLU()``. - norm_type (:obj:`Optional[str]`): - The type of normalization to use, see ``ding.torch_utils.fc_block`` for more details` + The type of normalization to use, see ``ding.torch_utils.fc_block`` for more details`. """ super(NGU, self).__init__() # For compatibility: 1, (1, ), [4, H, H] @@ -122,32 +131,29 @@ def __init__( def forward(self, inputs: Dict, inference: bool = False, saved_state_timesteps: Optional[list] = None) -> Dict: r""" Overview: - Use observation, prev_action prev_reward_extrinsic to predict NGU Q output. - Parameter updates with NGU's MLPs forward setup. + Forward computation graph of NGU R2D2 network. Input observation, prev_action prev_reward_extrinsic \ + to predict NGU Q output. Parameter updates with NGU's MLPs forward setup. Arguments: - inputs (:obj:`Dict`): - - inference: (:obj:'bool'): if inference is True, we unroll the one timestep transition, + - obs (:obj:`torch.Tensor`): Encoded observation. + - prev_state (:obj:`list`): Previous state's tensor of size ``(B, N)``. + - inference: (:obj:'bool'): If inference is True, we unroll the one timestep transition, \ if inference is False, we unroll the sequence transitions. - - saved_state_timesteps: (:obj:'Optional[list]'): when inference is False, - we unroll the sequence transitions, then we would save rnn hidden states at timesteps + - saved_state_timesteps: (:obj:'Optional[list]'): When inference is False, \ + we unroll the sequence transitions, then we would save rnn hidden states at timesteps \ that are listed in list saved_state_timesteps. - - ArgumentsKeys: - - obs (:obj:`torch.Tensor`): Encoded observation - - prev_state (:obj:`list`): Previous state's tensor of size ``(B, N)`` - Returns: - outputs (:obj:`Dict`): Run ``MLP`` with ``DRQN`` setups and return the result prediction dictionary. ReturnsKeys: - logit (:obj:`torch.Tensor`): Logit tensor with same size as input ``obs``. - - next_state (:obj:`list`): Next state's tensor of size ``(B, N)`` + - next_state (:obj:`list`): Next state's tensor of size ``(B, N)``. Shapes: - obs (:obj:`torch.Tensor`): :math:`(B, N=obs_space)`, where B is batch size. - - prev_state(:obj:`torch.FloatTensor list`): :math:`[(B, N)]` - - logit (:obj:`torch.FloatTensor`): :math:`(B, N)` - - next_state(:obj:`torch.FloatTensor list`): :math:`[(B, N)]` + - prev_state(:obj:`torch.FloatTensor list`): :math:`[(B, N)]`. + - logit (:obj:`torch.FloatTensor`): :math:`(B, N)`. + - next_state(:obj:`torch.FloatTensor list`): :math:`[(B, N)]`. """ x, prev_state = inputs['obs'], inputs['prev_state'] if 'prev_action' in inputs.keys(): diff --git a/ding/model/template/wqmix.py b/ding/model/template/wqmix.py index c63344caa7..15838138c4 100644 --- a/ding/model/template/wqmix.py +++ b/ding/model/template/wqmix.py @@ -106,9 +106,9 @@ def __init__( Q_star network and mixer_star. Arguments: - agent_num (:obj:`int`): The number of agent, such as 8. - - obs_shape (:obj:`int`): The dimension of each agent's observation state, such as 8 or [4, 84, 84]. - - global_obs_shape (:obj:`int`): The dimension of global observation state, such as 8 or [4, 84, 84]. - - action_shape (:obj:`int`): The dimension of action shape, such as 6 or [2, 3, 3]. + - obs_shape (:obj:`int`): The dimension of each agent's observation state, such as 8. + - global_obs_shape (:obj:`int`): The dimension of global observation state, such as 8. + - action_shape (:obj:`int`): The dimension of action shape, such as 6. - hidden_size_list (:obj:`list`): The list of hidden size for ``q_network``, \ the last element must match mixer's ``mixing_embed_dim``. - lstm_type (:obj:`str`): The type of RNN module in ``q_network``, now support \ From a9981fae7f206e505bced4e8347620f54fdbe46d Mon Sep 17 00:00:00 2001 From: nighood Date: Thu, 19 Oct 2023 16:42:29 +0800 Subject: [PATCH 3/4] polish(rjy): polish comments in pg model --- ding/model/template/ngu.py | 2 +- ding/model/template/pg.py | 44 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/ding/model/template/ngu.py b/ding/model/template/ngu.py index 8595d83857..d1c9d9fb99 100644 --- a/ding/model/template/ngu.py +++ b/ding/model/template/ngu.py @@ -48,7 +48,7 @@ class NGU(nn.Module): q_leaning.py. The implementation mentioned in the original paper is 'adapt the R2D2 agent that uses the \ dueling network architecture with an LSTM layer after a convolutional neural network'. The NGU network \ includes encoder, LSTM core(rnn) and head. - Interface: + Interface: ``__init__``, ``forward``. """ diff --git a/ding/model/template/pg.py b/ding/model/template/pg.py index 8a1f3f889d..e4297aa805 100644 --- a/ding/model/template/pg.py +++ b/ding/model/template/pg.py @@ -11,6 +11,15 @@ @MODEL_REGISTRY.register('pg') class PG(nn.Module): + """ + Overview: + The neural network and computation graph of algorithms related to Policy Gradient(PG) \ + (https://proceedings.neurips.cc/paper/1999/file/464d828b85b0bed98e80ade0a5c43b0f-Paper.pdf). \ + The PG model is composed of two parts: encoder and head. Encoders are used to extract the feature \ + from various observation. Heads are used to predict corresponding action logit. \ + Interface: + ``__init__``, ``forward``. + """ def __init__( self, @@ -23,6 +32,31 @@ def __init__( activation: Optional[nn.Module] = nn.ReLU(), norm_type: Optional[str] = None ) -> None: + """ + Overview: + Initialize the PG model according to corresponding input arguments. + Arguments: + - obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape, such as 8 or [4, 84, 84]. + - action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3]. + - action_space (:obj:`str`): The type of different action spaces, including ['discrete', 'continuous'], \ + then will instantiate corresponding head, including ``DiscreteHead`` and ``ReparameterizationHead``. + - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \ + the last element must match ``head_hidden_size``. + - head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of ``head`` network, defaults \ + to None, it must match the last element of ``encoder_hidden_size_list``. + - head_layer_num (:obj:`int`): The num of layers used in the ``head`` network to compute action. + - activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \ + if ``None`` then default set it to ``nn.ReLU()``. + - norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \ + ``ding.torch_utils.fc_block`` for more details. you can choose one of ['BN', 'IN', 'SyncBN', 'LN'] + Examples: + >>> model = PG((4, 84, 84), 5) + >>> inputs = torch.randn(8, 4, 84, 84) + >>> outputs = model(inputs) + >>> assert isinstance(outputs, dict) + >>> assert outputs['logit'].shape == (8, 5) + >>> assert outputs['dist'].sample().shape == (8, ) + """ super(PG, self).__init__() # For compatibility: 1, (1, ), [4, 32, 32] obs_shape, action_shape = squeeze(obs_shape), squeeze(action_shape) @@ -57,6 +91,16 @@ def __init__( raise KeyError("not support action space: {}".format(self.action_space)) def forward(self, x: torch.Tensor) -> Dict: + """ + Overview: + PG forward computation graph, input observation tensor to predict policy distribution. + Arguments: + - x (:obj:`torch.Tensor`): The input observation tensor data. + Returns: + - outputs (:obj:`torch.distributions`): The output policy distribution. If action space is \ + discrete, the output is Categorical distribution; if action space is continuous, the output is Normal \ + distribution. + """ x = self.encoder(x) x = self.head(x) if self.action_space == 'discrete': From c3853115c238209ec2117b51d33dfa1d30e96e64 Mon Sep 17 00:00:00 2001 From: nighood Date: Tue, 31 Oct 2023 15:55:13 +0800 Subject: [PATCH 4/4] polish(rjy): polish according to comments --- ding/model/template/pg.py | 2 +- ding/model/template/wqmix.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ding/model/template/pg.py b/ding/model/template/pg.py index e4297aa805..6059642dd3 100644 --- a/ding/model/template/pg.py +++ b/ding/model/template/pg.py @@ -16,7 +16,7 @@ class PG(nn.Module): The neural network and computation graph of algorithms related to Policy Gradient(PG) \ (https://proceedings.neurips.cc/paper/1999/file/464d828b85b0bed98e80ade0a5c43b0f-Paper.pdf). \ The PG model is composed of two parts: encoder and head. Encoders are used to extract the feature \ - from various observation. Heads are used to predict corresponding action logit. \ + from various observation. Heads are used to predict corresponding action logit. Interface: ``__init__``, ``forward``. """ diff --git a/ding/model/template/wqmix.py b/ding/model/template/wqmix.py index 15838138c4..f80aa25d4a 100644 --- a/ding/model/template/wqmix.py +++ b/ding/model/template/wqmix.py @@ -55,7 +55,7 @@ def forward(self, agent_qs: torch.FloatTensor, states: torch.FloatTensor) -> tor - agent_qs (:obj:`torch.FloatTensor`): The independent q_value of each agent. - states (:obj:`torch.FloatTensor`): The emdedding vector of global state. Returns: - - q_tot (:obj:`torch.FloatTensor`): The total mixed q_value + - q_tot (:obj:`torch.FloatTensor`): The total mixed q_value. Shapes: - agent_qs (:obj:`torch.FloatTensor`): :math:`(T,B, N)`, where T is timestep, \ B is batch size, A is agent_num, N is obs_shape. @@ -132,7 +132,7 @@ def forward(self, data: dict, single_step: bool = True, q_star: bool = False) -> Overview: Forward computation graph of qmix network. Input dict including time series observation and \ related data to predict total q_value and each agent q_value. Determine whether to calculate \ - Q_tot or Q_star based on the ``q_star`` parameter + Q_tot or Q_star based on the ``q_star`` parameter. Arguments: - data (:obj:`dict`): Input data dict with keys ['obs', 'prev_state', 'action']. - agent_state (:obj:`torch.Tensor`): Time series local observation data of each agents.