diff --git a/CHANGELOG.md b/CHANGELOG.md index 07abeea5cc..604b513890 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,13 +4,16 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). - ## [unReleased] - 2021-MM-DD ### Added - Added `EMNISTDataModule`, `BinaryEMNISTDataModule`, and `BinaryEMNIST` dataset ([#676](https://github.com/PyTorchLightning/lightning-bolts/pull/676)) + +- Added Advantage Actor-Critic (A2C) Model [#598](https://github.com/PyTorchLightning/lightning-bolts/pull/598)) + + ### Changed - Changed the default values `pin_memory=False`, `shuffle=False` and `num_workers=16` to `pin_memory=True`, `shuffle=True` and `num_workers=0` of datamodules ([#701](https://github.com/PyTorchLightning/lightning-bolts/pull/701)) diff --git a/docs/source/_images/rl_benchmark/cartpole_a2c_results.jpg b/docs/source/_images/rl_benchmark/cartpole_a2c_results.jpg new file mode 100644 index 0000000000..15be42c746 Binary files /dev/null and b/docs/source/_images/rl_benchmark/cartpole_a2c_results.jpg differ diff --git a/docs/source/reinforce_learn.rst b/docs/source/reinforce_learn.rst index 663d5b27b2..ecdb3a1c9c 100644 --- a/docs/source/reinforce_learn.rst +++ b/docs/source/reinforce_learn.rst @@ -25,6 +25,7 @@ Contributions by: `Donal Byrne `_ RL models currently only support CPU and single GPU training with `distributed_backend=dp`. Full GPU support will be added in later updates. +------------ DQN Models ---------- @@ -86,7 +87,7 @@ Example:: trainer = Trainer() trainer.fit(dqn) -.. autoclass:: pl_bolts.models.rl.dqn_model.DQN +.. autoclass:: pl_bolts.models.rl.DQN :noindex: --------------- @@ -150,7 +151,7 @@ Example:: trainer = Trainer() trainer.fit(ddqn) -.. autoclass:: pl_bolts.models.rl.double_dqn_model.DoubleDQN +.. autoclass:: pl_bolts.models.rl.DoubleDQN :noindex: --------------- @@ -240,7 +241,7 @@ Example:: trainer = Trainer() trainer.fit(dueling_dqn) -.. autoclass:: pl_bolts.models.rl.dueling_dqn_model.DuelingDQN +.. autoclass:: pl_bolts.models.rl.DuelingDQN :noindex: -------------- @@ -326,7 +327,7 @@ Example:: trainer = Trainer() trainer.fit(noisy_dqn) -.. autoclass:: pl_bolts.models.rl.noisy_dqn_model.NoisyDQN +.. autoclass:: pl_bolts.models.rl.NoisyDQN :noindex: -------------- @@ -519,7 +520,7 @@ Example:: trainer = Trainer() trainer.fit(per_dqn) -.. autoclass:: pl_bolts.models.rl.per_dqn_model.PERDQN +.. autoclass:: pl_bolts.models.rl.PERDQN :noindex: @@ -611,7 +612,7 @@ Example:: trainer = Trainer() trainer.fit(reinforce) -.. autoclass:: pl_bolts.models.rl.reinforce_model.Reinforce +.. autoclass:: pl_bolts.models.rl.Reinforce :noindex: -------------- @@ -664,5 +665,102 @@ Example:: trainer = Trainer() trainer.fit(vpg) -.. autoclass:: pl_bolts.models.rl.vanilla_policy_gradient_model.VanillaPolicyGradient +.. autoclass:: pl_bolts.models.rl.VanillaPolicyGradient + :noindex: + +-------------- + +Actor-Critic Models +------------------- +The following models are based on Actor Critic. Actor Critic conbines the approaches of value-based learning (the DQN family) +and the policy-based learning (the PG family) by learning the value function as well as the policy distribution. This approach +updates the policy network according to the policy gradient, and updates the value network to fit the discounted rewards. + +Actor Critic Key Points: + - Actor outputs a distribution of actions for controlling the agent + - Critic outputs a value of current state for policy update suggestion + - The addition of critic allows the model to do n-step training instead of generating an entire trajectory + +Advantage Actor Critic (A2C) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +(Asynchronous) Advantage Actor Critic model introduced in `Asynchronous Methods for Deep Reinforcement Learning `_ +Paper authors: Volodymyr Mnih, Adrià Puigdomènech Badia, Mehdi Mirza, Alex Graves, Timothy P. Lillicrap, Tim Harley, David Silver, Koray Kavukcuoglu + +Original implementation by: `Jason Wang `_ + +Advantage Actor Critic (A2C) is the classical actor critic approach in reinforcement learning. The underlying neural +network has an actor head and a critic head to output action distribution as well as value of current state. Usually the +first few layers are shared by the two heads to prevent learning similar stuff twice. It builds upon the idea of using a +baseline of average reward to reduce variance (in VPG) by using the critic as a baseline which could theoretically have +better performance. + +The algorithm can use an n-step training approach instead of generating an entire trajectory. The algorithm is as follows: + +1. Initialize our network. +2. Rollout n steps and save the transitions (states, actions, rewards, values, dones). +3. Calculate the n-step (discounted) return by bootstrapping the last value. + +.. math:: + + G_{n+1} = V_{n+1}, G_t = r_t + \gamma G_{t+1} \ \forall t \in [0,n] + +4. Calculate actor loss using values as baseline. + +.. math:: + + L_{actor} = - \frac1n \sum_t (G_t - V_t) \log \pi (a_t | s_t) + +5. Calculate critic loss using returns as target. + +.. math:: + L_{critic} = \frac1n \sum_t (V_t - G_t)^2 + +6. Calculate entropy bonus to encourage exploration. + +.. math:: + + H_\pi = - \frac1n \sum_t \pi (a_t | s_t) \log \pi (a_t | s_t) + +7. Calculate total loss as a weighted sum of the three components above. + +.. math:: + + L = L_{actor} + \beta_{critic} L_{critic} - \beta_{entropy} H_\pi + +8. Perform gradient descent to update our network. + +.. note:: + The current implementation only support discrete action space, and has only been tested on the CartPole environment. + +A2C Benefits +~~~~~~~~~~~~~~~ + +- Combines the benefit from value-based learning and policy-based learning + +- Further reduces variance using the critic as a value estimator + +A2C Results +~~~~~~~~~~~~~~~~ + +Hyperparameters: + +- Batch Size: 32 +- Learning Rate: 0.001 +- Entropy Beta: 0.01 +- Critic Beta: 0.5 +- Gamma: 0.99 + +.. image:: _images/rl_benchmark/cartpole_a2c_results.jpg + :width: 300 + :alt: A2C Results + +Example:: + + from pl_bolts.models.rl import AdvantageActorCritic + a2c = AdvantageActorCritic("CartPole-v0") + trainer = Trainer() + trainer.fit(a2c) + +.. autoclass:: pl_bolts.models.rl.AdvantageActorCritic :noindex: diff --git a/pl_bolts/models/rl/__init__.py b/pl_bolts/models/rl/__init__.py index 070ec666be..a84b51dec6 100644 --- a/pl_bolts/models/rl/__init__.py +++ b/pl_bolts/models/rl/__init__.py @@ -1,12 +1,14 @@ -from pl_bolts.models.rl.double_dqn_model import DoubleDQN # noqa: F401 -from pl_bolts.models.rl.dqn_model import DQN # noqa: F401 -from pl_bolts.models.rl.dueling_dqn_model import DuelingDQN # noqa: F401 -from pl_bolts.models.rl.noisy_dqn_model import NoisyDQN # noqa: F401 -from pl_bolts.models.rl.per_dqn_model import PERDQN # noqa: F401 -from pl_bolts.models.rl.reinforce_model import Reinforce # noqa: F401 -from pl_bolts.models.rl.vanilla_policy_gradient_model import VanillaPolicyGradient # noqa: F401 +from pl_bolts.models.rl.advantage_actor_critic_model import AdvantageActorCritic +from pl_bolts.models.rl.double_dqn_model import DoubleDQN +from pl_bolts.models.rl.dqn_model import DQN +from pl_bolts.models.rl.dueling_dqn_model import DuelingDQN +from pl_bolts.models.rl.noisy_dqn_model import NoisyDQN +from pl_bolts.models.rl.per_dqn_model import PERDQN +from pl_bolts.models.rl.reinforce_model import Reinforce +from pl_bolts.models.rl.vanilla_policy_gradient_model import VanillaPolicyGradient __all__ = [ + "AdvantageActorCritic", "DoubleDQN", "DQN", "DuelingDQN", diff --git a/pl_bolts/models/rl/advantage_actor_critic_model.py b/pl_bolts/models/rl/advantage_actor_critic_model.py new file mode 100644 index 0000000000..9f2a835a0d --- /dev/null +++ b/pl_bolts/models/rl/advantage_actor_critic_model.py @@ -0,0 +1,327 @@ +""" +Advantage Actor Critic (A2C) +""" +from argparse import ArgumentParser +from collections import OrderedDict +from typing import Any, Iterator, List, Tuple + +import numpy as np +import torch +from pytorch_lightning import LightningModule, seed_everything, Trainer +from pytorch_lightning.callbacks import ModelCheckpoint +from torch import optim as optim +from torch import Tensor +from torch.optim.optimizer import Optimizer +from torch.utils.data import DataLoader + +from pl_bolts.datamodules import ExperienceSourceDataset +from pl_bolts.models.rl.common.agents import ActorCriticAgent +from pl_bolts.models.rl.common.networks import ActorCriticMLP +from pl_bolts.utils import _GYM_AVAILABLE +from pl_bolts.utils.warnings import warn_missing_pkg + +if _GYM_AVAILABLE: + import gym +else: # pragma: no cover + warn_missing_pkg("gym") + + +class AdvantageActorCritic(LightningModule): + """ + PyTorch Lightning implementation of `Advantage Actor Critic + `_ + + Paper Authors: Volodymyr Mnih, Adrià Puigdomènech Badia, et al. + + Model implemented by: + + - `Jason Wang `_ + + Example: + >>> from pl_bolts.models.rl import AdvantageActorCritic + ... + >>> model = AdvantageActorCritic("CartPole-v0") + """ + + def __init__( + self, + env: str, + gamma: float = 0.99, + lr: float = 0.001, + batch_size: int = 32, + avg_reward_len: int = 100, + entropy_beta: float = 0.01, + critic_beta: float = 0.5, + epoch_len: int = 1000, + **kwargs: Any, + ) -> None: + """ + Args: + env: gym environment tag + gamma: discount factor + lr: learning rate + batch_size: size of minibatch pulled from the DataLoader + batch_episodes: how many episodes to rollout for each batch of training + avg_reward_len: how many episodes to take into account when calculating the avg reward + entropy_beta: dictates the level of entropy per batch + critic_beta: dictates the level of critic loss per batch + epoch_len: how many batches before pseudo epoch + """ + super().__init__() + + if not _GYM_AVAILABLE: # pragma: no cover + raise ModuleNotFoundError("This Module requires gym environment which is not installed yet.") + + # Hyperparameters + self.save_hyperparameters() + self.batches_per_epoch = batch_size * epoch_len + + # Model components + self.env = gym.make(env) + self.net = ActorCriticMLP(self.env.observation_space.shape, self.env.action_space.n) + self.agent = ActorCriticAgent(self.net) + + # Tracking metrics + self.total_rewards = [0] + self.episode_reward = 0 + self.done_episodes = 0 + self.avg_rewards = 0.0 + self.avg_reward_len = avg_reward_len + self.eps = np.finfo(np.float32).eps.item() + self.batch_states: List = [] + self.batch_actions: List = [] + self.batch_rewards: List = [] + self.batch_masks: List = [] + + self.state = self.env.reset() + + def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: + """ + Passes in a state x through the network and gets the log prob of each action + and the value for the state as an output + + Args: + x: environment state + + Returns: + action log probabilities, values + """ + if not isinstance(x, list): + x = [x] + + if not isinstance(x, Tensor): + x = torch.tensor(x, device=self.device) + + logprobs, values = self.net(x) + return logprobs, values + + def train_batch(self) -> Iterator[Tuple[np.ndarray, int, Tensor]]: + """ + Contains the logic for generating a new batch of data to be passed to the DataLoader + + Returns: + yields a tuple of Lists containing tensors for + states, actions, and returns of the batch. + + Note: + This is what's taken by the dataloader: + states: a list of numpy array + actions: a list of list of int + returns: a torch tensor + """ + while True: + for _ in range(self.hparams.batch_size): + action = self.agent(self.state, self.device)[0] + + next_state, reward, done, _ = self.env.step(action) + + self.batch_rewards.append(reward) + self.batch_actions.append(action) + self.batch_states.append(self.state) + self.batch_masks.append(done) + self.state = next_state + self.episode_reward += reward + + if done: + self.done_episodes += 1 + self.state = self.env.reset() + self.total_rewards.append(self.episode_reward) + self.episode_reward = 0 + self.avg_rewards = float(np.mean(self.total_rewards[-self.avg_reward_len:])) + + _, last_value = self.forward(self.state) + + returns = self.compute_returns(self.batch_rewards, self.batch_masks, last_value) + for idx in range(self.hparams.batch_size): + yield self.batch_states[idx], self.batch_actions[idx], returns[idx] + + self.batch_states = [] + self.batch_actions = [] + self.batch_rewards = [] + self.batch_masks = [] + + def compute_returns( + self, + rewards: List[float], + dones: List[bool], + last_value: Tensor, + ) -> Tensor: + """ + Calculate the discounted rewards of the batched rewards + + Args: + rewards: list of rewards + dones: list of done masks + last_value: the predicted value for the last state (for bootstrap) + + Returns: + tensor of discounted rewards + """ + g = last_value + returns = [] + + for r, d in zip(rewards[::-1], dones[::-1]): + g = r + self.hparams.gamma * g * (1 - d) + returns.append(g) + + # reverse list and stop the gradients + returns = torch.tensor(returns[::-1]) + + return returns + + def loss( + self, + states: Tensor, + actions: Tensor, + returns: Tensor, + ) -> Tensor: + """ + Calculates the loss for A2C which is a weighted sum of + actor loss (MSE), critic loss (PG), and entropy (for exploration) + + Args: + states: tensor of shape (batch_size, state dimension) + actions: tensor of shape (batch_size, ) + returns: tensor of shape (batch_size, ) + """ + + logprobs, values = self.net(states) + + # calculates (normalized) advantage + with torch.no_grad(): + # critic is trained with normalized returns, so we need to scale the values here + advs = returns - values * returns.std() + returns.mean() + # normalize advantages to train actor + advs = (advs - advs.mean()) / (advs.std() + self.eps) + # normalize returns to train critic + targets = (returns - returns.mean()) / (returns.std() + self.eps) + + # entropy loss + entropy = -logprobs.exp() * logprobs + entropy = self.hparams.entropy_beta * entropy.sum(1).mean() + + # actor loss + logprobs = logprobs[range(self.hparams.batch_size), actions] + actor_loss = -(logprobs * advs).mean() + + # critic loss + critic_loss = self.hparams.critic_beta * torch.square(targets - values).mean() + + # total loss (weighted sum) + total_loss = actor_loss + critic_loss - entropy + return total_loss + + def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> OrderedDict: + """ + Perform one actor-critic update using a batch of data + + Args: + batch: a batch of (states, actions, returns) + """ + states, actions, returns = batch + loss = self.loss(states, actions, returns) + + log = { + "episodes": self.done_episodes, + "reward": self.total_rewards[-1], + "avg_reward": self.avg_rewards, + } + return OrderedDict({ + "loss": loss, + "avg_reward": self.avg_rewards, + "log": log, + "progress_bar": log, + }) + + def configure_optimizers(self) -> List[Optimizer]: + """Initialize Adam optimizer""" + optimizer = optim.Adam(self.net.parameters(), lr=self.hparams.lr) + return [optimizer] + + def _dataloader(self) -> DataLoader: + """Initialize the Replay Buffer dataset used for retrieving experiences""" + dataset = ExperienceSourceDataset(self.train_batch) + dataloader = DataLoader(dataset=dataset, batch_size=self.hparams.batch_size) + return dataloader + + def train_dataloader(self) -> DataLoader: + """Get train loader""" + return self._dataloader() + + def get_device(self, batch) -> str: + """Retrieve device currently being used by minibatch""" + return batch[0][0][0].device.index if self.on_gpu else "cpu" + + @staticmethod + def add_model_specific_args(arg_parser: ArgumentParser) -> ArgumentParser: + """ + Adds arguments for A2C model + + Args: + arg_parser: the current argument parser to add to + + Returns: + arg_parser with model specific cargs added + """ + + arg_parser.add_argument("--entropy_beta", type=float, default=0.01, help="entropy coefficient") + arg_parser.add_argument("--critic_beta", type=float, default=0.5, help="critic loss coefficient") + arg_parser.add_argument("--batches_per_epoch", type=int, default=10000, help="number of batches in an epoch") + arg_parser.add_argument("--batch_size", type=int, default=32, help="size of the batches") + arg_parser.add_argument("--lr", type=float, default=1e-3, help="learning rate") + arg_parser.add_argument("--env", type=str, required=True, help="gym environment tag") + arg_parser.add_argument("--gamma", type=float, default=0.99, help="discount factor") + arg_parser.add_argument("--seed", type=int, default=123, help="seed for training run") + + arg_parser.add_argument( + "--avg_reward_len", + type=int, + default=100, + help="how many episodes to include in avg reward", + ) + + return arg_parser + + +def cli_main() -> None: + parser = ArgumentParser(add_help=False) + + # trainer args + parser = Trainer.add_argparse_args(parser) + + # model args + parser = AdvantageActorCritic.add_model_specific_args(parser) + args = parser.parse_args() + + model = AdvantageActorCritic(**args.__dict__) + + # save checkpoints based on avg_reward + checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor="avg_reward", mode="max", period=1, verbose=True) + + seed_everything(123) + trainer = Trainer.from_argparse_args(args, deterministic=True, checkpoint_callback=checkpoint_callback) + trainer.fit(model) + + +if __name__ == "__main__": + cli_main() diff --git a/pl_bolts/models/rl/common/agents.py b/pl_bolts/models/rl/common/agents.py index e692c8becb..057108b702 100644 --- a/pl_bolts/models/rl/common/agents.py +++ b/pl_bolts/models/rl/common/agents.py @@ -137,3 +137,33 @@ def __call__(self, states: Tensor, device: str) -> List[int]: actions = [np.random.choice(len(prob), p=prob) for prob in prob_np] return actions + + +class ActorCriticAgent(Agent): + """Actor-Critic based agent that returns an action based on the networks policy""" + + def __call__(self, states: Tensor, device: str) -> List[int]: + """ + Takes in the current state and returns the action based on the agents policy + + Args: + states: current state of the environment + device: the device used for the current batch + + Returns: + action defined by policy + """ + if not isinstance(states, list): + states = [states] + + if not isinstance(states, Tensor): + states = torch.tensor(states, device=device) + + logprobs, _ = self.net(states) + probabilities = logprobs.exp().squeeze(dim=-1) + prob_np = probabilities.data.cpu().numpy() + + # take the numpy values and randomly select action based on prob distribution + actions = [np.random.choice(len(prob), p=prob) for prob in prob_np] + + return actions diff --git a/pl_bolts/models/rl/common/networks.py b/pl_bolts/models/rl/common/networks.py index 48a9380f9c..3a88931398 100644 --- a/pl_bolts/models/rl/common/networks.py +++ b/pl_bolts/models/rl/common/networks.py @@ -93,6 +93,40 @@ def forward(self, input_x): return self.net(input_x.float()) +class ActorCriticMLP(nn.Module): + """ + MLP network with heads for actor and critic + """ + + def __init__(self, input_shape: Tuple[int], n_actions: int, hidden_size: int = 128): + """ + Args: + input_shape: observation shape of the environment + n_actions: number of discrete actions available in the environment + hidden_size: size of hidden layers + """ + super().__init__() + + self.fc1 = nn.Linear(input_shape[0], hidden_size) + self.actor_head = nn.Linear(hidden_size, n_actions) + self.critic_head = nn.Linear(hidden_size, 1) + + def forward(self, x) -> Tuple[Tensor, Tensor]: + """ + Forward pass through network. Calculates the action logits and the value + + Args: + x: input to network + + Returns: + action log probs (logits), value + """ + x = F.relu(self.fc1(x.float())) + a = F.log_softmax(self.actor_head(x), dim=-1) + c = self.critic_head(x) + return a, c + + class DuelingMLP(nn.Module): """ MLP network with duel heads for val and advantage diff --git a/tests/models/rl/integration/test_actor_critic_models.py b/tests/models/rl/integration/test_actor_critic_models.py new file mode 100644 index 0000000000..dd21b4505a --- /dev/null +++ b/tests/models/rl/integration/test_actor_critic_models.py @@ -0,0 +1,27 @@ +import argparse + +from pytorch_lightning import Trainer + +from pl_bolts.models.rl.advantage_actor_critic_model import AdvantageActorCritic + + +def test_a2c(): + """Smoke test that the A2C model runs""" + + parent_parser = argparse.ArgumentParser(add_help=False) + parent_parser = AdvantageActorCritic.add_model_specific_args(parent_parser) + args_list = [ + "--env", + "CartPole-v0", + ] + hparams = parent_parser.parse_args(args_list) + + trainer = Trainer( + gpus=0, + max_steps=100, + max_epochs=100, # Set this as the same as max steps to ensure that it doesn't stop early + val_check_interval=1, # This just needs 'some' value, does not effect training right now + fast_dev_run=True + ) + model = AdvantageActorCritic(hparams.env) + trainer.fit(model) diff --git a/tests/models/rl/test_scripts.py b/tests/models/rl/test_scripts.py index ee30206718..829049fc19 100644 --- a/tests/models/rl/test_scripts.py +++ b/tests/models/rl/test_scripts.py @@ -126,3 +126,18 @@ def test_cli_run_rl_vanilla_policy_gradient(cli_args): cli_args = cli_args.strip().split(' ') if cli_args else [] with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): cli_main() + + +@pytest.mark.parametrize('cli_args', [ + ' --env CartPole-v0' + ' --max_steps 10' + ' --fast_dev_run 1' + ' --batch_size 10', +]) +def test_cli_run_rl_advantage_actor_critic(cli_args): + """Test running CLI for an example with default params.""" + from pl_bolts.models.rl.advantage_actor_critic_model import cli_main + + cli_args = cli_args.strip().split(' ') if cli_args else [] + with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): + cli_main() diff --git a/tests/models/rl/unit/test_a2c.py b/tests/models/rl/unit/test_a2c.py new file mode 100644 index 0000000000..79805f1b62 --- /dev/null +++ b/tests/models/rl/unit/test_a2c.py @@ -0,0 +1,55 @@ +import argparse + +import torch +from torch import Tensor + +from pl_bolts.models.rl.advantage_actor_critic_model import AdvantageActorCritic + + +def test_a2c_loss(): + """Test the reinforce loss function""" + parent_parser = argparse.ArgumentParser(add_help=False) + parent_parser = AdvantageActorCritic.add_model_specific_args(parent_parser) + args_list = [ + "--env", + "CartPole-v0", + "--batch_size", + "32", + ] + hparams = parent_parser.parse_args(args_list) + model = AdvantageActorCritic(**vars(hparams)) + + batch_states = torch.rand(32, 4) + batch_actions = torch.rand(32).long() + batch_qvals = torch.rand(32) + + loss = model.loss(batch_states, batch_actions, batch_qvals) + + assert isinstance(loss, Tensor) + + +def test_a2c_train_batch(): + """Tests that a single batch generates correctly""" + parent_parser = argparse.ArgumentParser(add_help=False) + parent_parser = AdvantageActorCritic.add_model_specific_args(parent_parser) + args_list = [ + "--env", + "CartPole-v0", + "--batch_size", + "32", + ] + hparams = parent_parser.parse_args(args_list) + model = AdvantageActorCritic(**vars(hparams)) + + model.n_steps = 4 + model.hparams.batch_size = 1 + xp_dataloader = model.train_dataloader() + + batch = next(iter(xp_dataloader)) + + assert len(batch) == 3 + assert len(batch[0]) == model.hparams.batch_size + assert isinstance(batch, list) + assert isinstance(batch[0], Tensor) + assert isinstance(batch[1], Tensor) + assert isinstance(batch[2], Tensor) diff --git a/tests/models/rl/unit/test_agents.py b/tests/models/rl/unit/test_agents.py index 2f25e54de4..ca37d864dd 100644 --- a/tests/models/rl/unit/test_agents.py +++ b/tests/models/rl/unit/test_agents.py @@ -7,7 +7,7 @@ import torch from torch import Tensor -from pl_bolts.models.rl.common.agents import Agent, PolicyAgent, ValueAgent +from pl_bolts.models.rl.common.agents import ActorCriticAgent, Agent, PolicyAgent, ValueAgent class TestAgents(TestCase): @@ -61,3 +61,15 @@ def test_policy_agent(self): action = policy_agent(self.states, self.device) self.assertIsInstance(action, list) self.assertEqual(action[0], 1) + + +def test_a2c_agent(): + env = gym.make("CartPole-v0") + logprobs = torch.nn.functional.log_softmax(Tensor([[0.0, 100.0]])) + net = Mock(return_value=(logprobs, Tensor([[1]]))) + states = [env.reset()] + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + a2c_agent = ActorCriticAgent(net) + action = a2c_agent(states, device) + assert isinstance(action, list) + assert action[0] == 1