Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Cleanup examples folder #10: Add custom_rl_module.py example script and matching RLModule example class (tiny CNN).. #45774

Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2119,7 +2119,6 @@ py_test(

# subdirectory: checkpoints/
# ....................................

py_test(
name = "examples/checkpoints/checkpoint_by_custom_criteria",
main = "examples/checkpoints/checkpoint_by_custom_criteria.py",
Expand Down Expand Up @@ -2283,7 +2282,6 @@ py_test(

# subdirectory: curriculum/
# ....................................

py_test(
name = "examples/curriculum/curriculum_learning",
main = "examples/curriculum/curriculum_learning.py",
Expand All @@ -2295,7 +2293,6 @@ py_test(

# subdirectory: debugging/
# ....................................

#@OldAPIStack
py_test(
name = "examples/debugging/deterministic_training_torch",
Expand All @@ -2308,7 +2305,6 @@ py_test(

# subdirectory: envs/
# ....................................

py_test(
name = "examples/envs/custom_gym_env",
main = "examples/envs/custom_gym_env.py",
Expand Down Expand Up @@ -2449,7 +2445,6 @@ py_test(

# subdirectory: gpus/
# ....................................

py_test(
name = "examples/gpus/fractional_0.5_gpus_per_learner",
main = "examples/gpus/fractional_gpus_per_learner.py",
Expand All @@ -2469,7 +2464,6 @@ py_test(

# subdirectory: hierarchical/
# ....................................

#@OldAPIStack
py_test(
name = "examples/hierarchical/hierarchical_training_tf",
Expand All @@ -2492,7 +2486,6 @@ py_test(

# subdirectory: inference/
# ....................................

#@OldAPIStack
py_test(
name = "examples/inference/policy_inference_after_training_tf",
Expand Down Expand Up @@ -2905,6 +2898,15 @@ py_test(

# subdirectory: rl_modules/
# ....................................
py_test(
name = "examples/rl_modules/custom_rl_module",
main = "examples/rl_modules/custom_rl_module.py",
tags = ["team:rllib", "examples"],
size = "medium",
srcs = ["examples/rl_modules/custom_rl_module.py"],
args = ["--enable-new-api-stack", "--stop-iters=3"],
)

#@OldAPIStack @HybridAPIStack
py_test(
name = "examples/rl_modules/classes/mobilenet_rlm_hybrid_api_stack",
Expand Down
3 changes: 1 addition & 2 deletions rllib/algorithms/ppo/torch/ppo_torch_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ def setup(self):
super().setup()

# If not an inference-only module (e.g., for evaluation), set up the
# parameter names to be removed or renamed when syncing from the state dict
# when synching.
# parameter names to be removed or renamed when syncing from the state dict.
if not self.inference_only:
# Set the expected and unexpected keys for the inference-only module.
self._set_inference_only_state_dict_keys()
Expand Down
31 changes: 14 additions & 17 deletions rllib/core/rl_module/rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import datetime
import json
import pathlib
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Mapping, Any, TYPE_CHECKING, Optional, Type, Dict, Union

import gymnasium as gym
Expand Down Expand Up @@ -203,7 +203,7 @@ class RLModuleConfig:

observation_space: gym.Space = None
action_space: gym.Space = None
model_config_dict: Dict[str, Any] = None
model_config_dict: Dict[str, Any] = field(default_factory=dict)
catalog_class: Type["Catalog"] = None

def get_catalog(self) -> "Catalog":
Expand Down Expand Up @@ -456,22 +456,23 @@ def setup(self):

This is called automatically during the __init__ method of this class,
therefore, the subclass should call super.__init__() in its constructor. This
abstraction can be used to create any component that your RLModule needs.
abstraction can be used to create any components (e.g. NN layers) that your
RLModule needs.
"""
return None

@OverrideToImplementCustomLogic
def get_train_action_dist_cls(self) -> Type[Distribution]:
"""Returns the action distribution class for this RLModule used for training.

This class is used to create action distributions from outputs of the
forward_train method. If the case that no action distribution class is needed,
This class is used to get the correct action distribution class to be used by
the training components. In case that no action distribution class is needed,
this method can return None.

Note that RLlib's distribution classes all implement the `Distribution`
interface. This requires two special methods: `Distribution.from_logits()` and
`Distribution.to_deterministic()`. See the documentation for `Distribution`
for more detail.
`Distribution.to_deterministic()`. See the documentation of the
:py:class:`~ray.rllib.models.distributions.Distribution` class for more details.
"""
raise NotImplementedError

Expand All @@ -485,8 +486,8 @@ def get_exploration_action_dist_cls(self) -> Type[Distribution]:

Note that RLlib's distribution classes all implement the `Distribution`
interface. This requires two special methods: `Distribution.from_logits()` and
`Distribution.to_deterministic()`. See the documentation for `Distribution`
for more detail.
`Distribution.to_deterministic()`. See the documentation of the
:py:class:`~ray.rllib.models.distributions.Distribution` class for more details.
"""
raise NotImplementedError

Expand All @@ -500,8 +501,8 @@ def get_inference_action_dist_cls(self) -> Type[Distribution]:

Note that RLlib's distribution classes all implement the `Distribution`
interface. This requires two special methods: `Distribution.from_logits()` and
`Distribution.to_deterministic()`. See the documentation for `Distribution`
for more detail.
`Distribution.to_deterministic()`. See the documentation of the
:py:class:`~ray.rllib.models.distributions.Distribution` class for more details.
"""
raise NotImplementedError

Expand Down Expand Up @@ -596,9 +597,7 @@ def output_specs_inference(self) -> SpecType:
a dict that has `action_dist` key and its value is an instance of
`Distribution`.
"""
# TODO (sven): We should probably change this to [ACTION_DIST_INPUTS], b/c this
# is what most algos will do.
return {"action_dist": Distribution}
return [Columns.ACTION_DIST_INPUTS]

@OverrideToImplementCustomLogic_CallToSuperRecommended
def output_specs_exploration(self) -> SpecType:
Expand All @@ -609,9 +608,7 @@ def output_specs_exploration(self) -> SpecType:
a dict that has `action_dist` key and its value is an instance of
`Distribution`.
"""
# TODO (sven): We should probably change this to [ACTION_DIST_INPUTS], b/c this
# is what most algos will do.
return {"action_dist": Distribution}
return [Columns.ACTION_DIST_INPUTS]

def output_specs_train(self) -> SpecType:
"""Returns the output specs of the forward_train method."""
Expand Down
115 changes: 73 additions & 42 deletions rllib/core/rl_module/torch/torch_rl_module.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
import pathlib
from typing import Any, List, Mapping, Tuple, Union, Type

import gymnasium as gym
from packaging import version

from ray.rllib.core.rl_module import RLModule
from ray.rllib.core.rl_module.rl_module_with_target_networks_interface import (
RLModuleWithTargetNetworksInterface,
)
from ray.rllib.core.rl_module.torch.torch_compile_config import TorchCompileConfig
from ray.rllib.models.torch.torch_distributions import TorchDistribution
from ray.rllib.models.torch.torch_distributions import (
TorchCategorical,
TorchDiagGaussian,
TorchDistribution,
)
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.numpy import convert_to_numpy
Expand All @@ -21,47 +26,6 @@
torch, nn = try_import_torch()


def compile_wrapper(rl_module: "TorchRLModule", compile_config: TorchCompileConfig):
"""A wrapper that compiles the forward methods of a TorchRLModule."""

# TODO(Artur): Remove this once our requirements enforce torch >= 2.0.0
# Check if torch framework supports torch.compile.
if (
torch is not None
and version.parse(torch.__version__) < TORCH_COMPILE_REQUIRED_VERSION
):
raise ValueError("torch.compile is only supported from torch 2.0.0")

compiled_forward_train = torch.compile(
rl_module._forward_train,
backend=compile_config.torch_dynamo_backend,
mode=compile_config.torch_dynamo_mode,
**compile_config.kwargs
)

rl_module._forward_train = compiled_forward_train

compiled_forward_inference = torch.compile(
rl_module._forward_inference,
backend=compile_config.torch_dynamo_backend,
mode=compile_config.torch_dynamo_mode,
**compile_config.kwargs
)

rl_module._forward_inference = compiled_forward_inference

compiled_forward_exploration = torch.compile(
rl_module._forward_exploration,
backend=compile_config.torch_dynamo_backend,
mode=compile_config.torch_dynamo_mode,
**compile_config.kwargs
)

rl_module._forward_exploration = compiled_forward_exploration

return rl_module


class TorchRLModule(nn.Module, RLModule):
"""A base class for RLlib PyTorch RLModules.

Expand All @@ -85,6 +49,18 @@ def __init__(self, *args, **kwargs) -> None:
nn.Module.__init__(self)
RLModule.__init__(self, *args, **kwargs)

@override(RLModule)
def get_inference_action_dist_cls(self) -> Type[TorchDistribution]:
return self._get_default_action_dist_class("inference")

@override(RLModule)
def get_exploration_action_dist_cls(self) -> Type[TorchDistribution]:
return self._get_default_action_dist_class("exploration")

@override(RLModule)
def get_train_action_dist_cls(self) -> Type[TorchDistribution]:
return self._get_default_action_dist_class("train")

def forward(self, batch: Mapping[str, Any], **kwargs) -> Mapping[str, Any]:
"""forward pass of the module.

Expand Down Expand Up @@ -156,6 +132,20 @@ def _inference_only_get_state_hook(
"""
pass

def _get_default_action_dist_class(self, what: str) -> Type[TorchDistribution]:
# The default implementation is to return TorchCategorical for Discrete action
# spaces and TorchDiagGaussian for Box action spaces. For all other spaces,
# raise a NotImplementedError
if isinstance(self.config.action_space, gym.spaces.Discrete):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not using TorchMultiCategorical and TorchMultiDistribution - things that get assembled inside of the Catalog?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure either, tbh. I just wanted to get the most simple setup automated. I feel like users that just want to "hack together an RLModule" should not be concerned about picking the categorical distr for their CartPole action space :)

Yes, we should extend this method to even more decent defaults, I think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's continue brainstorming how to simplify the general RLModule experience for the user ...

return TorchCategorical
elif isinstance(self.config.action_space, gym.spaces.Box):
return TorchDiagGaussian
else:
raise NotImplementedError(
f"Override your RLModule's `get_{what}_action_dist_cls` method and "
"return the correct TorchDistribution class from it!"
)


class TorchDDPRLModule(RLModule, nn.parallel.DistributedDataParallel):
def __init__(self, *args, **kwargs) -> None:
Expand Down Expand Up @@ -234,3 +224,44 @@ class TorchDDPRLModuleWithTargetNetworksInterface(
@override(RLModuleWithTargetNetworksInterface)
def get_target_network_pairs(self) -> List[Tuple[NetworkType, NetworkType]]:
return self.module.get_target_network_pairs()


def compile_wrapper(rl_module: "TorchRLModule", compile_config: TorchCompileConfig):
"""A wrapper that compiles the forward methods of a TorchRLModule."""

# TODO(Artur): Remove this once our requirements enforce torch >= 2.0.0
# Check if torch framework supports torch.compile.
if (
torch is not None
and version.parse(torch.__version__) < TORCH_COMPILE_REQUIRED_VERSION
):
raise ValueError("torch.compile is only supported from torch 2.0.0")

compiled_forward_train = torch.compile(
rl_module._forward_train,
backend=compile_config.torch_dynamo_backend,
mode=compile_config.torch_dynamo_mode,
**compile_config.kwargs,
)

rl_module._forward_train = compiled_forward_train

compiled_forward_inference = torch.compile(
rl_module._forward_inference,
backend=compile_config.torch_dynamo_backend,
mode=compile_config.torch_dynamo_mode,
**compile_config.kwargs,
)

rl_module._forward_inference = compiled_forward_inference

compiled_forward_exploration = torch.compile(
rl_module._forward_exploration,
backend=compile_config.torch_dynamo_backend,
mode=compile_config.torch_dynamo_mode,
**compile_config.kwargs,
)

rl_module._forward_exploration = compiled_forward_exploration

return rl_module
7 changes: 2 additions & 5 deletions rllib/env/single_agent_env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,9 @@ def __init__(self, config: AlgorithmConfig, **kwargs):
try:
module_spec: SingleAgentRLModuleSpec = self.config.rl_module_spec
module_spec.observation_space = self._env_to_module.observation_space
# TODO (simon): The `gym.Wrapper` for `gym.vector.VectorEnv` should
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great that this is gone now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it didn't seem to be a problem anymore (e.g. for PPO Pendulum, everything looks completely fine w/o any weird space errors on the Box actions). So I removed this comment.

# actually hold the spaces for a single env, but for boxes the
# shape is (1, 1) which brings a problem with the action dists.
# shape=(1,) is expected.
module_spec.action_space = self.env.envs[0].action_space
module_spec.model_config_dict = self.config.model_config
if module_spec.model_config_dict is None:
module_spec.model_config_dict = self.config.model_config
# Only load a light version of the module, if available. This is useful
# if the the module has target or critic networks not needed in sampling
# or inference.
Expand Down
6 changes: 0 additions & 6 deletions rllib/examples/rl_modules/action_masking_rlm.py

This file was deleted.

Loading
Loading