From 16ec8adbd90b34b0d02761a7138042f32ac2ea8b Mon Sep 17 00:00:00 2001 From: sven1977 Date: Wed, 15 May 2024 20:49:25 +0200 Subject: [PATCH 1/8] wip Signed-off-by: sven1977 --- rllib/core/rl_module/rl_module.py | 15 +- rllib/core/rl_module/torch/torch_rl_module.py | 82 +++++------ .../examples/rl_modules/action_masking_rlm.py | 6 - .../rl_modules/classes/tiny_atari_cnn.py | 135 ++++++++++++++++++ rllib/examples/rl_modules/custom_rl_module.py | 43 ++++++ .../rl_modules/episode_env_aware_rlm.py | 6 - .../examples/rl_modules/frame_stacking_rlm.py | 12 -- rllib/examples/rl_modules/mobilenet_rlm.py | 6 - rllib/examples/rl_modules/random_rl_module.py | 6 - 9 files changed, 225 insertions(+), 86 deletions(-) delete mode 100644 rllib/examples/rl_modules/action_masking_rlm.py create mode 100644 rllib/examples/rl_modules/classes/tiny_atari_cnn.py create mode 100644 rllib/examples/rl_modules/custom_rl_module.py delete mode 100644 rllib/examples/rl_modules/episode_env_aware_rlm.py delete mode 100644 rllib/examples/rl_modules/frame_stacking_rlm.py delete mode 100644 rllib/examples/rl_modules/mobilenet_rlm.py delete mode 100644 rllib/examples/rl_modules/random_rl_module.py diff --git a/rllib/core/rl_module/rl_module.py b/rllib/core/rl_module/rl_module.py index 691f9c688b5a..afca0c4bb829 100644 --- a/rllib/core/rl_module/rl_module.py +++ b/rllib/core/rl_module/rl_module.py @@ -2,7 +2,7 @@ import datetime import json import pathlib -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Mapping, Any, TYPE_CHECKING, Optional, Type, Dict, Union import gymnasium as gym @@ -203,7 +203,7 @@ class RLModuleConfig: observation_space: gym.Space = None action_space: gym.Space = None - model_config_dict: Dict[str, Any] = None + model_config_dict: Dict[str, Any] = field(default_factory=dict) catalog_class: Type["Catalog"] = None def get_catalog(self) -> "Catalog": @@ -456,7 +456,8 @@ def setup(self): This is called automatically during the __init__ method of this class, therefore, the subclass should call super.__init__() in its constructor. This - abstraction can be used to create any component that your RLModule needs. + abstraction can be used to create any components (e.g. NN layers) that your + RLModule needs. """ return None @@ -596,9 +597,7 @@ def output_specs_inference(self) -> SpecType: a dict that has `action_dist` key and its value is an instance of `Distribution`. """ - # TODO (sven): We should probably change this to [ACTION_DIST_INPUTS], b/c this - # is what most algos will do. - return {"action_dist": Distribution} + return [Columns.ACTION_DIST_INPUTS] @OverrideToImplementCustomLogic_CallToSuperRecommended def output_specs_exploration(self) -> SpecType: @@ -609,9 +608,7 @@ def output_specs_exploration(self) -> SpecType: a dict that has `action_dist` key and its value is an instance of `Distribution`. """ - # TODO (sven): We should probably change this to [ACTION_DIST_INPUTS], b/c this - # is what most algos will do. - return {"action_dist": Distribution} + return [Columns.ACTION_DIST_INPUTS] def output_specs_train(self) -> SpecType: """Returns the output specs of the forward_train method.""" diff --git a/rllib/core/rl_module/torch/torch_rl_module.py b/rllib/core/rl_module/torch/torch_rl_module.py index 9cb4d2bda6c4..e2335ebb0b52 100644 --- a/rllib/core/rl_module/torch/torch_rl_module.py +++ b/rllib/core/rl_module/torch/torch_rl_module.py @@ -21,47 +21,6 @@ torch, nn = try_import_torch() -def compile_wrapper(rl_module: "TorchRLModule", compile_config: TorchCompileConfig): - """A wrapper that compiles the forward methods of a TorchRLModule.""" - - # TODO(Artur): Remove this once our requirements enforce torch >= 2.0.0 - # Check if torch framework supports torch.compile. - if ( - torch is not None - and version.parse(torch.__version__) < TORCH_COMPILE_REQUIRED_VERSION - ): - raise ValueError("torch.compile is only supported from torch 2.0.0") - - compiled_forward_train = torch.compile( - rl_module._forward_train, - backend=compile_config.torch_dynamo_backend, - mode=compile_config.torch_dynamo_mode, - **compile_config.kwargs - ) - - rl_module._forward_train = compiled_forward_train - - compiled_forward_inference = torch.compile( - rl_module._forward_inference, - backend=compile_config.torch_dynamo_backend, - mode=compile_config.torch_dynamo_mode, - **compile_config.kwargs - ) - - rl_module._forward_inference = compiled_forward_inference - - compiled_forward_exploration = torch.compile( - rl_module._forward_exploration, - backend=compile_config.torch_dynamo_backend, - mode=compile_config.torch_dynamo_mode, - **compile_config.kwargs - ) - - rl_module._forward_exploration = compiled_forward_exploration - - return rl_module - - class TorchRLModule(nn.Module, RLModule): """A base class for RLlib PyTorch RLModules. @@ -233,3 +192,44 @@ class TorchDDPRLModuleWithTargetNetworksInterface( @override(RLModuleWithTargetNetworksInterface) def get_target_network_pairs(self) -> List[Tuple[NetworkType, NetworkType]]: return self.module.get_target_network_pairs() + + +def compile_wrapper(rl_module: "TorchRLModule", compile_config: TorchCompileConfig): + """A wrapper that compiles the forward methods of a TorchRLModule.""" + + # TODO(Artur): Remove this once our requirements enforce torch >= 2.0.0 + # Check if torch framework supports torch.compile. + if ( + torch is not None + and version.parse(torch.__version__) < TORCH_COMPILE_REQUIRED_VERSION + ): + raise ValueError("torch.compile is only supported from torch 2.0.0") + + compiled_forward_train = torch.compile( + rl_module._forward_train, + backend=compile_config.torch_dynamo_backend, + mode=compile_config.torch_dynamo_mode, + **compile_config.kwargs + ) + + rl_module._forward_train = compiled_forward_train + + compiled_forward_inference = torch.compile( + rl_module._forward_inference, + backend=compile_config.torch_dynamo_backend, + mode=compile_config.torch_dynamo_mode, + **compile_config.kwargs + ) + + rl_module._forward_inference = compiled_forward_inference + + compiled_forward_exploration = torch.compile( + rl_module._forward_exploration, + backend=compile_config.torch_dynamo_backend, + mode=compile_config.torch_dynamo_mode, + **compile_config.kwargs + ) + + rl_module._forward_exploration = compiled_forward_exploration + + return rl_module diff --git a/rllib/examples/rl_modules/action_masking_rlm.py b/rllib/examples/rl_modules/action_masking_rlm.py deleted file mode 100644 index 68bde8c8a8f2..000000000000 --- a/rllib/examples/rl_modules/action_masking_rlm.py +++ /dev/null @@ -1,6 +0,0 @@ -msg = """ -This script has been moved to -`ray.rllib.examples.rl_modules.classes.action_masking_rlm.py` -""" - -raise NotImplementedError(msg) diff --git a/rllib/examples/rl_modules/classes/tiny_atari_cnn.py b/rllib/examples/rl_modules/classes/tiny_atari_cnn.py new file mode 100644 index 000000000000..a3d9d6cbd2a4 --- /dev/null +++ b/rllib/examples/rl_modules/classes/tiny_atari_cnn.py @@ -0,0 +1,135 @@ +from ray.rllib.core.columns import Columns +from ray.rllib.core.rl_module.torch import TorchRLModule +from ray.rllib.models.torch.misc import normc_initializer +from ray.rllib.models.torch.misc import same_padding, valid_padding +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_torch + +torch, nn = try_import_torch() + + +class TinyAtariCNN(TorchRLModule): + """A tiny CNN stack for fast-learning of Atari envs. + + The architecture here is the exact same as the one used by the old API stack as + CNN default ModelV2. + + We stack 3 CNN layers based on the config, then a 4th one with linear activation + and n 1x1 filters, where n is the number of actions in the (discrete) action space. + Simple reshaping (no flattening or extra linear layers necessary) lead to the + action logits, which can directly be used inside a distribution or loss. + """ + @override(TorchRLModule) + def setup(self): + # Define the layers that this CNN stack needs. + conv_filters = [ + [16, 4, 2], # num filters, kernel wxh, stride wxh + [32, 4, 2], + [256, 11, 1, "valid"], # , ... , [padding type] + ] + + # Build the CNN layers. + layers = [] + + # Add user-specified hidden convolutional layers first + width, height, in_depth = self.config.observation_space.shape + in_size = [width, height] + for filter_specs in conv_filters: + # Padding information not provided -> Use "same" as default. + if len(filter_specs) == 3: + out_depth, kernel_size, strides = filter_specs + padding = "same" + # Padding information provided. + else: + out_depth, kernel_size, strides, padding = filter_specs + + # Pad like in tensorflow's SAME/VALID mode. + if padding == "same": + padding_size, out_size = same_padding(in_size, kernel_size, strides) + layers.append(nn.ZeroPad2d(padding_size)) + # No actual padding is performed for "valid" mode, but we will still + # compute the output size (input for the next layer). + else: + out_size = valid_padding(in_size, kernel_size, strides) + + layer = nn.Conv2d(in_depth, out_depth, kernel_size, strides, bias=True) + # Initialize CNN layer kernel. + nn.init.xavier_uniform_(layer.weight) + # Initialize CNN layer bias. + nn.init.zeros_(layer.bias) + + layers.append(layer) + + # Activation. + layers.append(nn.ReLU()) + + in_size = out_size + in_depth = out_depth + + self._base_cnn_stack = nn.Sequential(*layers) + + # Add the final CNN 1x1 layer with num_filters == num_actions to be reshaped to + # yield the logits (no flattening, no additional linear layers required). + self._logits = nn.Sequential( + nn.ZeroPad2d(same_padding(in_size, 1, 1)[0]), + nn.Conv2d(in_depth, self.config.action_space.n, 1, 1, bias=True), + ) + self._values = nn.Linear(in_depth, 1) + # Mimick old API stack behavior of initializing the value function with `normc` + # std=0.01. + normc_initializer(0.01)(self._values.weight) + + @override(TorchRLModule) + def _forward_inference(self, batch, **kwargs): + _, logits = self._compute_features_and_logits(batch) + return { + Columns.ACTION_DIST_INPUTS: logits + } + + @override(TorchRLModule) + def _forward_exploration(self, batch, **kwargs): + return self._forward_inference(batch, **kwargs) + + @override(TorchRLModule) + def _forward_train(self, batch, **kwargs): + features, logits = self._compute_features_and_logits(batch) + # Besides the action logits, we also have to return value predictions here + # (to be used inside the loss function). + vf = self._values(features) + return { + Columns.ACTION_DIST_INPUTS: logits, + Columns.VF_PREDS: vf.squeeze(-1), + } + + def _compute_features_and_logits(self, batch): + obs = batch[Columns.OBS].permute(0, 3, 1, 2) + features = self._base_cnn_stack(obs) + logits = self._logits(features) + return torch.squeeze(features, dim=[-1, -2]), torch.squeeze(logits, dim=[-1, -2]) + + +if __name__ == "__main__": + import numpy as np + import gymnasium as gym + from ray.rllib.core.rl_module.rl_module import RLModuleConfig + + rl_module_config = RLModuleConfig( + observation_space=gym.spaces.Box(-1.0, 1.0, (42, 42, 4), np.float32), + action_space=gym.spaces.Discrete(4), + ) + my_net = TinyAtariCNN(rl_module_config) + + + B = 10 + w = 42 + h = 42 + c = 4 + data = torch.from_numpy( + np.random.random_sample(size=(B, w, h, c)).astype(np.float32) + ) + print(my_net.forward_inference({"obs": data})) + print(my_net.forward_exploration({"obs": data})) + print(my_net.forward_train({"obs": data})) + + num_all_params = sum(int(np.prod(p.size())) for p in my_net.parameters()) + print(f"num params = {num_all_params}") \ No newline at end of file diff --git a/rllib/examples/rl_modules/custom_rl_module.py b/rllib/examples/rl_modules/custom_rl_module.py new file mode 100644 index 000000000000..9f86bcbd826b --- /dev/null +++ b/rllib/examples/rl_modules/custom_rl_module.py @@ -0,0 +1,43 @@ +"""Example of implementing and configuring a custom (torch) RLModule. + +This example: + - demonstrates how you can subclass the TorchRLModule base class and setup your + own neural network architecture by overriding `setup()`. + - how to override the 3 forward methods: `_forward_inference`, `_forward_exploration`, + and `forward_train` to implement your own custom forward logic(s). You will also learn, + when each of these 3 methods is called by RLlib or the users of your RLModule. + - shows how you then configure an RLlib Algorithm such that it uses your custom + RLModule (instead of a default RLModule). + +We implement a tiny CNN stack here, the exact same one that is used by the old API +stack as default CNN net. It comprises 4 convolutional layers, the last of which +ends in a 1x1 filter size and the number of filters exactly matches the number of +discrete actions (logits). This way, the (non-activated) output of the last layer only +needs to be reshaped in order to receive the policy's logit outputs. No flattening +or additional dense layer required. + +The network is then used in a fast ALE/Pong-v5 experiment. + + +How to run this script +---------------------- +`python [script file name].py --enable-new-api-stack` + +For debugging, use the following additional command line options +`--no-tune --num-env-runners=0` +which should allow you to set breakpoints anywhere in the RLlib code and +have the execution stop there for inspection and debugging. + +For logging to your WandB account, use: +`--wandb-key=[your WandB API key] --wandb-project=[some project name] +--wandb-run-name=[optional: WandB run name (within the defined project)]` + + +Results to expect +----------------- +You should see the following output (at the end of the experiment) in your console: + +""" + + + diff --git a/rllib/examples/rl_modules/episode_env_aware_rlm.py b/rllib/examples/rl_modules/episode_env_aware_rlm.py deleted file mode 100644 index 9cafd034ec0b..000000000000 --- a/rllib/examples/rl_modules/episode_env_aware_rlm.py +++ /dev/null @@ -1,6 +0,0 @@ -msg = """ -This script has been moved to -`ray.rllib.examples.rl_modules.classes.random_rlm.py::StatefulRandomRLModule` -""" - -raise NotImplementedError(msg) diff --git a/rllib/examples/rl_modules/frame_stacking_rlm.py b/rllib/examples/rl_modules/frame_stacking_rlm.py deleted file mode 100644 index 4ed592fa8705..000000000000 --- a/rllib/examples/rl_modules/frame_stacking_rlm.py +++ /dev/null @@ -1,12 +0,0 @@ -msg = """ -This script has been taken out of RLlib b/c: -- This script used `ViewRequirements` ("Trajectory View API") to set up the RLModule, -however, this API will not be part of the new API stack. -Instead, you can use RLlib's built-in ConnectorV2 for frame stacking (or write a custom -ConnectorV2). Take a look at this example script here, which shows how you can do frame- -stacking with RLlib's new ConnectorV2 API. - -`ray.rllib.examples.connectors.frame_stacking.py` -""" - -raise NotImplementedError(msg) diff --git a/rllib/examples/rl_modules/mobilenet_rlm.py b/rllib/examples/rl_modules/mobilenet_rlm.py deleted file mode 100644 index 84f57d0566e0..000000000000 --- a/rllib/examples/rl_modules/mobilenet_rlm.py +++ /dev/null @@ -1,6 +0,0 @@ -msg = """ -This script has been moved to -`ray.rllib.examples.rl_modules.classes.mobilenet_rlm.py` -""" - -raise NotImplementedError(msg) diff --git a/rllib/examples/rl_modules/random_rl_module.py b/rllib/examples/rl_modules/random_rl_module.py deleted file mode 100644 index eac2d59ddf61..000000000000 --- a/rllib/examples/rl_modules/random_rl_module.py +++ /dev/null @@ -1,6 +0,0 @@ -msg = """ -This script has been moved to -`ray.rllib.examples.rl_modules.classes.random_rlm.py` -""" - -raise NotImplementedError(msg) From 0af3228efb5869cf34e2def9a59117b51ab6e67c Mon Sep 17 00:00:00 2001 From: sven1977 Date: Wed, 15 May 2024 20:57:19 +0200 Subject: [PATCH 2/8] wip Signed-off-by: sven1977 --- rllib/algorithms/ppo/torch/ppo_torch_rl_module.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index 1d923e772233..551f9327cb33 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -22,8 +22,7 @@ def setup(self): super().setup() # If not an inference-only module (e.g., for evaluation), set up the - # parameter names to be removed or renamed when syncing from the state dict - # when synching. + # parameter names to be removed or renamed when syncing from the state dict. if not self.inference_only: # Set the expected and unexpected keys for the inference-only module. self._set_inference_only_state_dict_keys() From df4192b5f1e34800a4c60516861e34c52812e783 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Thu, 6 Jun 2024 13:08:03 +0200 Subject: [PATCH 3/8] wip Signed-off-by: sven1977 --- rllib/core/rl_module/rl_module.py | 16 +++++----- rllib/core/rl_module/torch/torch_rl_module.py | 32 ++++++++++++++++++- .../rl_modules/classes/tiny_atari_cnn.py | 20 ++++++++++-- rllib/examples/rl_modules/custom_rl_module.py | 14 ++++---- 4 files changed, 63 insertions(+), 19 deletions(-) diff --git a/rllib/core/rl_module/rl_module.py b/rllib/core/rl_module/rl_module.py index afca0c4bb829..5d71fecf13d7 100644 --- a/rllib/core/rl_module/rl_module.py +++ b/rllib/core/rl_module/rl_module.py @@ -465,14 +465,14 @@ def setup(self): def get_train_action_dist_cls(self) -> Type[Distribution]: """Returns the action distribution class for this RLModule used for training. - This class is used to create action distributions from outputs of the - forward_train method. If the case that no action distribution class is needed, + This class is used to get the correct action distribution class to be used by + the training components. In case that no action distribution class is needed, this method can return None. Note that RLlib's distribution classes all implement the `Distribution` interface. This requires two special methods: `Distribution.from_logits()` and - `Distribution.to_deterministic()`. See the documentation for `Distribution` - for more detail. + `Distribution.to_deterministic()`. See the documentation of the + :py:class:`~ray.rllib.models.distributions.Distribution` class for more details. """ raise NotImplementedError @@ -486,8 +486,8 @@ def get_exploration_action_dist_cls(self) -> Type[Distribution]: Note that RLlib's distribution classes all implement the `Distribution` interface. This requires two special methods: `Distribution.from_logits()` and - `Distribution.to_deterministic()`. See the documentation for `Distribution` - for more detail. + `Distribution.to_deterministic()`. See the documentation of the + :py:class:`~ray.rllib.models.distributions.Distribution` class for more details. """ raise NotImplementedError @@ -501,8 +501,8 @@ def get_inference_action_dist_cls(self) -> Type[Distribution]: Note that RLlib's distribution classes all implement the `Distribution` interface. This requires two special methods: `Distribution.from_logits()` and - `Distribution.to_deterministic()`. See the documentation for `Distribution` - for more detail. + `Distribution.to_deterministic()`. See the documentation of the + :py:class:`~ray.rllib.models.distributions.Distribution` class for more details. """ raise NotImplementedError diff --git a/rllib/core/rl_module/torch/torch_rl_module.py b/rllib/core/rl_module/torch/torch_rl_module.py index a1b605d46498..576f5dab8908 100644 --- a/rllib/core/rl_module/torch/torch_rl_module.py +++ b/rllib/core/rl_module/torch/torch_rl_module.py @@ -1,6 +1,7 @@ import pathlib from typing import Any, List, Mapping, Tuple, Union, Type +import gymnasium as gym from packaging import version from ray.rllib.core.rl_module import RLModule @@ -8,7 +9,11 @@ RLModuleWithTargetNetworksInterface, ) from ray.rllib.core.rl_module.torch.torch_compile_config import TorchCompileConfig -from ray.rllib.models.torch.torch_distributions import TorchDistribution +from ray.rllib.models.torch.torch_distributions import ( + TorchCategorical, + TorchDiagGaussian, + TorchDistribution, +) from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.numpy import convert_to_numpy @@ -44,6 +49,18 @@ def __init__(self, *args, **kwargs) -> None: nn.Module.__init__(self) RLModule.__init__(self, *args, **kwargs) + @override(RLModule) + def get_inference_action_dist_cls(self) -> Type[TorchDistribution]: + return self._get_default_action_dist_class("inference") + + @override(RLModule) + def get_exploration_action_dist_cls(self) -> Type[TorchDistribution]: + return self._get_default_action_dist_class("exploration") + + @override(RLModule) + def get_train_action_dist_cls(self) -> Type[TorchDistribution]: + return self._get_default_action_dist_class("train") + def forward(self, batch: Mapping[str, Any], **kwargs) -> Mapping[str, Any]: """forward pass of the module. @@ -115,6 +132,19 @@ def _inference_only_get_state_hook( """ pass + def _get_default_action_dist_class(self, what: str) -> Type[TorchDistribution]: + # The default implementation is to return TorchCategorical for Discrete action + # spaces and TorchDiagGaussian for Box action spaces. For all other spaces, + # raise a NotImplementedError + if isinstance(self.config.action_space, gym.spaces.Discrete): + return TorchCategorical + elif isinstance(self.config.action_space, gym.spaces.Box): + return TorchDiagGaussian + else: + raise NotImplementedError( + f"Override your RLModule's `get_{what}_action_dist_cls` method and " + "return the correct TorchDistribution class from it!" + ) class TorchDDPRLModule(RLModule, nn.parallel.DistributedDataParallel): def __init__(self, *args, **kwargs) -> None: diff --git a/rllib/examples/rl_modules/classes/tiny_atari_cnn.py b/rllib/examples/rl_modules/classes/tiny_atari_cnn.py index c043cc06f401..e5ca591fc22f 100644 --- a/rllib/examples/rl_modules/classes/tiny_atari_cnn.py +++ b/rllib/examples/rl_modules/classes/tiny_atari_cnn.py @@ -4,6 +4,7 @@ from ray.rllib.models.torch.misc import same_padding, valid_padding from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.torch_utils import convert_to_torch_tensor torch, nn = try_import_torch() @@ -75,7 +76,9 @@ def setup(self): @override(TorchRLModule) def _forward_inference(self, batch, **kwargs): + # Compute the basic 1D feature tensor (inputs to policy- and value-heads). _, logits = self._compute_features_and_logits(batch) + # Return logits as ACTION_DIST_INPUTS (categorical distribution). return { Columns.ACTION_DIST_INPUTS: logits } @@ -86,20 +89,31 @@ def _forward_exploration(self, batch, **kwargs): @override(TorchRLModule) def _forward_train(self, batch, **kwargs): + # Compute the basic 1D feature tensor (inputs to policy- and value-heads). features, logits = self._compute_features_and_logits(batch) # Besides the action logits, we also have to return value predictions here # (to be used inside the loss function). - vf = self._values(features) + values = self._values(features).squeeze(-1) return { Columns.ACTION_DIST_INPUTS: logits, - Columns.VF_PREDS: vf.squeeze(-1), + Columns.VF_PREDS: values, } def _compute_features_and_logits(self, batch): obs = batch[Columns.OBS].permute(0, 3, 1, 2) features = self._base_cnn_stack(obs) logits = self._logits(features) - return torch.squeeze(features, dim=[-1, -2]), torch.squeeze(logits, dim=[-1, -2]) + return ( + torch.squeeze(features, dim=[-1, -2]), + torch.squeeze(logits, dim=[-1, -2]), + ) + + def _compute_values(self, batch, device): + obs = convert_to_torch_tensor(batch[Columns.OBS], device=device) + features = self._base_cnn_stack(obs.permute(0, 3, 1, 2)) + logits = self._logits(features) + features = torch.squeeze(features, dim=[-1, -2]) + return self._values(features).squeeze(-1) if __name__ == "__main__": diff --git a/rllib/examples/rl_modules/custom_rl_module.py b/rllib/examples/rl_modules/custom_rl_module.py index c724e2920762..4421188906f4 100644 --- a/rllib/examples/rl_modules/custom_rl_module.py +++ b/rllib/examples/rl_modules/custom_rl_module.py @@ -50,14 +50,18 @@ from ray.tune.registry import get_trainable_cls, register_env parser = add_rllib_example_script_args(default_iters=100, default_timesteps=600000) +parser.set_defaults(env="ALE/Pong-v5") if __name__ == "__main__": args = parser.parse_args() + assert ( + args.enable_new_api_stack + ), "Must set --enable-new-api-stack when running this script!" + register_env("env", lambda cfg: wrap_atari_for_new_api_stack( - #TODO(sven) pull from master to get args.env - gym.make("ALE/Pong-v5", **cfg), #args.env + gym.make(args.env, **cfg), dim=42, # <- need images to be "tiny" for our custom model framestack=4, )) @@ -65,10 +69,6 @@ base_config = ( get_trainable_cls(args.algo) .get_default_config() - .api_stack( - enable_rl_module_and_learner=True, - enable_env_runner_and_connector_v2=True, - ) .environment( env="env", env_config=dict( @@ -84,4 +84,4 @@ ) ) - run_rllib_example_script_experiment(base_config, args) + run_rllib_example_script_experiment(base_config, args, stop={}) From b8b459ffdd462678dcb58d461cd8cb877e406a4c Mon Sep 17 00:00:00 2001 From: sven1977 Date: Thu, 6 Jun 2024 16:38:13 +0200 Subject: [PATCH 4/8] LINT and test case and docstring Signed-off-by: sven1977 --- rllib/BUILD | 16 +++--- rllib/core/rl_module/torch/torch_rl_module.py | 7 +-- .../rl_modules/classes/tiny_atari_cnn.py | 54 +++++++++++++------ rllib/examples/rl_modules/custom_rl_module.py | 49 +++++++++++++---- 4 files changed, 89 insertions(+), 37 deletions(-) diff --git a/rllib/BUILD b/rllib/BUILD index 6948f17e903c..1f46b6618f21 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -2119,7 +2119,6 @@ py_test( # subdirectory: checkpoints/ # .................................... - py_test( name = "examples/checkpoints/checkpoint_by_custom_criteria", main = "examples/checkpoints/checkpoint_by_custom_criteria.py", @@ -2283,7 +2282,6 @@ py_test( # subdirectory: curriculum/ # .................................... - py_test( name = "examples/curriculum/curriculum_learning", main = "examples/curriculum/curriculum_learning.py", @@ -2295,7 +2293,6 @@ py_test( # subdirectory: debugging/ # .................................... - #@OldAPIStack py_test( name = "examples/debugging/deterministic_training_torch", @@ -2308,7 +2305,6 @@ py_test( # subdirectory: envs/ # .................................... - py_test( name = "examples/envs/custom_gym_env", main = "examples/envs/custom_gym_env.py", @@ -2449,7 +2445,6 @@ py_test( # subdirectory: gpus/ # .................................... - py_test( name = "examples/gpus/fractional_0.5_gpus_per_learner", main = "examples/gpus/fractional_gpus_per_learner.py", @@ -2469,7 +2464,6 @@ py_test( # subdirectory: hierarchical/ # .................................... - #@OldAPIStack py_test( name = "examples/hierarchical/hierarchical_training_tf", @@ -2492,7 +2486,6 @@ py_test( # subdirectory: inference/ # .................................... - #@OldAPIStack py_test( name = "examples/inference/policy_inference_after_training_tf", @@ -2905,6 +2898,15 @@ py_test( # subdirectory: rl_modules/ # .................................... +py_test( + name = "examples/rl_modules/custom_rl_module", + main = "examples/rl_modules/custom_rl_module.py", + tags = ["team:rllib", "examples"], + size = "medium", + srcs = ["examples/rl_modules/custom_rl_module.py"], + args = ["--enable-new-api-stack", "--stop-iters=3"], +) + #@OldAPIStack @HybridAPIStack py_test( name = "examples/rl_modules/classes/mobilenet_rlm_hybrid_api_stack", diff --git a/rllib/core/rl_module/torch/torch_rl_module.py b/rllib/core/rl_module/torch/torch_rl_module.py index 576f5dab8908..0a68f3a27254 100644 --- a/rllib/core/rl_module/torch/torch_rl_module.py +++ b/rllib/core/rl_module/torch/torch_rl_module.py @@ -146,6 +146,7 @@ def _get_default_action_dist_class(self, what: str) -> Type[TorchDistribution]: "return the correct TorchDistribution class from it!" ) + class TorchDDPRLModule(RLModule, nn.parallel.DistributedDataParallel): def __init__(self, *args, **kwargs) -> None: nn.parallel.DistributedDataParallel.__init__(self, *args, **kwargs) @@ -240,7 +241,7 @@ def compile_wrapper(rl_module: "TorchRLModule", compile_config: TorchCompileConf rl_module._forward_train, backend=compile_config.torch_dynamo_backend, mode=compile_config.torch_dynamo_mode, - **compile_config.kwargs + **compile_config.kwargs, ) rl_module._forward_train = compiled_forward_train @@ -249,7 +250,7 @@ def compile_wrapper(rl_module: "TorchRLModule", compile_config: TorchCompileConf rl_module._forward_inference, backend=compile_config.torch_dynamo_backend, mode=compile_config.torch_dynamo_mode, - **compile_config.kwargs + **compile_config.kwargs, ) rl_module._forward_inference = compiled_forward_inference @@ -258,7 +259,7 @@ def compile_wrapper(rl_module: "TorchRLModule", compile_config: TorchCompileConf rl_module._forward_exploration, backend=compile_config.torch_dynamo_backend, mode=compile_config.torch_dynamo_mode, - **compile_config.kwargs + **compile_config.kwargs, ) rl_module._forward_exploration = compiled_forward_exploration diff --git a/rllib/examples/rl_modules/classes/tiny_atari_cnn.py b/rllib/examples/rl_modules/classes/tiny_atari_cnn.py index e5ca591fc22f..f7f6e9311fbb 100644 --- a/rllib/examples/rl_modules/classes/tiny_atari_cnn.py +++ b/rllib/examples/rl_modules/classes/tiny_atari_cnn.py @@ -20,14 +20,31 @@ class TinyAtariCNN(TorchRLModule): Simple reshaping (no flattening or extra linear layers necessary) lead to the action logits, which can directly be used inside a distribution or loss. """ + @override(TorchRLModule) def setup(self): - # Define the layers that this CNN stack needs. - conv_filters = [ - [16, 4, 2, "same"], # num filters, kernel wxh, stride wxh, padding type - [32, 4, 2, "same"], - [256, 11, 1, "valid"], - ] + """Use this method to create all the model components that you require. + + Feel free to access the following useful properties in this class: + - `self.config.model_config_dict`: The config dict for this RLModule class, + which should contain flxeible settings, for example: {"hiddens": [256, 256]}. + - `self.config.observation|action_space`: The observation and action space that + this RLModule is subject to. Note that the observation space might not be the + exact space from your env, but that it might have already gone through + preprocessing through a connector pipeline (for example, flattening, + frame-stacking, mean/std-filtering, etc..). + """ + # Get the CNN stack config from our RLModuleConfig's (self.config) + # `model_config_dict` property: + if "conv_filters" in self.config.model_config_dict: + conv_filters = self.config.model_config_dict["conv_filters"] + # Default CNN stack with 3 layers: + else: + conv_filters = [ + [16, 4, 2, "same"], # num filters, kernel wxh, stride wxh, padding type + [32, 4, 2, "same"], + [256, 11, 1, "valid"], + ] # Build the CNN layers. layers = [] @@ -79,9 +96,7 @@ def _forward_inference(self, batch, **kwargs): # Compute the basic 1D feature tensor (inputs to policy- and value-heads). _, logits = self._compute_features_and_logits(batch) # Return logits as ACTION_DIST_INPUTS (categorical distribution). - return { - Columns.ACTION_DIST_INPUTS: logits - } + return {Columns.ACTION_DIST_INPUTS: logits} @override(TorchRLModule) def _forward_exploration(self, batch, **kwargs): @@ -99,6 +114,18 @@ def _forward_train(self, batch, **kwargs): Columns.VF_PREDS: values, } + # TODO (sven): In order for this RLModule to work with PPO, we must define + # our own `_compute_values()` method. This would become more obvious, if we simply + # subclassed the `PPOTorchRLModule` directly here (which we didn't do for + # simplicity and to keep some generality). We might change even get rid of algo- + # specific RLModule subclasses altogether in the future and replace them + # by mere algo-specific APIs (w/o any actual implementations). + def _compute_values(self, batch, device): + obs = convert_to_torch_tensor(batch[Columns.OBS], device=device) + features = self._base_cnn_stack(obs.permute(0, 3, 1, 2)) + features = torch.squeeze(features, dim=[-1, -2]) + return self._values(features).squeeze(-1) + def _compute_features_and_logits(self, batch): obs = batch[Columns.OBS].permute(0, 3, 1, 2) features = self._base_cnn_stack(obs) @@ -108,13 +135,6 @@ def _compute_features_and_logits(self, batch): torch.squeeze(logits, dim=[-1, -2]), ) - def _compute_values(self, batch, device): - obs = convert_to_torch_tensor(batch[Columns.OBS], device=device) - features = self._base_cnn_stack(obs.permute(0, 3, 1, 2)) - logits = self._logits(features) - features = torch.squeeze(features, dim=[-1, -2]) - return self._values(features).squeeze(-1) - if __name__ == "__main__": import numpy as np @@ -139,4 +159,4 @@ def _compute_values(self, batch, device): print(my_net.forward_train({"obs": data})) num_all_params = sum(int(np.prod(p.size())) for p in my_net.parameters()) - print(f"num params = {num_all_params}") \ No newline at end of file + print(f"num params = {num_all_params}") diff --git a/rllib/examples/rl_modules/custom_rl_module.py b/rllib/examples/rl_modules/custom_rl_module.py index 4421188906f4..7f6ce25b1cce 100644 --- a/rllib/examples/rl_modules/custom_rl_module.py +++ b/rllib/examples/rl_modules/custom_rl_module.py @@ -3,9 +3,10 @@ This example: - demonstrates how you can subclass the TorchRLModule base class and setup your own neural network architecture by overriding `setup()`. - - how to override the 3 forward methods: `_forward_inference`, `_forward_exploration`, - and `forward_train` to implement your own custom forward logic(s). You will also learn, - when each of these 3 methods is called by RLlib or the users of your RLModule. + - how to override the 3 forward methods: `_forward_inference()`, + `_forward_exploration()`, and `forward_train()` to implement your own custom forward + logic(s). You will also learn, when each of these 3 methods is called by RLlib or + the users of your RLModule. - shows how you then configure an RLlib Algorithm such that it uses your custom RLModule (instead of a default RLModule). @@ -35,8 +36,21 @@ Results to expect ----------------- -You should see the following output (at the end of the experiment) in your console: - +You should see the following output (during the experiment) in your console: + +Number of trials: 1/1 (1 RUNNING) ++---------------------+----------+----------------+--------+------------------+ +| Trial name | status | loc | iter | total time (s) | +| | | | | | +|---------------------+----------+----------------+--------+------------------+ +| PPO_env_82b44_00000 | RUNNING | 127.0.0.1:9718 | 1 | 98.3585 | ++---------------------+----------+----------------+--------+------------------+ ++------------------------+------------------------+------------------------+ +| num_env_steps_sample | num_env_steps_traine | num_episodes_lifetim | +| d_lifetime | d_lifetime | e | +|------------------------+------------------------+------------------------| +| 4000 | 4000 | 4 | ++------------------------+------------------------+------------------------+ """ import gymnasium as gym @@ -60,11 +74,14 @@ args.enable_new_api_stack ), "Must set --enable-new-api-stack when running this script!" - register_env("env", lambda cfg: wrap_atari_for_new_api_stack( - gym.make(args.env, **cfg), - dim=42, # <- need images to be "tiny" for our custom model - framestack=4, - )) + register_env( + "env", + lambda cfg: wrap_atari_for_new_api_stack( + gym.make(args.env, **cfg), + dim=42, # <- need images to be "tiny" for our custom model + framestack=4, + ), + ) base_config = ( get_trainable_cls(args.algo) @@ -78,9 +95,21 @@ ), ) .rl_module( + # Plug-in our custom RLModule class. rl_module_spec=SingleAgentRLModuleSpec( module_class=TinyAtariCNN, ), + # Feel free to specify your own `model_config_dict` settings below. + # The `model_config_dict` defined here will be available inside your custom + # RLModule class through the `self.config.model_config_dict` property. + # model_config_dict={ + # "conv_filters": [ + # # num filters, kernel wxh, stride wxh, padding type + # [16, 4, 2, "same"], + # [32, 4, 2, "same"], + # [64, 4, 2, "same"], + # ], + # }, ) ) From 682da4dd2b1d4cea2d32d14a2f8ecb4b44af608c Mon Sep 17 00:00:00 2001 From: sven1977 Date: Thu, 6 Jun 2024 16:55:10 +0200 Subject: [PATCH 5/8] SAEnvRunner bug fix overriding an already provided model_config_dict with RLlib default ones. Signed-off-by: sven1977 --- rllib/env/single_agent_env_runner.py | 7 ++----- rllib/examples/rl_modules/custom_rl_module.py | 17 +++++++++-------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/rllib/env/single_agent_env_runner.py b/rllib/env/single_agent_env_runner.py index 3025689255f2..582806dd193b 100644 --- a/rllib/env/single_agent_env_runner.py +++ b/rllib/env/single_agent_env_runner.py @@ -91,12 +91,9 @@ def __init__(self, config: AlgorithmConfig, **kwargs): try: module_spec: SingleAgentRLModuleSpec = self.config.rl_module_spec module_spec.observation_space = self._env_to_module.observation_space - # TODO (simon): The `gym.Wrapper` for `gym.vector.VectorEnv` should - # actually hold the spaces for a single env, but for boxes the - # shape is (1, 1) which brings a problem with the action dists. - # shape=(1,) is expected. module_spec.action_space = self.env.envs[0].action_space - module_spec.model_config_dict = self.config.model_config + if module_spec.model_config_dict is None: + module_spec.model_config_dict = self.config.model_config # Only load a light version of the module, if available. This is useful # if the the module has target or critic networks not needed in sampling # or inference. diff --git a/rllib/examples/rl_modules/custom_rl_module.py b/rllib/examples/rl_modules/custom_rl_module.py index 7f6ce25b1cce..b2f407946071 100644 --- a/rllib/examples/rl_modules/custom_rl_module.py +++ b/rllib/examples/rl_modules/custom_rl_module.py @@ -98,18 +98,19 @@ # Plug-in our custom RLModule class. rl_module_spec=SingleAgentRLModuleSpec( module_class=TinyAtariCNN, + model_config_dict={"a": "b"}, ), # Feel free to specify your own `model_config_dict` settings below. # The `model_config_dict` defined here will be available inside your custom # RLModule class through the `self.config.model_config_dict` property. - # model_config_dict={ - # "conv_filters": [ - # # num filters, kernel wxh, stride wxh, padding type - # [16, 4, 2, "same"], - # [32, 4, 2, "same"], - # [64, 4, 2, "same"], - # ], - # }, + model_config_dict={ + "conv_filters": [ + # num filters, kernel wxh, stride wxh, padding type + [16, 4, 2, "same"], + [32, 4, 2, "same"], + [64, 4, 2, "same"], + ], + }, ) ) From 6f4672a7da21f9b78df7bb0fe2b5fca548b127ea Mon Sep 17 00:00:00 2001 From: sven1977 Date: Fri, 7 Jun 2024 09:00:37 +0200 Subject: [PATCH 6/8] wip Signed-off-by: sven1977 --- rllib/algorithms/sac/torch/sac_torch_rl_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rllib/algorithms/sac/torch/sac_torch_rl_module.py b/rllib/algorithms/sac/torch/sac_torch_rl_module.py index 9b30e5bbaf89..ca6176336c23 100644 --- a/rllib/algorithms/sac/torch/sac_torch_rl_module.py +++ b/rllib/algorithms/sac/torch/sac_torch_rl_module.py @@ -22,7 +22,7 @@ torch, nn = try_import_torch() -class SACTorchRLModule(TorchRLModule, SACRLModule): +class SACTorchRLModule(SACRLModule, TorchRLModule): framework: str = "torch" @override(SACRLModule) From a30f9b6539b2b4f8741653ac3925c915b08dfb5a Mon Sep 17 00:00:00 2001 From: sven1977 Date: Fri, 7 Jun 2024 09:00:48 +0200 Subject: [PATCH 7/8] wip Signed-off-by: sven1977 --- rllib/algorithms/ppo/torch/ppo_torch_rl_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index 78c041878d1e..ac0dfcf4df7c 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -14,7 +14,7 @@ torch, nn = try_import_torch() -class PPOTorchRLModule(TorchRLModule, PPORLModule): +class PPOTorchRLModule(PPORLModule, TorchRLModule): framework: str = "torch" @override(PPORLModule) From 587b5f009834c7d7fb0a78f8fb48d500c476a15e Mon Sep 17 00:00:00 2001 From: sven1977 Date: Fri, 7 Jun 2024 13:01:00 +0200 Subject: [PATCH 8/8] fixes Signed-off-by: sven1977 --- .../ppo/torch/ppo_torch_rl_module.py | 2 +- .../sac/torch/sac_torch_rl_module.py | 2 +- rllib/core/rl_module/torch/torch_rl_module.py | 33 +------------------ .../rl_modules/classes/tiny_atari_cnn.py | 16 +++++++++ 4 files changed, 19 insertions(+), 34 deletions(-) diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index ac0dfcf4df7c..78c041878d1e 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -14,7 +14,7 @@ torch, nn = try_import_torch() -class PPOTorchRLModule(PPORLModule, TorchRLModule): +class PPOTorchRLModule(TorchRLModule, PPORLModule): framework: str = "torch" @override(PPORLModule) diff --git a/rllib/algorithms/sac/torch/sac_torch_rl_module.py b/rllib/algorithms/sac/torch/sac_torch_rl_module.py index ca6176336c23..9b30e5bbaf89 100644 --- a/rllib/algorithms/sac/torch/sac_torch_rl_module.py +++ b/rllib/algorithms/sac/torch/sac_torch_rl_module.py @@ -22,7 +22,7 @@ torch, nn = try_import_torch() -class SACTorchRLModule(SACRLModule, TorchRLModule): +class SACTorchRLModule(TorchRLModule, SACRLModule): framework: str = "torch" @override(SACRLModule) diff --git a/rllib/core/rl_module/torch/torch_rl_module.py b/rllib/core/rl_module/torch/torch_rl_module.py index 0a68f3a27254..883b39f26f99 100644 --- a/rllib/core/rl_module/torch/torch_rl_module.py +++ b/rllib/core/rl_module/torch/torch_rl_module.py @@ -1,7 +1,6 @@ import pathlib from typing import Any, List, Mapping, Tuple, Union, Type -import gymnasium as gym from packaging import version from ray.rllib.core.rl_module import RLModule @@ -9,11 +8,7 @@ RLModuleWithTargetNetworksInterface, ) from ray.rllib.core.rl_module.torch.torch_compile_config import TorchCompileConfig -from ray.rllib.models.torch.torch_distributions import ( - TorchCategorical, - TorchDiagGaussian, - TorchDistribution, -) +from ray.rllib.models.torch.torch_distributions import TorchDistribution from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.numpy import convert_to_numpy @@ -49,18 +44,6 @@ def __init__(self, *args, **kwargs) -> None: nn.Module.__init__(self) RLModule.__init__(self, *args, **kwargs) - @override(RLModule) - def get_inference_action_dist_cls(self) -> Type[TorchDistribution]: - return self._get_default_action_dist_class("inference") - - @override(RLModule) - def get_exploration_action_dist_cls(self) -> Type[TorchDistribution]: - return self._get_default_action_dist_class("exploration") - - @override(RLModule) - def get_train_action_dist_cls(self) -> Type[TorchDistribution]: - return self._get_default_action_dist_class("train") - def forward(self, batch: Mapping[str, Any], **kwargs) -> Mapping[str, Any]: """forward pass of the module. @@ -132,20 +115,6 @@ def _inference_only_get_state_hook( """ pass - def _get_default_action_dist_class(self, what: str) -> Type[TorchDistribution]: - # The default implementation is to return TorchCategorical for Discrete action - # spaces and TorchDiagGaussian for Box action spaces. For all other spaces, - # raise a NotImplementedError - if isinstance(self.config.action_space, gym.spaces.Discrete): - return TorchCategorical - elif isinstance(self.config.action_space, gym.spaces.Box): - return TorchDiagGaussian - else: - raise NotImplementedError( - f"Override your RLModule's `get_{what}_action_dist_cls` method and " - "return the correct TorchDistribution class from it!" - ) - class TorchDDPRLModule(RLModule, nn.parallel.DistributedDataParallel): def __init__(self, *args, **kwargs) -> None: diff --git a/rllib/examples/rl_modules/classes/tiny_atari_cnn.py b/rllib/examples/rl_modules/classes/tiny_atari_cnn.py index f7f6e9311fbb..2f45cf219734 100644 --- a/rllib/examples/rl_modules/classes/tiny_atari_cnn.py +++ b/rllib/examples/rl_modules/classes/tiny_atari_cnn.py @@ -1,5 +1,6 @@ from ray.rllib.core.columns import Columns from ray.rllib.core.rl_module.torch import TorchRLModule +from ray.rllib.models.torch.torch_distributions import TorchCategorical from ray.rllib.models.torch.misc import normc_initializer from ray.rllib.models.torch.misc import same_padding, valid_padding from ray.rllib.utils.annotations import override @@ -114,6 +115,21 @@ def _forward_train(self, batch, **kwargs): Columns.VF_PREDS: values, } + # TODO (sven): We still need to define the distibution to use here, even though, + # we have a pretty standard action space (Discrete), which should simply always map + # to a categorical dist. by default. + @override(TorchRLModule) + def get_inference_action_dist_cls(self): + return TorchCategorical + + @override(TorchRLModule) + def get_exploration_action_dist_cls(self): + return TorchCategorical + + @override(TorchRLModule) + def get_train_action_dist_cls(self): + return TorchCategorical + # TODO (sven): In order for this RLModule to work with PPO, we must define # our own `_compute_values()` method. This would become more obvious, if we simply # subclassed the `PPOTorchRLModule` directly here (which we didn't do for