diff --git a/README.rst b/README.rst index a2a3495406cb..c2a389f80617 100644 --- a/README.rst +++ b/README.rst @@ -35,6 +35,8 @@ Or more about `Ray Core`_ and its key abstractions: - `Actors`_: Stateful worker processes created in the cluster. - `Objects`_: Immutable values accessible across the cluster. +Monitor and debug Ray applications and clusters using the `Ray dashboard `__. + Ray runs on any machine, cluster, cloud provider, and Kubernetes, and features a growing `ecosystem of community integrations`_. diff --git a/doc/source/_toc.yml b/doc/source/_toc.yml index 066a552b5f0e..38f6dbd3c391 100644 --- a/doc/source/_toc.yml +++ b/doc/source/_toc.yml @@ -141,6 +141,7 @@ parts: - file: data/random-access - file: data/faq - file: data/api/api + - file: data/glossary - file: data/integrations - file: train/train @@ -361,9 +362,6 @@ parts: - file: ray-observability/monitoring-debugging/monitoring-debugging title: "Monitoring and Debugging" - sections: - - file: ray-observability/index - title: Tools - file: ray-references/api title: References @@ -380,4 +378,4 @@ parts: - file: ray-contribute/fake-autoscaler - file: ray-core/examples/testing-tips - file: ray-core/configure - - file: ray-contribute/whitepaper \ No newline at end of file + - file: ray-contribute/whitepaper diff --git a/doc/source/data/getting-started.rst b/doc/source/data/getting-started.rst index 0ddb73e10a83..6ff45965180b 100644 --- a/doc/source/data/getting-started.rst +++ b/doc/source/data/getting-started.rst @@ -144,3 +144,8 @@ or remote filesystems. To learn more about saving datasets, read :ref:`Saving datasets `. + +Next Steps +---------- + +* To check how your application is doing, you can use the :ref:`Ray dashboard`. \ No newline at end of file diff --git a/doc/source/data/glossary.rst b/doc/source/data/glossary.rst new file mode 100644 index 000000000000..067491ec9cd8 --- /dev/null +++ b/doc/source/data/glossary.rst @@ -0,0 +1,137 @@ +.. _datasets_glossary: + +===================== +Ray Datasets Glossary +===================== + +.. glossary:: + + Batch format + The way batches of data are represented. + + Set ``batch_format`` in methods like + :meth:`Dataset.iter_batches() ` and + :meth:`Dataset.map_batches() ` to specify the + batch type. + + .. doctest:: + + >>> import ray + >>> dataset = ray.data.range_table(10) + >>> next(iter(dataset.iter_batches(batch_format="numpy", batch_size=5))) + {'value': array([0, 1, 2, 3, 4])} + >>> next(iter(dataset.iter_batches(batch_format="pandas", batch_size=5))) + value + 0 0 + 1 1 + 2 2 + 3 3 + 4 4 + + To learn more about batch formats, read + :ref:`UDF Input Batch Formats `. + + Block + A processing unit of data. A :class:`~ray.data.Dataset` consists of a + collection of blocks. + + Under the hood, :term:`Datasets ` partition :term:`records ` + into a set of distributed data blocks. This allows Datasets to perform operations + in parallel. + + Unlike a batch, which is a user-facing object, a block is an internal abstraction. + + Block format + The way :term:`blocks ` are represented. + + Blocks are represented as + `Arrow tables `_, + `pandas DataFrames `_, + and Python lists. To determine the block format, call + :meth:`Dataset.dataset_format() `. + + Datasets (library) + A library for distributed data processing. + + Datasets isn’t intended as a replacement for more general data processing systems. + Its utility is as the last-mile bridge from ETL pipeline outputs to distributed + ML applications and libraries in Ray. + + To learn more about Ray Datasets, read :ref:`Key Concepts `. + + Dataset (object) + A class that represents a distributed collection of data. + + :class:`~ray.data.Dataset` exposes methods to read, transform, and consume data at scale. + + To learn more about Datasets and the operations they support, read the :ref:`Datasets API Reference `. + + Datasource + A :class:`~ray.data.Datasource` specifies how to read and write from + a variety of external storage and data formats. + + Examples of Datasources include :class:`~ray.data.datasource.ParquetDatasource`, + :class:`~ray.data.datasource.ImageDatasource`, + :class:`~ray.data.datasource.TFRecordDatasource`, + :class:`~ray.data.datasource.CSVDatasource`, and + :class:`~ray.data.datasource.MongoDatasource`. + + To learn more about Datasources, read :ref:`Creating a Custom Datasource `. + + Record + A single data item. + + If your dataset is :term:`tabular `, then records are :class:`TableRows `. + If your dataset is :term:`simple `, then records are arbitrary Python objects. + If your dataset is :term:`tensor `, then records are `NumPy ndarrays `_. + + Schema + The data type of a dataset. + + If your dataset is :term:`tabular `, then the schema describes + the column names and data types. If your dataset is :term:`simple `, + then the schema describes the Python object type. If your dataset is + :term:`tensor `, then the schema describes the per-element + tensor shape and data type. + + To determine a dataset's schema, call + :meth:`Dataset.schema() `. + + Simple Dataset + A Dataset that represents a collection of arbitrary Python objects. + + .. doctest:: + + >>> import ray + >>> ray.data.from_items(["spam", "ham", "eggs"]) + Dataset(num_blocks=3, num_rows=3, schema=) + + Tensor Dataset + A Dataset that represents a collection of ndarrays. + + :term:`Tabular datasets ` that contain tensor columns aren’t tensor datasets. + + .. doctest:: + + >>> import numpy as np + >>> import ray + >>> ray.data.from_numpy(np.zeros((100, 32, 32, 3))) + Dataset(num_blocks=1, num_rows=100, schema={__value__: ArrowTensorType(shape=(32, 32, 3), dtype=double)}) + + Tabular Dataset + A Dataset that represents columnar data. + + .. doctest:: + + >>> import ray + >>> ray.data.read_csv("s3://anonymous@air-example-data/iris.csv") + Dataset(num_blocks=1, num_rows=150, schema={sepal length (cm): double, sepal width (cm): double, petal length (cm): double, petal width (cm): double, target: int64}) + + User-defined function (UDF) + A callable that transforms batches or :term:`records ` of data. UDFs let you arbitrarily transform datasets. + + Call :meth:`Dataset.map_batches() `, + :meth:`Dataset.map() `, or + :meth:`Dataset.flat_map() ` to apply UDFs. + + To learn more about UDFs, read :ref:`Writing User-Defined Functions `. diff --git a/doc/source/ray-air/getting-started.rst b/doc/source/ray-air/getting-started.rst index 1e9cb01225de..9a16203b7a34 100644 --- a/doc/source/ray-air/getting-started.rst +++ b/doc/source/ray-air/getting-started.rst @@ -205,3 +205,4 @@ Next Steps - :ref:`air-examples-ref` - :ref:`API reference ` - :ref:`Technical whitepaper ` +- To check how your application is doing, you can use the :ref:`Ray dashboard`. diff --git a/doc/source/ray-core/ray-dashboard.rst b/doc/source/ray-core/ray-dashboard.rst index 93e4ff537421..d847b9fed0a4 100644 --- a/doc/source/ray-core/ray-dashboard.rst +++ b/doc/source/ray-core/ray-dashboard.rst @@ -2,16 +2,12 @@ Ray Dashboard ============= -Ray's built-in dashboard provides metrics, charts, and other features that help -Ray users to understand Ray clusters and libraries. +Ray provides a web-based dashboard for monitoring and debugging Ray applications. +The dashboard provides a visual representation of the system state, allowing users to track the performance +of their applications and troubleshoot issues. -The dashboard lets you: - -- View cluster metrics including time-series visualizations. -- See errors and exceptions at a glance. -- View logs across many machines. -- See all your ray jobs and the logs for those jobs. -- See your ray actors and their logs +.. image:: https://raw.githubusercontent.com/ray-project/Images/master/docs/new-dashboard/Dashboard-overview.png + :align: center Getting Started --------------- diff --git a/doc/source/ray-core/walkthrough.rst b/doc/source/ray-core/walkthrough.rst index aeca0c67ac2f..80b42f8ac665 100644 --- a/doc/source/ray-core/walkthrough.rst +++ b/doc/source/ray-core/walkthrough.rst @@ -60,6 +60,8 @@ As seen above, Ray stores task and actor call results in its :ref:`distributed o Next Steps ---------- +.. tip:: To check how your application is doing, you can use the :ref:`Ray dashboard `. + Ray's key primitives are simple, but can be composed together to express almost any kind of distributed computation. Learn more about Ray's :ref:`key concepts ` with the following user guides: diff --git a/doc/source/ray-observability/index.rst b/doc/source/ray-observability/index.rst deleted file mode 100644 index f38e9eab58a5..000000000000 --- a/doc/source/ray-observability/index.rst +++ /dev/null @@ -1,18 +0,0 @@ -.. _observability: - -Observability -============= - -.. toctree:: - :maxdepth: 2 - :caption: Observability, Debugging, and Profiling - - overview - ../ray-core/ray-dashboard.rst - state/state-api.rst - ray-debugging.rst - ray-logging.rst - ray-metrics.rst - ray-tracing.rst - ../ray-contribute/debugging.rst - ../ray-contribute/profiling.rst diff --git a/doc/source/ray-observability/monitoring-debugging/monitoring-debugging.rst b/doc/source/ray-observability/monitoring-debugging/monitoring-debugging.rst index 1e147e8e5eb1..4dd9b646d0f7 100644 --- a/doc/source/ray-observability/monitoring-debugging/monitoring-debugging.rst +++ b/doc/source/ray-observability/monitoring-debugging/monitoring-debugging.rst @@ -1,3 +1,5 @@ +.. _observability: + Monitoring and Debugging ======================== @@ -8,9 +10,18 @@ See :ref:`Getting Help ` if your problem is not s .. toctree:: :maxdepth: 0 + ../overview + ../../ray-core/ray-dashboard + ../state/state-api + ../ray-debugging + ../ray-logging + ../ray-metrics + profiling + ../ray-tracing troubleshoot-failures troubleshoot-hangs troubleshoot-performance gotchas - profiling getting-help + ../../ray-contribute/debugging.rst + ../../ray-contribute/profiling.rst diff --git a/doc/source/ray-observability/overview.rst b/doc/source/ray-observability/overview.rst index f45bbfbf69a5..3e073c51a43d 100644 --- a/doc/source/ray-observability/overview.rst +++ b/doc/source/ray-observability/overview.rst @@ -5,6 +5,13 @@ This section covers a list of available monitoring and debugging tools and featu This documentation only covers the high-level description of available tools and features. For more details, see :ref:`Ray Observability `. +Dashboard (Web UI) +------------------ +Ray supports the web-based dashboard to help users monitor the cluster. When a new cluster is started, the dashboard is available +through the default address `localhost:8265` (port can be automatically incremented if port 8265 is already occupied). + +See :ref:`Ray Dashboard ` for more details. + Application Logging ------------------- By default, all stdout and stderr of tasks and actors are streamed to the Ray driver (the entrypoint script that calls ``ray.init``). @@ -79,13 +86,6 @@ The following command will list all the actors from the cluster. See :ref:`Ray State API ` for more details. -Dashboard (Web UI) ------------------- -Ray supports the web-based dashboard to help users monitor the cluster. When a new cluster is started, the dashboard is available -through the default address `localhost:8265` (port can be automatically incremented if port 8265 is already occupied). - -See :ref:`Ray Dashboard ` for more details. - Debugger -------- Ray has a built-in debugger that allows you to debug your distributed applications. diff --git a/doc/source/ray-overview/index.md b/doc/source/ray-overview/index.md index 566df0c30fd2..6814cc658d94 100644 --- a/doc/source/ray-overview/index.md +++ b/doc/source/ray-overview/index.md @@ -615,6 +615,113 @@ ray submit cluster.yaml example.py --start ````` +## Debugging and Monitoring Ray Quick Start + +You can use built-in observability tools to monitor and debug Ray applications and clusters. + +`````{dropdown} ray Ray Dashboard: Web GUI to monitor and debug Ray +:animate: fade-in-slide-down + +Ray dashboard provides a visual interface that displays real-time system metrics, node-level resource monitoring, job profiling, and task visualizations. The dashboard is designed to help users understand the performance of their Ray applications and identify potential issues. + +```{image} https://raw.githubusercontent.com/ray-project/Images/master/docs/new-dashboard/Dashboard-overview.png +:align: center +``` + +````{note} +To get started with ray dashboard install the Ray default installation as follows. + +```bash +pip install "ray[default]" +``` +```` + +```{link-button} ../ray-core/ray-dashboard +:type: ref +:text: Learn more about Ray Dashboard. +:classes: btn-outline-primary btn-block +``` + +````` + +`````{dropdown} ray Ray State APIs: CLI to access cluster states +:animate: fade-in-slide-down + +Ray state APIs allow users to conveniently access the current state (snapshot) of Ray through CLI or Python SDK. + +````{note} +To get started with ray state API install the Ray default installation as follows. + +```bash +pip install "ray[default]" +``` +```` + +Run the following code. + +```{code-block} python + + import ray + import time + + ray.init(num_cpus=4) + + @ray.remote + def task_running_300_seconds(): + print("Start!") + time.sleep(300) + + @ray.remote + class Actor: + def __init__(self): + print("Actor created") + + # Create 2 tasks + tasks = [task_running_300_seconds.remote() for _ in range(2)] + + # Create 2 actors + actors = [Actor.remote() for _ in range(2)] + + ray.get(tasks) + +``` + +See the summarized statistics of Ray tasks using ``ray summary tasks``. + +```{code-block} bash + + ray summary tasks + +``` + +```{code-block} text + + ======== Tasks Summary: 2022-07-22 08:54:38.332537 ======== + Stats: + ------------------------------------ + total_actor_scheduled: 2 + total_actor_tasks: 0 + total_tasks: 2 + + + Table (group by func_name): + ------------------------------------ + FUNC_OR_CLASS_NAME STATE_COUNTS TYPE + 0 task_running_300_seconds RUNNING: 2 NORMAL_TASK + 1 Actor.__init__ FINISHED: 2 ACTOR_CREATION_TASK + +``` + +```{link-button} ../ray-observability/state/state-api +:type: ref +:text: Learn more about Ray State APIs +:classes: btn-outline-primary btn-block +``` + +````` + + + ```{include} learn-more.md ``` diff --git a/doc/source/rllib/rllib-training.rst b/doc/source/rllib/rllib-training.rst index 3f82f3ebd229..e9d13b7b8e3e 100644 --- a/doc/source/rllib/rllib-training.rst +++ b/doc/source/rllib/rllib-training.rst @@ -604,3 +604,8 @@ Stack Traces You can use the ``ray stack`` command to dump the stack traces of all the Python workers on a single node. This can be useful for debugging unexpected hangs or performance issues. + +Next Steps +---------- + +- To check how your application is doing, you can use the :ref:`Ray dashboard`. \ No newline at end of file diff --git a/doc/source/serve/getting_started.md b/doc/source/serve/getting_started.md index ca8b538a1630..b5f6d40bffa0 100644 --- a/doc/source/serve/getting_started.md +++ b/doc/source/serve/getting_started.md @@ -259,6 +259,7 @@ Deployment graphs are useful since they let you deploy each part of your machine - Dive into the {doc}`key-concepts` to get a deeper understanding of Ray Serve. - Learn more about how to deploy your Ray Serve application to production: {ref}`serve-in-production`. - Check more in-depth tutorials for popular machine learning frameworks: {doc}`tutorials/index`. +- To check how your application is doing, you can use the :ref:`Ray dashboard `. ```{rubric} Footnotes ``` diff --git a/doc/source/train/getting-started.rst b/doc/source/train/getting-started.rst index cb59e5f64c9d..d071b2c44ee8 100644 --- a/doc/source/train/getting-started.rst +++ b/doc/source/train/getting-started.rst @@ -179,3 +179,9 @@ Here are examples for some of the commonly used trainers: :end-before: __tf_trainer_end__ See :ref:`train-porting-code` for a more comprehensive example. + + +Next Steps +---------- + +* To check how your application is doing, you can use the :ref:`Ray dashboard`. \ No newline at end of file diff --git a/doc/source/tune/getting-started.rst b/doc/source/tune/getting-started.rst index 14e303b2dae3..ccf805d83326 100644 --- a/doc/source/tune/getting-started.rst +++ b/doc/source/tune/getting-started.rst @@ -164,3 +164,4 @@ Next Steps * Check out the :ref:`Tune tutorials ` for guides on using Tune with your preferred machine learning library. * Browse our :ref:`gallery of examples ` to see how to use Tune with PyTorch, XGBoost, Tensorflow, etc. * `Let us know `__ if you ran into issues or have any questions by opening an issue on our Github. +* To check how your application is doing, you can use the :ref:`Ray dashboard`. diff --git a/python/ray/_private/test_utils.py b/python/ray/_private/test_utils.py index da02cc8e3002..8e5ccccb9ce4 100644 --- a/python/ray/_private/test_utils.py +++ b/python/ray/_private/test_utils.py @@ -1572,7 +1572,20 @@ def no_resource_leaks_excluding_node_resources(): @contextmanager -def simulate_storage(storage_type, root=None): +def simulate_storage( + storage_type: str, + root: Optional[str] = None, + port: int = 5002, + region: str = "us-west-2", +): + """Context that simulates a given storage type and yields the URI. + + Args: + storage_type: The storage type to simiulate ("fs" or "s3") + root: Root directory of the URI to return (e.g., s3 bucket name) + port: The port of the localhost endpoint where s3 is being served (s3 only) + region: The s3 region (s3 only) + """ if storage_type == "fs": if root is None: with tempfile.TemporaryDirectory() as d: @@ -1580,38 +1593,17 @@ def simulate_storage(storage_type, root=None): else: yield "file://" + root elif storage_type == "s3": - import uuid - - from moto import mock_s3 - - from ray.tests.mock_s3_server import start_service, stop_process - - @contextmanager - def aws_credentials(): - old_env = os.environ - os.environ["AWS_ACCESS_KEY_ID"] = "testing" - os.environ["AWS_SECRET_ACCESS_KEY"] = "testing" - os.environ["AWS_SECURITY_TOKEN"] = "testing" - os.environ["AWS_SESSION_TOKEN"] = "testing" - yield - os.environ = old_env - - @contextmanager - def moto_s3_server(): - host = "localhost" - port = 5002 - url = f"http://{host}:{port}" - process = start_service("s3", host, port) - yield url - stop_process(process) - - if root is None: - root = uuid.uuid4().hex - with moto_s3_server() as s3_server, aws_credentials(), mock_s3(): - url = f"s3://{root}?region=us-west-2&endpoint_override={s3_server}" - yield url + from moto.server import ThreadedMotoServer + + root = root or uuid.uuid4().hex + s3_server = f"http://localhost:{port}" + server = ThreadedMotoServer(port=port) + server.start() + url = f"s3://{root}?region={region}&endpoint_override={s3_server}" + yield url + server.stop() else: - raise ValueError(f"Unknown storage type: {storage_type}") + raise NotImplementedError(f"Unknown storage type: {storage_type}") def job_hook(**kwargs): diff --git a/python/ray/air/_internal/uri_utils.py b/python/ray/air/_internal/uri_utils.py new file mode 100644 index 000000000000..c6222198b137 --- /dev/null +++ b/python/ray/air/_internal/uri_utils.py @@ -0,0 +1,61 @@ +from pathlib import Path +import urllib.parse +import os +from typing import Union + + +class URI: + """Represents a URI, supporting path appending and retrieving parent URIs. + + Example Usage: + + >>> s3_uri = URI("s3://bucket/a?scheme=http&endpoint_override=localhost%3A900") + >>> s3_uri + URI + >>> str(s3_uri / "b" / "c") + 's3://bucket/a/b/c?scheme=http&endpoint_override=localhost%3A900' + >>> str(s3_uri.parent) + 's3://bucket?scheme=http&endpoint_override=localhost%3A900' + >>> str(s3_uri) + 's3://bucket/a?scheme=http&endpoint_override=localhost%3A900' + >>> s3_uri.parent.name, s3_uri.name + ('bucket', 'a') + + Args: + uri: The URI to represent. + Ex: s3://bucket?scheme=http&endpoint_override=localhost%3A900 + Ex: file:///a/b/c/d + """ + + def __init__(self, uri: str): + self._parsed = urllib.parse.urlparse(uri) + if not self._parsed.scheme: + raise ValueError(f"Invalid URI: {uri}") + self._path = Path(os.path.normpath(self._parsed.netloc + self._parsed.path)) + + @property + def name(self) -> str: + return self._path.name + + @property + def parent(self) -> "URI": + assert self._path.parent != ".", f"{str(self)} has no valid parent URI" + return URI(self._get_str_representation(self._parsed, self._path.parent)) + + def __truediv__(self, path_to_append): + assert isinstance(path_to_append, str) + return URI( + self._get_str_representation(self._parsed, self._path / path_to_append) + ) + + @classmethod + def _get_str_representation( + cls, parsed_uri: urllib.parse.ParseResult, path: Union[str, Path] + ) -> str: + return parsed_uri._replace(netloc=str(path), path="").geturl() + + def __repr__(self): + return f"URI<{str(self)}>" + + def __str__(self): + return self._get_str_representation(self._parsed, self._path) diff --git a/python/ray/air/checkpoint.py b/python/ray/air/checkpoint.py index f1232922f4c3..ddccaf140b0e 100644 --- a/python/ray/air/checkpoint.py +++ b/python/ray/air/checkpoint.py @@ -764,17 +764,26 @@ def __fspath__(self): def get_preprocessor(self) -> Optional["Preprocessor"]: """Return the saved preprocessor, if one exists.""" + if self._override_preprocessor: + return self._override_preprocessor + # The preprocessor will either be stored in an in-memory dict or # written to storage. In either case, it will use the PREPROCESSOR_KEY key. - # First try converting to dictionary. - checkpoint_dict = self.to_dict() - preprocessor = checkpoint_dict.get(PREPROCESSOR_KEY, None) - - if preprocessor is None: - # Fallback to reading from directory. + # If this is a pure directory checkpoint (not a dict checkpoint saved to dir), + # then do not convert to dictionary as that takes a lot of time and memory. + if self.uri: with self.as_directory() as checkpoint_path: - preprocessor = load_preprocessor_from_dir(checkpoint_path) + if _is_persisted_directory_checkpoint(checkpoint_path): + # If this is a persisted directory checkpoint, then we load the + # files from the temp directory created by the context. + # That way we avoid having to download the files twice. + loaded_checkpoint = self.from_directory(checkpoint_path) + preprocessor = _get_preprocessor(loaded_checkpoint) + else: + preprocessor = load_preprocessor_from_dir(checkpoint_path) + else: + preprocessor = _get_preprocessor(self) return preprocessor @@ -859,3 +868,20 @@ def _make_dir(path: str, acquire_del_lock: bool = True) -> None: open(del_lock_path, "a").close() os.makedirs(path, exist_ok=True) + + +def _is_persisted_directory_checkpoint(path: str) -> bool: + return Path(path, _DICT_CHECKPOINT_FILE_NAME).exists() + + +def _get_preprocessor(checkpoint: "Checkpoint") -> Optional["Preprocessor"]: + # First try converting to dictionary. + checkpoint_dict = checkpoint.to_dict() + preprocessor = checkpoint_dict.get(PREPROCESSOR_KEY, None) + + if preprocessor is None: + # Fallback to reading from directory. + with checkpoint.as_directory() as checkpoint_path: + preprocessor = load_preprocessor_from_dir(checkpoint_path) + + return preprocessor diff --git a/python/ray/air/constants.py b/python/ray/air/constants.py index 12dca587419f..1accc998eebd 100644 --- a/python/ray/air/constants.py +++ b/python/ray/air/constants.py @@ -45,3 +45,13 @@ COPY_DIRECTORY_CHECKPOINTS_INSTEAD_OF_MOVING_ENV = ( "TRAIN_COPY_DIRECTORY_CHECKPOINTS_INSTEAD_OF_MOVING" ) + +# Integer value which if set will disable lazy checkpointing +# (avoiding unnecessary serialization if worker is on the same node +# as Trainable) +DISABLE_LAZY_CHECKPOINTING_ENV = "TRAIN_DISABLE_LAZY_CHECKPOINTING" + +# Name of the marker dropped by the Trainable. If a worker detects +# the presence of the marker in the trial dir, it will use lazy +# checkpointing. +LAZY_CHECKPOINT_MARKER_FILE = ".lazy_checkpoint_marker" diff --git a/python/ray/air/session.py b/python/ray/air/session.py index 6fe5f3bb7ccd..b8747f56952b 100644 --- a/python/ray/air/session.py +++ b/python/ray/air/session.py @@ -7,6 +7,7 @@ from ray.air.constants import SESSION_MISUSE_LOG_ONCE_KEY from ray.train.session import _TrainSessionImpl from ray.util import log_once +from ray.util.annotations import PublicAPI if TYPE_CHECKING: from ray.data import DatasetIterator @@ -37,6 +38,7 @@ def wrapper(*args, **kwargs): return inner +@PublicAPI(stability="beta") @_warn_session_misuse() def report(metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> None: """Report metrics and optionally save a checkpoint. @@ -90,6 +92,7 @@ def train_func(): _get_session().report(metrics, checkpoint=checkpoint) +@PublicAPI(stability="beta") @_warn_session_misuse() def get_checkpoint() -> Optional[Checkpoint]: """Access the session's last checkpoint to resume from if applicable. @@ -140,30 +143,35 @@ def train_func(): return _get_session().loaded_checkpoint +@PublicAPI(stability="beta") @_warn_session_misuse() def get_experiment_name() -> str: """Experiment name for the corresponding trial.""" return _get_session().experiment_name +@PublicAPI(stability="beta") @_warn_session_misuse() def get_trial_name() -> str: """Trial name for the corresponding trial.""" return _get_session().trial_name +@PublicAPI(stability="beta") @_warn_session_misuse() def get_trial_id() -> str: """Trial id for the corresponding trial.""" return _get_session().trial_id +@PublicAPI(stability="beta") @_warn_session_misuse() def get_trial_resources() -> "PlacementGroupFactory": """Trial resources for the corresponding trial.""" return _get_session().trial_resources +@PublicAPI(stability="beta") @_warn_session_misuse() def get_trial_dir() -> str: """Log directory corresponding to the trial directory for a Tune session. @@ -186,6 +194,7 @@ def train_func(): return _get_session().trial_dir +@PublicAPI(stability="beta") @_warn_session_misuse(default_value=1) def get_world_size() -> int: """Get the current world size (i.e. total number of workers) for this run. @@ -216,6 +225,7 @@ def train_loop_per_worker(config): return session.world_size +@PublicAPI(stability="beta") @_warn_session_misuse(default_value=0) def get_world_rank() -> int: """Get the world rank of this worker. @@ -249,6 +259,7 @@ def train_loop_per_worker(): return session.world_rank +@PublicAPI(stability="beta") @_warn_session_misuse(default_value=0) def get_local_rank() -> int: """Get the local rank of this worker (rank of the worker on its node). @@ -281,6 +292,7 @@ def train_loop_per_worker(): return session.local_rank +@PublicAPI(stability="beta") @_warn_session_misuse(default_value=0) def get_local_world_size() -> int: """Get the local rank of this worker (rank of the worker on its node). @@ -311,6 +323,7 @@ def get_local_world_size() -> int: return session.local_world_size +@PublicAPI(stability="beta") @_warn_session_misuse(default_value=0) def get_node_rank() -> int: """Get the local rank of this worker (rank of the worker on its node). @@ -341,6 +354,7 @@ def get_node_rank() -> int: return session.node_rank +@PublicAPI(stability="beta") @_warn_session_misuse() def get_dataset_shard( dataset_name: Optional[str] = None, diff --git a/python/ray/air/tests/test_checkpoints.py b/python/ray/air/tests/test_checkpoints.py index 5fa9f64b3416..feecfc0d432d 100644 --- a/python/ray/air/tests/test_checkpoints.py +++ b/python/ray/air/tests/test_checkpoints.py @@ -759,6 +759,13 @@ def testDirCheckpointSetPreprocessor(self): preprocessor = checkpoint.get_preprocessor() assert preprocessor.multiplier == 1 + # Also check that loading from dir works + new_checkpoint_dir = os.path.join(tmpdir, "new_checkpoint") + checkpoint.to_directory(new_checkpoint_dir) + checkpoint = Checkpoint.from_directory(new_checkpoint_dir) + preprocessor = checkpoint.get_preprocessor() + assert preprocessor.multiplier == 1 + def testDirCheckpointSetPreprocessorAsDict(self): with tempfile.TemporaryDirectory() as tmpdir: preprocessor = DummyPreprocessor(1) diff --git a/python/ray/train/_internal/checkpoint.py b/python/ray/train/_internal/checkpoint.py index 39bcddc8fbcb..a85f05d7c915 100644 --- a/python/ray/train/_internal/checkpoint.py +++ b/python/ray/train/_internal/checkpoint.py @@ -1,8 +1,9 @@ +import os import logging from pathlib import Path from typing import Callable, Dict, List, Optional, Type, Union -from ray.air import Checkpoint, CheckpointConfig +from ray.air import Checkpoint, CheckpointConfig, session from ray.air._internal.checkpoint_manager import CheckpointStorage from ray.air._internal.checkpoint_manager import ( _CheckpointManager as CommonCheckpointManager, @@ -16,6 +17,7 @@ TUNE_CHECKPOINT_ID, TUNE_INSTALLED, CHECKPOINT_METADATA_KEY, + LAZY_CHECKPOINT_MARKER_FILE, ) if TUNE_INSTALLED: @@ -209,6 +211,24 @@ def latest_checkpoint_id(self) -> Optional[int]: class TuneCheckpointManager(CheckpointManager): + def __init__( + self, + run_dir: Optional[Path] = None, + checkpoint_strategy: Optional[CheckpointConfig] = None, + ): + super().__init__(run_dir, checkpoint_strategy) + + # Name of the marker dropped by the Trainable. If a worker detects + # the presence of the marker in the trial dir, it will use lazy + # checkpointing. + self._lazy_marker_path = None + if tune.is_session_enabled(): + self._lazy_marker_path = ( + Path(session.get_trial_dir()) / LAZY_CHECKPOINT_MARKER_FILE + ) + with open(self._lazy_marker_path, "w"): + pass + def _load_checkpoint( self, checkpoint_to_load: Optional[Union[Dict, str, Path, Checkpoint]] ) -> Optional[Union[Dict, Checkpoint]]: @@ -247,6 +267,14 @@ def next_checkpoint_path(self) -> Optional[Path]: def _get_next_checkpoint_path(self) -> Optional[Path]: return None + def __del__(self): + try: + assert self._lazy_marker_path + os.remove(str(self._lazy_marker_path)) + except Exception: + pass + return super().__del__() + def _construct_checkpoint_path_name(checkpoint_id: int) -> str: return f"checkpoint_{checkpoint_id:06d}" diff --git a/python/ray/train/_internal/session.py b/python/ray/train/_internal/session.py index a623dc449a45..369261901f46 100644 --- a/python/ray/train/_internal/session.py +++ b/python/ray/train/_internal/session.py @@ -8,6 +8,7 @@ from dataclasses import dataclass from datetime import datetime from enum import Enum, auto +from pathlib import Path from typing import Callable, Dict, Optional, Type, Union import ray @@ -25,6 +26,7 @@ TIME_TOTAL_S, TIMESTAMP, CHECKPOINT_METADATA_KEY, + LAZY_CHECKPOINT_MARKER_FILE, ) from ray.train.error import SessionMisuseError from ray.train.session import _TrainSessionImpl @@ -300,7 +302,7 @@ def checkpoint(self, checkpoint: Checkpoint): checkpoint and self.enable_lazy_checkpointing and checkpoint._local_path - and self.get_current_ip() == self.trial_info.driver_ip + and (Path(self.trial_info.logdir) / LAZY_CHECKPOINT_MARKER_FILE).exists() ): metadata.update({CHECKPOINT_METADATA_KEY: checkpoint._metadata}) checkpoint = str(checkpoint._local_path) diff --git a/python/ray/train/constants.py b/python/ray/train/constants.py index 71b499befeb0..4b6b1ac61bce 100644 --- a/python/ray/train/constants.py +++ b/python/ray/train/constants.py @@ -13,6 +13,8 @@ TRAIN_DATASET_KEY, WILDCARD_KEY, COPY_DIRECTORY_CHECKPOINTS_INSTEAD_OF_MOVING_ENV, + DISABLE_LAZY_CHECKPOINTING_ENV, + LAZY_CHECKPOINT_MARKER_FILE, ) # Autofilled session.report() metrics. Keys should be consistent with Tune. @@ -64,10 +66,6 @@ # PACK to SPREAD. 1 for True, 0 for False. TRAIN_ENABLE_WORKER_SPREAD_ENV = "TRAIN_ENABLE_WORKER_SPREAD" -# Integer value which if set will disable lazy checkpointing -# (avoiding unnecessary serialization if worker is on the same node -# as Trainable) -DISABLE_LAZY_CHECKPOINTING_ENV = "TRAIN_DISABLE_LAZY_CHECKPOINTING" # Blacklist virtualized networking. DEFAULT_NCCL_SOCKET_IFNAME = "^lo,docker,veth" diff --git a/python/ray/train/tests/test_data_parallel_trainer_checkpointing.py b/python/ray/train/tests/test_data_parallel_trainer_checkpointing.py index c2caeedb1839..d4fdca28b848 100644 --- a/python/ray/train/tests/test_data_parallel_trainer_checkpointing.py +++ b/python/ray/train/tests/test_data_parallel_trainer_checkpointing.py @@ -43,7 +43,6 @@ def checkpoint_train_func(): ("dict", True), ("dir", True), ("lazy_dir", True), - ("dir", False), ("lazy_dir", False), ) diff --git a/python/ray/tune/execution/ray_trial_executor.py b/python/ray/tune/execution/ray_trial_executor.py index 878360e0b4be..2a95537f45f2 100644 --- a/python/ray/tune/execution/ray_trial_executor.py +++ b/python/ray/tune/execution/ray_trial_executor.py @@ -16,7 +16,10 @@ from ray.actor import ActorHandle from ray.air import Checkpoint, AcquiredResources, ResourceRequest from ray.air._internal.checkpoint_manager import CheckpointStorage, _TrackedCheckpoint -from ray.air.constants import COPY_DIRECTORY_CHECKPOINTS_INSTEAD_OF_MOVING_ENV +from ray.air.constants import ( + COPY_DIRECTORY_CHECKPOINTS_INSTEAD_OF_MOVING_ENV, + DISABLE_LAZY_CHECKPOINTING_ENV, +) from ray.air.execution import ResourceManager from ray.air.execution.resources.placement_group import ( PlacementGroupResourceManager, @@ -46,6 +49,7 @@ "PL_DISABLE_FORK": "1" } ENV_VARS_TO_PROPAGATE = { + DISABLE_LAZY_CHECKPOINTING_ENV, COPY_DIRECTORY_CHECKPOINTS_INSTEAD_OF_MOVING_ENV, "TUNE_CHECKPOINT_CLOUD_RETRY_NUM", "TUNE_CHECKPOINT_CLOUD_RETRY_WAIT_TIME_S", diff --git a/python/ray/tune/experiment/experiment.py b/python/ray/tune/experiment/experiment.py index 6e52b8de0b2b..7b9c6fe028d6 100644 --- a/python/ray/tune/experiment/experiment.py +++ b/python/ray/tune/experiment/experiment.py @@ -22,6 +22,7 @@ ) from ray.air import CheckpointConfig +from ray.air._internal.uri_utils import URI from ray.tune.error import TuneError from ray.tune.registry import register_trainable, is_function_trainable from ray.tune.result import DEFAULT_RESULTS_DIR @@ -443,14 +444,8 @@ def checkpoint_dir(self): def remote_checkpoint_dir(self) -> Optional[str]: if not self.sync_config.upload_dir or not self.dir_name: return None - - # NOTE: `upload_dir` can contain query strings. For example: - # 's3://bucket?scheme=http&endpoint_override=localhost%3A9000'. - if "?" in self.sync_config.upload_dir: - path, query = self.sync_config.upload_dir.split("?") - return os.path.join(path, self.dir_name) + "?" + query - - return os.path.join(self.sync_config.upload_dir, self.dir_name) + uri = URI(self.sync_config.upload_dir) + return str(uri / self.dir_name) @property def run_identifier(self): diff --git a/python/ray/tune/experiment/trial.py b/python/ray/tune/experiment/trial.py index 2697329981ba..2b4edcf8c4c3 100644 --- a/python/ray/tune/experiment/trial.py +++ b/python/ray/tune/experiment/trial.py @@ -14,6 +14,7 @@ import ray from ray.air import CheckpointConfig +from ray.air._internal.uri_utils import URI from ray.air._internal.checkpoint_manager import _TrackedCheckpoint, CheckpointStorage import ray.cloudpickle as cloudpickle from ray.exceptions import RayActorError, RayTaskError @@ -601,7 +602,7 @@ def generate_id(cls): return str(uuid.uuid4().hex)[:8] @property - def remote_checkpoint_dir(self): + def remote_checkpoint_dir(self) -> str: """This is the **per trial** remote checkpoint dir. This is different from **per experiment** remote checkpoint dir. @@ -609,9 +610,8 @@ def remote_checkpoint_dir(self): assert self.logdir, "Trial {}: logdir not initialized.".format(self) if not self.sync_config.upload_dir or not self.experiment_dir_name: return None - return os.path.join( - self.sync_config.upload_dir, self.experiment_dir_name, self.relative_logdir - ) + uri = URI(self.sync_config.upload_dir) + return str(uri / self.experiment_dir_name / self.relative_logdir) @property def uses_cloud_checkpointing(self): diff --git a/python/ray/tune/impl/tuner_internal.py b/python/ray/tune/impl/tuner_internal.py index 3ee84e3349c2..e6857f7c848e 100644 --- a/python/ray/tune/impl/tuner_internal.py +++ b/python/ray/tune/impl/tuner_internal.py @@ -7,15 +7,16 @@ import tempfile from pathlib import Path from typing import Any, Callable, Dict, Optional, Type, Union, TYPE_CHECKING, Tuple -import urllib.parse import ray import ray.cloudpickle as pickle from ray.util import inspect_serializability +from ray.air._internal.uri_utils import URI from ray.air._internal.remote_storage import download_from_uri, is_non_local_path_uri from ray.air.config import RunConfig, ScalingConfig from ray.tune import Experiment, TuneError, ExperimentAnalysis from ray.tune.execution.experiment_state import _ResumeConfig +from ray.tune.tune import _Config from ray.tune.registry import is_function_trainable from ray.tune.result_grid import ResultGrid from ray.tune.trainable import Trainable @@ -99,20 +100,29 @@ def __init__( "Tuner(..., run_config=RunConfig(...))" ) + self.trainable = trainable + param_space = param_space or {} + if isinstance(param_space, _Config): + param_space = param_space.to_dict() + if not isinstance(param_space, dict): + raise ValueError( + "The `param_space` passed to the Tuner` must be a dict. " + f"Got '{type(param_space)}' instead." + ) + self.param_space = param_space + self._tune_config = tune_config or TuneConfig() self._run_config = run_config or RunConfig() self._missing_params_error_message = None - self._param_space = param_space or {} - self._process_scaling_config() - # Restore from Tuner checkpoint. if restore_path: self._restore_from_path_or_uri( path_or_uri=restore_path, resume_config=resume_config, overwrite_trainable=trainable, + overwrite_param_space=param_space, ) return @@ -121,7 +131,6 @@ def __init__( raise TuneError("You need to provide a trainable to tune.") self._is_restored = False - self.trainable = trainable self._resume_config = None self._tuner_kwargs = copy.deepcopy(_tuner_kwargs) or {} @@ -301,6 +310,7 @@ def _restore_from_path_or_uri( path_or_uri: str, resume_config: Optional[_ResumeConfig], overwrite_trainable: Optional[TrainableTypeOrTrainer], + overwrite_param_space: Optional[Dict[str, Any]], ): # Sync down from cloud storage if needed synced, experiment_checkpoint_dir = self._maybe_sync_down_tuner_state( @@ -332,6 +342,8 @@ def _restore_from_path_or_uri( self._is_restored = True self.trainable = trainable + if overwrite_param_space: + self.param_space = overwrite_param_space self._resume_config = resume_config if not synced: @@ -347,14 +359,9 @@ def _restore_from_path_or_uri( self._run_config.name = experiment_path.name else: # Set the experiment `name` and `upload_dir` according to the URI - parsed_uri = urllib.parse.urlparse(path_or_uri) - remote_path = Path(os.path.normpath(parsed_uri.netloc + parsed_uri.path)) - upload_dir = parsed_uri._replace( - netloc="", path=str(remote_path.parent) - ).geturl() - - self._run_config.name = remote_path.name - self._run_config.sync_config.upload_dir = upload_dir + uri = URI(path_or_uri) + self._run_config.name = uri.name + self._run_config.sync_config.upload_dir = str(uri.parent) # If we synced, `experiment_checkpoint_dir` will contain a temporary # directory. Create an experiment checkpoint dir instead and move @@ -435,6 +442,15 @@ def trainable(self, trainable: TrainableTypeOrTrainer): self._trainable = trainable self._converted_trainable = self._convert_trainable(trainable) + @property + def param_space(self) -> Dict[str, Any]: + return self._param_space + + @param_space.setter + def param_space(self, param_space: Dict[str, Any]): + self._param_space = param_space + self._process_scaling_config() + def _convert_trainable(self, trainable: TrainableTypeOrTrainer) -> TrainableType: """Converts an AIR Trainer to a Tune trainable and saves the converted trainable. If not using an AIR Trainer, this leaves the trainable as is.""" @@ -449,7 +465,7 @@ def _convert_trainable(self, trainable: TrainableTypeOrTrainer) -> TrainableType def fit(self) -> ResultGrid: trainable = self.converted_trainable assert self._experiment_checkpoint_dir - param_space = copy.deepcopy(self._param_space) + param_space = copy.deepcopy(self.param_space) if not self._is_restored: analysis = self._fit_internal(trainable, param_space) else: @@ -552,14 +568,14 @@ def _get_tune_run_arguments(self, trainable: TrainableType) -> Dict[str, Any]: ) def _fit_internal( - self, trainable: TrainableType, param_space: Dict[str, Any] + self, trainable: TrainableType, param_space: Optional[Dict[str, Any]] ) -> ExperimentAnalysis: """Fitting for a fresh Tuner.""" args = { **self._get_tune_run_arguments(trainable), **dict( run_or_experiment=trainable, - config={**param_space}, + config=param_space, num_samples=self._tune_config.num_samples, search_alg=self._tune_config.search_alg, scheduler=self._tune_config.scheduler, @@ -575,7 +591,7 @@ def _fit_internal( return analysis def _fit_resume( - self, trainable: TrainableType, param_space: Dict[str, Any] + self, trainable: TrainableType, param_space: Optional[Dict[str, Any]] ) -> ExperimentAnalysis: """Fitting for a restored Tuner.""" if self._missing_params_error_message: @@ -599,7 +615,7 @@ def _fit_resume( **self._get_tune_run_arguments(trainable), **dict( run_or_experiment=trainable, - config={**param_space}, + config=param_space, resume=resume, search_alg=self._tune_config.search_alg, scheduler=self._tune_config.scheduler, diff --git a/python/ray/tune/syncer.py b/python/ray/tune/syncer.py index dc4a739c66a9..11fcc64b0d71 100644 --- a/python/ray/tune/syncer.py +++ b/python/ray/tune/syncer.py @@ -27,6 +27,7 @@ delete_at_uri, is_non_local_path_uri, ) +from ray.air.constants import LAZY_CHECKPOINT_MARKER_FILE from ray.exceptions import RayActorError from ray.tune import TuneError from ray.tune.callback import Callback @@ -57,6 +58,7 @@ "./checkpoint_tmp*", "./save_to_object*", "./rank_*", + f"./{LAZY_CHECKPOINT_MARKER_FILE}", ] diff --git a/python/ray/tune/tests/test_experiment.py b/python/ray/tune/tests/test_experiment.py index 729b1e3a9633..be3774b9ec2a 100644 --- a/python/ray/tune/tests/test_experiment.py +++ b/python/ray/tune/tests/test_experiment.py @@ -4,19 +4,30 @@ import ray from ray.air import CheckpointConfig from ray.tune import register_trainable, SyncConfig -from ray.tune.experiment import Experiment, _convert_to_experiment_list +from ray.tune.experiment import Experiment, Trial, _convert_to_experiment_list from ray.tune.error import TuneError from ray.tune.utils import diagnose_serialization -def test_remote_checkpoint_dir_with_query_string(): +def test_remote_checkpoint_dir_with_query_string(tmp_path): + sync_config = SyncConfig(syncer="auto", upload_dir="s3://bucket?scheme=http") experiment = Experiment( name="spam", run=lambda config: config, - sync_config=SyncConfig(syncer="auto", upload_dir="s3://bucket?scheme=http"), + sync_config=sync_config, ) assert experiment.remote_checkpoint_dir == "s3://bucket/spam?scheme=http" + trial = Trial( + "mock", + stub=True, + sync_config=sync_config, + experiment_dir_name="spam", + local_dir=str(tmp_path), + ) + trial.relative_logdir = "trial_dirname" + assert trial.remote_checkpoint_dir == "s3://bucket/spam/trial_dirname?scheme=http" + class ExperimentTest(unittest.TestCase): def tearDown(self): diff --git a/python/ray/tune/tests/test_syncer.py b/python/ray/tune/tests/test_syncer.py index 83d03b285da4..2e39dac28ecf 100644 --- a/python/ray/tune/tests/test_syncer.py +++ b/python/ray/tune/tests/test_syncer.py @@ -4,12 +4,11 @@ import subprocess import tempfile import time -from pathlib import Path from typing import List, Optional from unittest.mock import patch -import pytest import boto3 +import pytest from freezegun import freeze_time import ray @@ -23,6 +22,14 @@ from ray._private.test_utils import simulate_storage +@pytest.fixture +def ray_start_4_cpus(): + address_info = ray.init(num_cpus=4, configure_logging=False) + yield address_info + # The code after the yield will run as teardown code. + ray.shutdown() + + @pytest.fixture def ray_start_2_cpus(): address_info = ray.init(num_cpus=2, configure_logging=False) @@ -65,6 +72,23 @@ def temp_data_dirs(): shutil.rmtree(tmp_target) +@pytest.fixture +def mock_s3_bucket_uri(): + bucket_name = "test_syncer_bucket" + port = 5002 + region = "us-west-2" + with simulate_storage("s3", root=bucket_name, port=port, region=region) as s3_uri: + s3 = boto3.client( + "s3", region_name=region, endpoint_url=f"http://localhost:{port}" + ) + s3.create_bucket( + Bucket=bucket_name, + CreateBucketConfiguration={"LocationConstraint": region}, + ) + + yield s3_uri + + def assert_file(exists: bool, root: str, path: str): full_path = os.path.join(root, path) @@ -620,7 +644,7 @@ def test_syncer_serialize(temp_data_dirs): pickle.dumps(syncer) -def test_final_experiment_checkpoint_sync(tmpdir): +def test_final_experiment_checkpoint_sync(ray_start_2_cpus, tmpdir): class SlowSyncer(_DefaultSyncer): def __init__(self, **kwargs): super(_DefaultSyncer, self).__init__(**kwargs) @@ -676,28 +700,19 @@ def train_func(config): ) -def test_sync_folder_with_many_files_s3(tmpdir): +def test_sync_folder_with_many_files_s3(mock_s3_bucket_uri, tmp_path): + source_dir = tmp_path / "source" + check_dir = tmp_path / "check" + source_dir.mkdir() + check_dir.mkdir() + # Create 256 files to upload for i in range(256): - (tmpdir / str(i)).write_text("", encoding="utf-8") + (source_dir / str(i)).write_text("", encoding="utf-8") - root = "bucket_test_syncer/dir" - with simulate_storage("s3", root) as s3_uri: - # Upload to S3 - - s3 = boto3.client( - "s3", region_name="us-west-2", endpoint_url="http://localhost:5002" - ) - s3.create_bucket( - Bucket="bucket_test_syncer", - CreateBucketConfiguration={"LocationConstraint": "us-west-2"}, - ) - upload_to_uri(tmpdir, s3_uri) - - with tempfile.TemporaryDirectory() as download_dir: - download_from_uri(s3_uri, download_dir) - - assert (Path(download_dir) / "255").exists() + upload_to_uri(source_dir, mock_s3_bucket_uri) + download_from_uri(mock_s3_bucket_uri, check_dir) + assert (check_dir / "255").exists() def test_sync_folder_with_many_files_fs(tmpdir): @@ -713,6 +728,50 @@ def test_sync_folder_with_many_files_fs(tmpdir): assert (tmpdir / "255").exists() +def test_e2e_sync_to_s3(ray_start_4_cpus, mock_s3_bucket_uri, tmp_path): + """Tests an end to end Tune run with syncing to a mock s3 bucket.""" + download_dir = tmp_path / "upload_dir" + download_dir.mkdir() + + local_dir = str(tmp_path / "local_dir") + + exp_name = "test_e2e_sync_to_s3" + + def train_fn(config): + session.report({"score": 1}, checkpoint=Checkpoint.from_dict({"data": 1})) + + tuner = tune.Tuner( + train_fn, + param_space={"id": tune.grid_search([0, 1, 2, 3])}, + run_config=RunConfig( + name=exp_name, + local_dir=local_dir, + sync_config=tune.SyncConfig(upload_dir=mock_s3_bucket_uri), + ), + tune_config=tune.TuneConfig( + trial_dirname_creator=lambda t: str(t.config.get("id")) + ), + ) + result_grid = tuner.fit() + + # Download remote dir to do some sanity checks + download_from_uri(uri=mock_s3_bucket_uri, local_path=str(download_dir)) + + assert not result_grid.errors + + def get_remote_trial_dir(trial_id: int): + return os.path.join(download_dir, exp_name, str(trial_id)) + + # Check that each remote trial dir has a checkpoint + for result in result_grid: + trial_id = result.config["id"] + remote_dir = get_remote_trial_dir(trial_id) + num_checkpoints = len( + [file for file in os.listdir(remote_dir) if file.startswith("checkpoint_")] + ) + assert num_checkpoints == 1 + + if __name__ == "__main__": import sys diff --git a/python/ray/tune/tests/test_tune_restore.py b/python/ray/tune/tests/test_tune_restore.py index a582688f1435..82fa8e5347df 100644 --- a/python/ray/tune/tests/test_tune_restore.py +++ b/python/ray/tune/tests/test_tune_restore.py @@ -12,6 +12,7 @@ import time from typing import List import unittest +from unittest import mock import ray from ray import tune @@ -314,10 +315,8 @@ def testResourceUpdateInResume(self): ) assert len(analysis.trials) == 27 - # Unfinished trials' resources should be updated. + @mock.patch.dict(os.environ, {"TUNE_MAX_PENDING_TRIALS_PG": "1"}) def testConfigUpdateInResume(self): - os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "1" - class FakeDataset: def __init__(self, name): self.name = name diff --git a/python/ray/tune/tests/test_tuner.py b/python/ray/tune/tests/test_tuner.py index 843186693f7a..6f3e51f726c5 100644 --- a/python/ray/tune/tests/test_tuner.py +++ b/python/ray/tune/tests/test_tuner.py @@ -531,6 +531,33 @@ def train_func(config): assert artifact_data == f"{result.config['id']}" +def test_invalid_param_space(shutdown_only): + """Check that Tune raises an error on invalid param_space types.""" + + def trainable(config): + return {"metric": 1} + + with pytest.raises(ValueError): + Tuner(trainable, param_space="not allowed") + + from ray.tune.tune import _Config + + class CustomConfig(_Config): + def to_dict(self) -> dict: + return {"hparam": 1} + + with pytest.raises(ValueError): + Tuner(trainable, param_space="not allowed").fit() + + with pytest.raises(ValueError): + tune.run(trainable, config="not allowed") + + # Dict and custom _Config subclasses are fine + Tuner(trainable, param_space={}).fit() + Tuner(trainable, param_space=CustomConfig()).fit() + tune.run(trainable, config=CustomConfig()) + + if __name__ == "__main__": import sys diff --git a/python/ray/tune/tests/test_tuner_restore.py b/python/ray/tune/tests/test_tuner_restore.py index d46f5eb5668a..c62b03c481bd 100644 --- a/python/ray/tune/tests/test_tuner_restore.py +++ b/python/ray/tune/tests/test_tuner_restore.py @@ -96,7 +96,7 @@ def _train_fn_sometimes_failing(config): class _FailOnStats(Callback): """Fail when at least num_trials exist and num_finished have finished.""" - def __init__(self, num_trials: int, num_finished: int, delay: int = 1): + def __init__(self, num_trials: int, num_finished: int = 0, delay: int = 1): self.num_trials = num_trials self.num_finished = num_finished self.delay = delay @@ -574,7 +574,7 @@ def train_func_1(config): with pytest.raises(ValueError): tuner = Tuner.restore( str(tmpdir / "overwrite_trainable"), - overwrite_trainable="__fake", + trainable="__fake", resume_errored=True, ) @@ -586,7 +586,7 @@ def train_func_2(config): with pytest.raises(ValueError): tuner = Tuner.restore( str(tmpdir / "overwrite_trainable"), - overwrite_trainable=train_func_2, + trainable=train_func_2, resume_errored=True, ) @@ -599,7 +599,7 @@ def train_func_1(config): with caplog.at_level(logging.WARNING, logger="ray.tune.impl.tuner_internal"): tuner = Tuner.restore( str(tmpdir / "overwrite_trainable"), - overwrite_trainable=train_func_1, + trainable=train_func_1, resume_errored=True, ) assert "The trainable will be overwritten" in caplog.text @@ -680,7 +680,7 @@ def create_trainable_with_params(): tuner = Tuner.restore( str(tmp_path / exp_name), resume_errored=True, - overwrite_trainable=create_trainable_with_params(), + trainable=create_trainable_with_params(), ) results = tuner.fit() assert not results.errors @@ -1011,6 +1011,86 @@ def test_tuner_can_restore(tmp_path, upload_dir): assert not Tuner.can_restore(tmp_path / "new_exp") +def testParamSpaceOverwrite(tmp_path, monkeypatch): + """Test that overwriting param space on restore propagates new refs to existing + trials and newly generated trials.""" + + # Limit the number of generated trial configs -- so restore tests + # newly generated trials. + monkeypatch.setenv("TUNE_MAX_PENDING_TRIALS_PG", "1") + + class FakeDataset: + def __init__(self, name): + self.name = name + + def __repr__(self): + return f"" + + def train_fn(config): + raise RuntimeError("Failing!") + + param_space = { + "test": tune.grid_search( + [FakeDataset("1"), FakeDataset("2"), FakeDataset("3")] + ), + "test2": tune.grid_search( + [ + FakeDataset("4"), + FakeDataset("5"), + FakeDataset("6"), + FakeDataset("7"), + ] + ), + } + + tuner = Tuner( + train_fn, + param_space=param_space, + tune_config=TuneConfig(num_samples=1), + run_config=RunConfig( + local_dir=str(tmp_path), + name="param_space_overwrite", + callbacks=[_FailOnStats(num_trials=4, num_finished=2)], + ), + ) + with pytest.raises(RuntimeError): + tuner.fit() + + # Just suppress the error this time with a new trainable + def train_fn(config): + pass + + param_space = { + "test": tune.grid_search( + [FakeDataset("8"), FakeDataset("9"), FakeDataset("10")] + ), + "test2": tune.grid_search( + [ + FakeDataset("11"), + FakeDataset("12"), + FakeDataset("13"), + FakeDataset("14"), + ] + ), + } + + tuner = Tuner.restore( + str(tmp_path / "param_space_overwrite"), + trainable=train_fn, + param_space=param_space, + resume_errored=True, + ) + tuner._local_tuner._run_config.callbacks = None + result_grid = tuner.fit() + assert not result_grid.errors + assert len(result_grid) == 12 + + for r in result_grid: + # Make sure that test and test2 are updated. + assert r.config["test"].name in ["8", "9", "10"] + assert r.config["test2"].name in ["11", "12", "13", "14"] + + if __name__ == "__main__": import sys diff --git a/python/ray/tune/trainable/util.py b/python/ray/tune/trainable/util.py index 489faad4113e..94d6e0261c1d 100644 --- a/python/ray/tune/trainable/util.py +++ b/python/ray/tune/trainable/util.py @@ -14,6 +14,7 @@ PlacementGroupFactory, resource_dict_to_pg_factory, ) +from ray.air._internal.uri_utils import URI from ray.air.config import ScalingConfig from ray.tune.registry import _ParameterRegistry from ray.tune.resources import Resources @@ -194,7 +195,8 @@ def get_remote_storage_path( ``logdir`` is assumed to be a prefix of ``local_path``.""" rel_local_path = os.path.relpath(local_path, logdir) - return os.path.join(remote_checkpoint_dir, rel_local_path) + uri = URI(remote_checkpoint_dir) + return str(uri / rel_local_path) @DeveloperAPI diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index 3bafc7d00b92..5684e1a8e70e 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -1,3 +1,4 @@ +import abc import copy import datetime import logging @@ -61,6 +62,7 @@ from ray.util.annotations import PublicAPI from ray.util.queue import Queue + logger = logging.getLogger(__name__) @@ -173,6 +175,12 @@ def signal_interrupt_tune_run(sig: int, frame): return experiment_interrupted_event +class _Config(abc.ABC): + def to_dict(self) -> dict: + """Converts this configuration to a dict format.""" + raise NotImplementedError + + @PublicAPI def run( run_or_experiment: Union[str, Callable, Type], @@ -477,6 +485,14 @@ class and registered trainables. set_verbosity(verbose) config = config or {} + if isinstance(config, _Config): + config = config.to_dict() + if not isinstance(config, dict): + raise ValueError( + "The `config` passed to `tune.run()` must be a dict. " + f"Got '{type(config)}' instead." + ) + sync_config = sync_config or SyncConfig() sync_config.validate_upload_dir() diff --git a/python/ray/tune/tuner.py b/python/ray/tune/tuner.py index c15ade6cb597..2ca1a253cc3b 100644 --- a/python/ray/tune/tuner.py +++ b/python/ray/tune/tuner.py @@ -172,6 +172,7 @@ def restore( overwrite_trainable: Optional[ Union[str, Callable, Type[Trainable], "BaseTrainer"] ] = None, + param_space: Optional[Dict[str, Any]] = None, ) -> "Tuner": """Restores Tuner after a previously failed run. @@ -202,6 +203,15 @@ def restore( This should be the same trainable that was used to initialize the original Tuner. NOTE: Starting in 2.5, this will be a required parameter. + param_space: The same `param_space` that was passed to + the original Tuner. This can be optionally re-specified due + to the `param_space` potentially containing Ray object + references (tuning over Ray Datasets or tuning over + several `ray.put` object references). **Tune expects the + `param_space` to be unmodified**, and the only part that + will be used during restore are the updated object references. + Changing the hyperparameter search space then resuming is NOT + supported by this API. resume_unfinished: If True, will continue to run unfinished trials. resume_errored: If True, will re-schedule errored trials and try to restore from their latest checkpoints. @@ -242,6 +252,7 @@ def restore( restore_path=path, resume_config=resume_config, trainable=trainable, + param_space=param_space, ) return Tuner(_tuner_internal=tuner_internal) else: @@ -251,6 +262,7 @@ def restore( restore_path=path, resume_config=resume_config, trainable=trainable, + param_space=param_space, ) return Tuner(_tuner_internal=tuner_internal) diff --git a/rllib/BUILD b/rllib/BUILD index 6cae4c62ca0a..435f1b2399d3 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1135,10 +1135,10 @@ py_test( py_test( - name = "test_ppo_rl_trainer", + name = "test_ppo_learner", tags = ["team:rllib", "algorithms_dir"], size = "medium", - srcs = ["algorithms/ppo/tests/test_ppo_rl_trainer.py"] + srcs = ["algorithms/ppo/tests/test_ppo_learner.py"] ) # PPO Reproducibility @@ -1859,32 +1859,32 @@ py_test( ) py_test( - name = "test_trainer_runner", + name = "test_learner_group", tags = ["team:rllib", "multi_gpu", "exclusive"], size = "large", - srcs = ["core/rl_trainer/tests/test_trainer_runner.py"] + srcs = ["core/learner/tests/test_learner_group.py"] ) py_test( - name = "test_trainer_runner_config", + name = "test_learner_group_config", tags = ["team:rllib", "core"], size = "medium", - srcs = ["core/rl_trainer/tests/test_trainer_runner_config.py"] + srcs = ["core/learner/tests/test_learner_group_config.py"] ) py_test( - name = "test_rl_trainer", + name = "test_learner", tags = ["team:rllib", "core"], size = "medium", - srcs = ["core/rl_trainer/tests/test_rl_trainer.py"] + srcs = ["core/learner/tests/test_learner.py"] ) -# TODO (Kourosh): to be removed in favor of test_rl_trainer.py +# TODO (Kourosh): to be removed in favor of test_learner.py py_test( - name = "test_torch_rl_trainer", + name = "test_torch_learner", tags = ["team:rllib", "core"], size = "medium", - srcs = ["core/rl_trainer/torch/tests/test_torch_rl_trainer.py"] + srcs = ["core/learner/torch/tests/test_torch_learner.py"] ) py_test( @@ -3855,40 +3855,40 @@ py_test( ) # -------------------------------------------------------------------- -# examples/rl_trainer directory +# examples/learner directory # # # Description: These are RLlib tests for the new multi-gpu enabled -# training stack via RLTrainers. +# training stack via Learners. # # NOTE: Add tests alphabetically to this list. # -------------------------------------------------------------------- py_test( - name = "examples/rl_trainer/multi_agent_cartpole_ppo_torch", - main = "examples/rl_trainer/multi_agent_cartpole_ppo.py", + name = "examples/learner/multi_agent_cartpole_ppo_torch", + main = "examples/learner/multi_agent_cartpole_ppo.py", tags = ["team:rllib", "exclusive", "examples"], size = "medium", - srcs = ["examples/rl_trainer/multi_agent_cartpole_ppo.py"], + srcs = ["examples/learner/multi_agent_cartpole_ppo.py"], args = ["--as-test", "--framework=torch", "--num-gpus=0"] ) py_test( - name = "examples/rl_trainer/multi_agent_cartpole_ppo_torch_gpu", - main = "examples/rl_trainer/multi_agent_cartpole_ppo.py", + name = "examples/learner/multi_agent_cartpole_ppo_torch_gpu", + main = "examples/learner/multi_agent_cartpole_ppo.py", tags = ["team:rllib", "exclusive", "examples", "gpu"], size = "medium", - srcs = ["examples/rl_trainer/multi_agent_cartpole_ppo.py"], + srcs = ["examples/learner/multi_agent_cartpole_ppo.py"], args = ["--as-test", "--framework=torch", "--num-gpus=1"] ) py_test( - name = "examples/rl_trainer/multi_agent_cartpole_ppo_torch_multi_gpu", - main = "examples/rl_trainer/multi_agent_cartpole_ppo.py", + name = "examples/learner/multi_agent_cartpole_ppo_torch_multi_gpu", + main = "examples/learner/multi_agent_cartpole_ppo.py", tags = ["team:rllib", "exclusive", "examples", "multi_gpu"], size = "medium", - srcs = ["examples/rl_trainer/multi_agent_cartpole_ppo.py"], + srcs = ["examples/learner/multi_agent_cartpole_ppo.py"], args = ["--as-test", "--framework=torch", "--num-gpus=2"] ) diff --git a/rllib/algorithms/algorithm.py b/rllib/algorithms/algorithm.py index 8ef170668379..d4b94ebf22b0 100644 --- a/rllib/algorithms/algorithm.py +++ b/rllib/algorithms/algorithm.py @@ -685,8 +685,8 @@ def setup(self, config: AlgorithmConfig) -> None: # Need to add back method_type in case Algorithm is restored from checkpoint method_config["type"] = method_type - self.trainer_runner = None - if self.config._enable_rl_trainer_api: + self.learner_group = None + if self.config._enable_learner_api: # TODO (Kourosh): This is an interim solution where policies and modules # co-exist. In this world we have both policy_map and MARLModule that need # to be consistent with one another. To make a consistent parity between @@ -694,12 +694,12 @@ def setup(self, config: AlgorithmConfig) -> None: # MARLModule from the RLModule within each policy. local_worker = self.workers.local_worker() module_spec = local_worker.marl_module_spec - trainer_runner_config = self.config.get_trainer_runner_config(module_spec) - self.trainer_runner = trainer_runner_config.build() + learner_group_config = self.config.get_learner_group_config(module_spec) + self.learner_group = learner_group_config.build() # sync the weights from local rollout worker to trainers weights = local_worker.get_weights() - self.trainer_runner.set_weights(weights) + self.learner_group.set_weights(weights) # Run `on_algorithm_init` callback after initialization is done. self.callbacks.on_algorithm_init(algorithm=self) @@ -1345,8 +1345,8 @@ def training_step(self) -> ResultDict: # cases should use the multi-GPU optimizer, even if only using 1 GPU). # TODO: (sven) rename MultiGPUOptimizer into something more # meaningful. - if self.config._enable_rl_trainer_api: - train_results = self.trainer_runner.update(train_batch) + if self.config._enable_learner_api: + train_results = self.learner_group.update(train_batch) elif self.config.get("simple_optimizer") is True: train_results = train_one_step(self, train_batch) else: @@ -1361,12 +1361,12 @@ def training_step(self) -> ResultDict: "timestep": self._counters[NUM_ENV_STEPS_SAMPLED], } with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: - # TODO (Avnish): Implement this on trainer_runner.get_weights(). + # TODO (Avnish): Implement this on learner_group.get_weights(). # TODO (Kourosh): figure out how we are going to sync MARLModule # weights to MARLModule weights under the policy_map objects? from_worker_or_trainer = None - if self.config._enable_rl_trainer_api: - from_worker_or_trainer = self.trainer_runner + if self.config._enable_learner_api: + from_worker_or_trainer = self.learner_group self.workers.sync_weights( from_worker_or_trainer=from_worker_or_trainer, policies=list(train_results.keys()), @@ -2120,7 +2120,7 @@ def default_resource_request( eval_cf.freeze() # resources for local worker - if cf._enable_rl_trainer_api: + if cf._enable_learner_api: local_worker = {"CPU": cf.num_cpus_for_local_worker, "GPU": 0} else: local_worker = { @@ -2170,24 +2170,24 @@ def default_resource_request( bundles += rollout_workers + evaluation_bundle - if cf._enable_rl_trainer_api: + if cf._enable_learner_api: # resources for the trainer - if cf.num_trainer_workers == 0: - # if num_trainer_workers is 0, then we need to allocate one gpu if - # num_gpus_per_trainer_worker is greater than 0. + if cf.num_learner_workers == 0: + # if num_learner_workers is 0, then we need to allocate one gpu if + # num_gpus_per_learner_worker is greater than 0. trainer_bundle = [ { - "CPU": cf.num_cpus_per_trainer_worker, - "GPU": cf.num_gpus_per_trainer_worker, + "CPU": cf.num_cpus_per_learner_worker, + "GPU": cf.num_gpus_per_learner_worker, } ] else: trainer_bundle = [ { - "CPU": cf.num_cpus_per_trainer_worker, - "GPU": cf.num_gpus_per_trainer_worker, + "CPU": cf.num_cpus_per_learner_worker, + "GPU": cf.num_gpus_per_learner_worker, } - for _ in range(cf.num_trainer_workers) + for _ in range(cf.num_learner_workers) ] bundles += trainer_bundle diff --git a/rllib/algorithms/algorithm_config.py b/rllib/algorithms/algorithm_config.py index 2fc03b08d735..9ac741cacdbb 100644 --- a/rllib/algorithms/algorithm_config.py +++ b/rllib/algorithms/algorithm_config.py @@ -16,9 +16,9 @@ import ray from ray.rllib.algorithms.callbacks import DefaultCallbacks -from ray.rllib.core.rl_trainer.rl_trainer import RLTrainerHPs -from ray.rllib.core.rl_trainer.trainer_runner_config import ( - TrainerRunnerConfig, +from ray.rllib.core.learner.learner import LearnerHPs +from ray.rllib.core.learner.learner_group_config import ( + LearnerGroupConfig, ModuleSpec, ) from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec @@ -61,6 +61,7 @@ ResultDict, SampleBatchType, ) +from ray.tune.tune import _Config from ray.tune.logger import Logger from ray.tune.registry import get_trainable_cls from ray.tune.result import TRIAL_INFO @@ -92,7 +93,7 @@ if TYPE_CHECKING: from ray.rllib.algorithms.algorithm import Algorithm - from ray.rllib.core.rl_trainer import RLTrainer + from ray.rllib.core.learner import Learner logger = logging.getLogger(__name__) @@ -115,7 +116,7 @@ def _resolve_class_path(module) -> Type: return getattr(module, class_name) -class AlgorithmConfig: +class AlgorithmConfig(_Config): """A RLlib AlgorithmConfig builds an RLlib Algorithm from a given configuration. Example: @@ -243,9 +244,9 @@ def __init__(self, algo_class=None): self.num_gpus_per_worker = 0 self._fake_gpus = False self.num_cpus_for_local_worker = 1 - self.num_trainer_workers = 0 - self.num_gpus_per_trainer_worker = 0 - self.num_cpus_per_trainer_worker = 1 + self.num_learner_workers = 0 + self.num_gpus_per_learner_worker = 0 + self.num_cpus_per_learner_worker = 1 self.local_gpu_idx = 0 self.custom_resources_per_worker = {} self.placement_strategy = "PACK" @@ -321,12 +322,12 @@ def __init__(self, algo_class=None): self.model = copy.deepcopy(MODEL_DEFAULTS) self.optimizer = {} self.max_requests_in_flight_per_sampler_worker = 2 - self.rl_trainer_class = None - self._enable_rl_trainer_api = False + self.learner_class = None + self._enable_learner_api = False # experimental: this will contain the hyper-parameters that are passed to the - # RLTrainer, for computing loss, etc. New algorithms have to set this to their + # Learner, for computing loss, etc. New algorithms have to set this to their # own default. .training() will modify the fields of this object. - self._rl_trainer_hps = RLTrainerHPs() + self._learner_hps = LearnerHPs() # `self.callbacks()` self.callbacks_class = DefaultCallbacks @@ -453,8 +454,8 @@ def __init__(self, algo_class=None): self.no_done_at_end = DEPRECATED_VALUE @property - def rl_trainer_hps(self) -> RLTrainerHPs: - return self._rl_trainer_hps + def learner_hps(self) -> LearnerHPs: + return self._learner_hps def to_dict(self) -> AlgorithmConfigDict: """Converts all settings into a legacy config dict for backward compatibility. @@ -791,7 +792,7 @@ def validate(self) -> None: elif self.num_gpus > 1: # TODO: AlphaStar uses >1 GPUs differently (1 per policy actor), so this is # ok for tf2 here. - # Remove this hacky check, once we have fully moved to the RLTrainer API. + # Remove this hacky check, once we have fully moved to the Learner API. if self.framework_str == "tf2" and type(self).__name__ != "AlphaStar": raise ValueError( "`num_gpus` > 1 not supported yet for " @@ -890,17 +891,17 @@ def validate(self) -> None: "SingleAgentRLModuleSpec or MultiAgentRLModuleSpec." ) - # make sure the resource requirements for trainer runner is valid - if self.num_trainer_workers == 0 and self.num_gpus_per_worker > 1: + # make sure the resource requirements for learner_group is valid + if self.num_learner_workers == 0 and self.num_gpus_per_worker > 1: raise ValueError( "num_gpus_per_worker must be 0 (cpu) or 1 (gpu) when using local mode " - "(i.e. num_trainer_workers = 0)" + "(i.e. num_learner_workers = 0)" ) - # resolve rl_trainer class - if self._enable_rl_trainer_api and self.rl_trainer_class is None: - rl_trainer_class_path = self.get_default_rl_trainer_class() - self.rl_trainer_class = _resolve_class_path(rl_trainer_class_path) + # resolve learner class + if self._enable_learner_api and self.learner_class is None: + learner_class_path = self.get_default_learner_class() + self.learner_class = _resolve_class_path(learner_class_path) def build( self, @@ -973,9 +974,9 @@ def resources( num_cpus_per_worker: Optional[Union[float, int]] = NotProvided, num_gpus_per_worker: Optional[Union[float, int]] = NotProvided, num_cpus_for_local_worker: Optional[int] = NotProvided, - num_trainer_workers: Optional[int] = NotProvided, - num_cpus_per_trainer_worker: Optional[Union[float, int]] = NotProvided, - num_gpus_per_trainer_worker: Optional[Union[float, int]] = NotProvided, + num_learner_workers: Optional[int] = NotProvided, + num_cpus_per_learner_worker: Optional[Union[float, int]] = NotProvided, + num_gpus_per_learner_worker: Optional[Union[float, int]] = NotProvided, local_gpu_idx: Optional[int] = NotProvided, custom_resources_per_worker: Optional[dict] = NotProvided, placement_strategy: Optional[str] = NotProvided, @@ -996,18 +997,18 @@ def resources( fractional. This is usually needed only if your env itself requires a GPU (i.e., it is a GPU-intensive video game), or model inference is unusually expensive. - num_trainer_workers: Number of workers used for training. A value of 0 + num_learner_workers: Number of workers used for training. A value of 0 means training will take place on a local worker on head node CPUs or 1 - GPU (determined by `num_gpus_per_trainer_worker`). For multi-gpu + GPU (determined by `num_gpus_per_learner_worker`). For multi-gpu training, set number of workers greater than 1 and set - `num_gpus_per_trainer_worker` accordingly (e.g. 4 GPUs total, and model - needs 2 GPUs: `num_trainer_workers = 2` and - `num_gpus_per_trainer_worker = 2`) - num_cpus_per_trainer_worker: Number of CPUs allocated per trainer worker. - Only necessary for custom processing pipeline inside each RLTrainer - requiring multiple CPU cores. Ignored if `num_trainer_workers = 0`. - num_gpus_per_trainer_worker: Number of GPUs allocated per worker. If - `num_trainer_workers = 0`, any value greater than 0 will run the + `num_gpus_per_learner_worker` accordingly (e.g. 4 GPUs total, and model + needs 2 GPUs: `num_learner_workers = 2` and + `num_gpus_per_learner_worker = 2`) + num_cpus_per_learner_worker: Number of CPUs allocated per trainer worker. + Only necessary for custom processing pipeline inside each Learner + requiring multiple CPU cores. Ignored if `num_learner_workers = 0`. + num_gpus_per_learner_worker: Number of GPUs allocated per worker. If + `num_learner_workers = 0`, any value greater than 0 will run the training on a single GPU on the head node, while a value of 0 will run the training on head node CPU cores. local_gpu_idx: if num_gpus_per_worker > 0, and num_workers<2, then this gpu @@ -1054,12 +1055,12 @@ def resources( if placement_strategy is not NotProvided: self.placement_strategy = placement_strategy - if num_trainer_workers is not NotProvided: - self.num_trainer_workers = num_trainer_workers - if num_cpus_per_trainer_worker is not NotProvided: - self.num_cpus_per_trainer_worker = num_cpus_per_trainer_worker - if num_gpus_per_trainer_worker is not NotProvided: - self.num_gpus_per_trainer_worker = num_gpus_per_trainer_worker + if num_learner_workers is not NotProvided: + self.num_learner_workers = num_learner_workers + if num_cpus_per_learner_worker is not NotProvided: + self.num_cpus_per_learner_worker = num_cpus_per_learner_worker + if num_gpus_per_learner_worker is not NotProvided: + self.num_gpus_per_learner_worker = num_gpus_per_learner_worker if local_gpu_idx is not NotProvided: self.local_gpu_idx = local_gpu_idx @@ -1450,8 +1451,8 @@ def training( model: Optional[dict] = NotProvided, optimizer: Optional[dict] = NotProvided, max_requests_in_flight_per_sampler_worker: Optional[int] = NotProvided, - _enable_rl_trainer_api: Optional[bool] = NotProvided, - rl_trainer_class: Optional[Type["RLTrainer"]] = NotProvided, + _enable_learner_api: Optional[bool] = NotProvided, + learner_class: Optional[Type["Learner"]] = NotProvided, ) -> "AlgorithmConfig": """Sets the training related configuration. @@ -1475,7 +1476,7 @@ def training( dashboard. If you're seeing that the object store is filling up, turn down the number of remote requests in flight, or enable compression in your experiment of timesteps. - _enable_rl_trainer_api: Whether to enable the TrainerRunner and RLTrainer + _enable_learner_api: Whether to enable the LearnerGroup and Learner for training. This API uses ray.train to run the training loop which allows for a more flexible distributed training. @@ -1518,10 +1519,10 @@ def training( self.max_requests_in_flight_per_sampler_worker = ( max_requests_in_flight_per_sampler_worker ) - if _enable_rl_trainer_api is not NotProvided: - self._enable_rl_trainer_api = _enable_rl_trainer_api - if rl_trainer_class is not NotProvided: - self.rl_trainer_class = rl_trainer_class + if _enable_learner_api is not NotProvided: + self._enable_learner_api = _enable_learner_api + if learner_class is not NotProvided: + self.learner_class = learner_class return self @@ -2662,15 +2663,15 @@ def get_default_rl_module_spec(self) -> ModuleSpec: """ raise NotImplementedError - def get_default_rl_trainer_class(self) -> Union[Type["RLTrainer"], str]: - """Returns the RLTrainer class to use for this algorithm. + def get_default_learner_class(self) -> Union[Type["Learner"], str]: + """Returns the Learner class to use for this algorithm. - Override this method in the sub-class to return the RLTrainer class type given + Override this method in the sub-class to return the Learner class type given the input framework. Returns: - The RLTrainer class to use for this algorithm either as a class type or as - a string (e.g. ray.rllib.core.rl_trainer.testing.torch.BCTrainer). + The Learner class to use for this algorithm either as a class type or as + a string (e.g. ray.rllib.core.learner.testing.torch.BCTrainer). """ raise NotImplementedError @@ -2739,27 +2740,27 @@ def get_marl_module_spec( return marl_module_spec - def get_trainer_runner_config(self, module_spec: ModuleSpec) -> TrainerRunnerConfig: + def get_learner_group_config(self, module_spec: ModuleSpec) -> LearnerGroupConfig: if not self._is_frozen: raise ValueError( - "Cannot call `get_trainer_runner_config()` on an unfrozen " + "Cannot call `get_learner_group_config()` on an unfrozen " "AlgorithmConfig! Please call `freeze()` first." ) config = ( - TrainerRunnerConfig() + LearnerGroupConfig() .module(module_spec) - .trainer( - trainer_class=self.rl_trainer_class, + .learner( + learner_class=self.learner_class, # TODO (Kourosh): optimizer config can now be more complicated. optimizer_config={"lr": self.lr}, - rl_trainer_hps=self.rl_trainer_hps, + learner_hps=self.learner_hps, ) .resources( - num_trainer_workers=self.num_trainer_workers, - num_cpus_per_trainer_worker=self.num_cpus_per_trainer_worker, - num_gpus_per_trainer_worker=self.num_gpus_per_trainer_worker, + num_learner_workers=self.num_learner_workers, + num_cpus_per_learner_worker=self.num_cpus_per_learner_worker, + num_gpus_per_learner_worker=self.num_gpus_per_learner_worker, local_gpu_idx=self.local_gpu_idx, ) .framework(eager_tracing=self.eager_tracing) diff --git a/rllib/algorithms/ppo/ppo.py b/rllib/algorithms/ppo/ppo.py index 0fcf54c11327..0fdcf5b5c587 100644 --- a/rllib/algorithms/ppo/ppo.py +++ b/rllib/algorithms/ppo/ppo.py @@ -16,7 +16,7 @@ from ray.rllib.algorithms.algorithm import Algorithm from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided from ray.rllib.algorithms.pg import PGConfig -from ray.rllib.algorithms.ppo.ppo_rl_trainer_config import PPORLTrainerHPs +from ray.rllib.algorithms.ppo.ppo_learner_config import PPOLearnerHPs from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec from ray.rllib.execution.rollout_ops import ( standardize_fields, @@ -44,7 +44,7 @@ ) if TYPE_CHECKING: - from ray.rllib.core.rl_trainer.rl_trainer import RLTrainer + from ray.rllib.core.learner.learner import Learner logger = logging.getLogger(__name__) @@ -93,7 +93,7 @@ def __init__(self, algo_class=None): # fmt: off # __sphinx_doc_begin__ # PPO specific settings: - self._rl_trainer_hps = PPORLTrainerHPs() + self._learner_hps = PPOLearnerHPs() self.use_critic = True self.use_gae = True self.lambda_ = 1.0 @@ -137,13 +137,13 @@ def get_default_rl_module_spec(self) -> SingleAgentRLModuleSpec: raise ValueError(f"The framework {self.framework_str} is not supported.") @override(AlgorithmConfig) - def get_default_rl_trainer_class(self) -> Union[Type["RLTrainer"], str]: + def get_default_learner_class(self) -> Union[Type["Learner"], str]: if self.framework_str == "torch": - from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_trainer import ( - PPOTorchRLTrainer, + from ray.rllib.algorithms.ppo.torch.ppo_torch_learner import ( + PPOTorchLearner, ) - return PPOTorchRLTrainer + return PPOTorchLearner else: raise ValueError(f"The framework {self.framework_str} is not supported.") @@ -217,16 +217,16 @@ def training( self.lr_schedule = lr_schedule if use_critic is not NotProvided: self.use_critic = use_critic - # TODO (Kourosh) This is experimental. Set rl_trainer_hps parameters as + # TODO (Kourosh) This is experimental. Set learner_hps parameters as # well. Don't forget to remove .use_critic from algorithm config. - self._rl_trainer_hps.use_critic = use_critic + self._learner_hps.use_critic = use_critic if use_gae is not NotProvided: self.use_gae = use_gae if lambda_ is not NotProvided: self.lambda_ = lambda_ if kl_coeff is not NotProvided: self.kl_coeff = kl_coeff - self._rl_trainer_hps.kl_coeff = kl_coeff + self._learner_hps.kl_coeff = kl_coeff if sgd_minibatch_size is not NotProvided: self.sgd_minibatch_size = sgd_minibatch_size if num_sgd_iter is not NotProvided: @@ -235,24 +235,24 @@ def training( self.shuffle_sequences = shuffle_sequences if vf_loss_coeff is not NotProvided: self.vf_loss_coeff = vf_loss_coeff - self._rl_trainer_hps.vf_loss_coeff = vf_loss_coeff + self._learner_hps.vf_loss_coeff = vf_loss_coeff if entropy_coeff is not NotProvided: self.entropy_coeff = entropy_coeff - self._rl_trainer_hps.entropy_coeff = entropy_coeff + self._learner_hps.entropy_coeff = entropy_coeff if entropy_coeff_schedule is not NotProvided: self.entropy_coeff_schedule = entropy_coeff_schedule - self._rl_trainer_hps.entropy_coeff_schedule = entropy_coeff_schedule + self._learner_hps.entropy_coeff_schedule = entropy_coeff_schedule if clip_param is not NotProvided: self.clip_param = clip_param - self._rl_trainer_hps.clip_param = clip_param + self._learner_hps.clip_param = clip_param if vf_clip_param is not NotProvided: self.vf_clip_param = vf_clip_param - self._rl_trainer_hps.vf_clip_param = vf_clip_param + self._learner_hps.vf_clip_param = vf_clip_param if grad_clip is not NotProvided: self.grad_clip = grad_clip if kl_target is not NotProvided: self.kl_target = kl_target - self._rl_trainer_hps.kl_target = kl_target + self._learner_hps.kl_target = kl_target return self @@ -393,7 +393,7 @@ def training_step(self) -> ResultDict: # Standardize advantages train_batch = standardize_fields(train_batch, ["advantages"]) # Train - if self.config._enable_rl_trainer_api: + if self.config._enable_learner_api: # TODO (Kourosh) Clearly define what train_batch_size # vs. sgd_minibatch_size and num_sgd_iter is in the config. # TODO (Kourosh) Do this inside the RL Trainer so @@ -401,7 +401,7 @@ def training_step(self) -> ResultDict: # communication between driver and the remote # trainer workers - train_results = self.trainer_runner.update( + train_results = self.learner_group.update( train_batch, minibatch_size=self.config.sgd_minibatch_size, num_iters=self.config.num_sgd_iter, @@ -412,7 +412,7 @@ def training_step(self) -> ResultDict: else: train_results = multi_gpu_train_one_step(self, train_batch) - if self.config._enable_rl_trainer_api: + if self.config._enable_learner_api: # the train results's loss keys are pids to their loss values. But we also # return a total_loss key at the same level as the pid keys. So we need to # subtract that to get the total set of pids to update. @@ -440,19 +440,19 @@ def training_step(self) -> ResultDict: with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: if self.workers.num_remote_workers() > 0: from_worker_or_trainer = None - if self.config._enable_rl_trainer_api: - # sync weights from trainer_runner to all rollout workers - from_worker_or_trainer = self.trainer_runner + if self.config._enable_learner_api: + # sync weights from learner_group to all rollout workers + from_worker_or_trainer = self.learner_group self.workers.sync_weights( from_worker_or_trainer=from_worker_or_trainer, policies=policies_to_update, global_vars=global_vars, ) - elif self.config._enable_rl_trainer_api: - weights = self.trainer_runner.get_weights() + elif self.config._enable_learner_api: + weights = self.learner_group.get_weights() self.workers.local_worker().set_weights(weights) - if self.config._enable_rl_trainer_api: + if self.config._enable_learner_api: kl_dict = { # TODO (Kourosh): Train results don't match the old format. The thing # that used to be under `kl` is now under `mean_kl_loss`. Fix this. Do @@ -461,7 +461,7 @@ def training_step(self) -> ResultDict: for pid in policies_to_update } # triggers a special update method on RLOptimizer to update the KL values. - self.trainer_runner.additional_update( + self.learner_group.additional_update( sampled_kl_values=kl_dict, timestep=self._counters[NUM_AGENT_STEPS_SAMPLED], ) diff --git a/rllib/algorithms/ppo/ppo_rl_trainer_config.py b/rllib/algorithms/ppo/ppo_learner_config.py similarity index 81% rename from rllib/algorithms/ppo/ppo_rl_trainer_config.py rename to rllib/algorithms/ppo/ppo_learner_config.py index 2f616ca45787..e6850efa6b6a 100644 --- a/rllib/algorithms/ppo/ppo_rl_trainer_config.py +++ b/rllib/algorithms/ppo/ppo_learner_config.py @@ -1,11 +1,11 @@ from dataclasses import dataclass from typing import List, Optional, Union -from ray.rllib.core.rl_trainer.rl_trainer import RLTrainerHPs +from ray.rllib.core.learner.learner import LearnerHPs @dataclass -class PPORLTrainerHPs(RLTrainerHPs): +class PPOLearnerHPs(LearnerHPs): """Hyperparameters for the PPO RL Trainer""" kl_coeff: float = 0.2 @@ -16,6 +16,6 @@ class PPORLTrainerHPs(RLTrainerHPs): entropy_coeff: float = 0.0 vf_loss_coeff: float = 1.0 - # experimental placeholder for things that could be part of the base RLTrainerHPs + # experimental placeholder for things that could be part of the base LearnerHPs lr_schedule: Optional[List[List[Union[int, float]]]] = None entropy_coeff_schedule: Optional[List[List[Union[int, float]]]] = None diff --git a/rllib/algorithms/ppo/tests/test_ppo_rl_trainer.py b/rllib/algorithms/ppo/tests/test_ppo_learner.py similarity index 87% rename from rllib/algorithms/ppo/tests/test_ppo_rl_trainer.py rename to rllib/algorithms/ppo/tests/test_ppo_learner.py index c0914f57f5fe..e4a6069a9676 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_rl_trainer.py +++ b/rllib/algorithms/ppo/tests/test_ppo_learner.py @@ -84,11 +84,11 @@ def test_loss(self): policy_loss = policy.loss(policy.model, policy.dist_class, train_batch) - config.training(_enable_rl_trainer_api=True) + config.training(_enable_learner_api=True) config.validate() config.freeze() - trainer_runner_config = config.get_trainer_runner_config( + learner_group_config = config.get_learner_group_config( SingleAgentRLModuleSpec( module_class=config.rl_module_spec.module_class, observation_space=policy.observation_space, @@ -96,17 +96,17 @@ def test_loss(self): model_config=policy.config["model"], ) ) - trainer_runner = trainer_runner_config.build() + learner_group = learner_group_config.build() - # load the policy weights into the trainer runner + # load the policy weights into the learner_group state_dict = {"module_state": {"default_policy": policy.get_weights()}} state_dict = convert_to_torch_tensor(state_dict) - trainer_runner.set_state(state_dict) - results = trainer_runner.update(train_batch.as_multi_agent()) + learner_group.set_state(state_dict) + results = learner_group.update(train_batch.as_multi_agent()) - trainer_runner_loss = results["loss"]["total_loss"] + learner_group_loss = results["loss"]["total_loss"] - check(trainer_runner_loss, policy_loss) + check(learner_group_loss, policy_loss) if __name__ == "__main__": diff --git a/rllib/algorithms/ppo/tf/ppo_tf_policy_rlm.py b/rllib/algorithms/ppo/tf/ppo_tf_policy_rlm.py index b320165bfe7e..f9b55a56b7d1 100644 --- a/rllib/algorithms/ppo/tf/ppo_tf_policy_rlm.py +++ b/rllib/algorithms/ppo/tf/ppo_tf_policy_rlm.py @@ -52,7 +52,7 @@ class PPOTfPolicyWithRLModule( # TODO: In the future we will deprecate doing all phases of training, exploration, # and inference via one policy abstraction. Instead, we will use separate # abstractions for each phase. For training (i.e. gradient updates, given the - # sample that have been collected) we will use RLTrainer which will own one or + # sample that have been collected) we will use Learner which will own one or # possibly many RLModules, and RLOptimizer. For exploration, we will use RLSampler # which will own RLModule, and RLTrajectoryProcessor. The exploration and inference # phase details are TBD but the whole point is to make rllib extremely modular. diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_trainer.py b/rllib/algorithms/ppo/torch/ppo_torch_learner.py similarity index 94% rename from rllib/algorithms/ppo/torch/ppo_torch_rl_trainer.py rename to rllib/algorithms/ppo/torch/ppo_torch_learner.py index 9b645b5ff6c2..a66fd21ad6b4 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_trainer.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_learner.py @@ -1,7 +1,7 @@ import logging from typing import Mapping, Any -from ray.rllib.core.rl_trainer.torch.torch_rl_trainer import TorchRLTrainer +from ray.rllib.core.learner.torch.torch_learner import TorchLearner from ray.rllib.evaluation.postprocessing import Postprocessing from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.framework import try_import_torch @@ -16,8 +16,8 @@ logger = logging.getLogger(__name__) -class PPOTorchRLTrainer(TorchRLTrainer): - """Implements PPO loss / update logic on top of TorchRLTrainer. +class PPOTorchLearner(TorchLearner): + """Implements PPO loss / update logic on top of TorchLearner. This class implements the ppo loss under `_compute_loss_per_module()` and the additional non-gradient based updates such as KL-coeff and learning rate updates @@ -30,13 +30,13 @@ def __init__(self, *args, **kwargs): # TODO (Kourosh): Move these failures to config.validate() or support them. self.entropy_coeff_scheduler = None if self.hps.entropy_coeff_schedule: - raise ValueError("entropy_coeff_schedule is not supported in RLTrainer yet") + raise ValueError("entropy_coeff_schedule is not supported in Learner yet") # TODO (Kourosh): Create a way on the base class for users to define arbitrary # schedulers for learning rates. self.lr_scheduler = None if self.hps.lr_schedule: - raise ValueError("lr_schedule is not supported in RLTrainer yet") + raise ValueError("lr_schedule is not supported in Learner yet") # TODO (Kourosh): We can still use mix-ins in the new design. Do we want that? # Most likely not. I rather be specific about everything. kl_coeff is a @@ -45,7 +45,7 @@ def __init__(self, *args, **kwargs): self.kl_coeff = self.hps.kl_coeff self.kl_target = self.hps.kl_target - @override(TorchRLTrainer) + @override(TorchLearner) def compute_loss_per_module( self, module_id: str, batch: SampleBatch, fwd_out: Mapping[str, TensorType] ) -> TensorType: @@ -125,7 +125,7 @@ def compute_loss_per_module( "mean_kl_loss": mean_kl_loss, } - @override(TorchRLTrainer) + @override(TorchLearner) def additional_update_per_module( self, module_id: str, sampled_kl_values: dict, timestep: int ) -> Mapping[str, Any]: diff --git a/rllib/algorithms/ppo/torch/ppo_torch_policy_rlm.py b/rllib/algorithms/ppo/torch/ppo_torch_policy_rlm.py index 84411d951652..4405ebb4b87d 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_policy_rlm.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_policy_rlm.py @@ -54,7 +54,7 @@ class PPOTorchPolicyWithRLModule( # TODO: In the future we will deprecate doing all phases of training, exploration, # and inference via one policy abstraction. Instead, we will use separate # abstractions for each phase. For training (i.e. gradient updates, given the - # sample that have been collected) we will use RLTrainer which will own one or + # sample that have been collected) we will use Learner which will own one or # possibly many RLModules, and RLOptimizer. For exploration, we will use RLSampler # which will own RLModule, and RLTrajectoryProcessor. The exploration and inference # phase details are TBD but the whole point is to make rllib extremely modular. diff --git a/rllib/algorithms/qmix/qmix.py b/rllib/algorithms/qmix/qmix.py index 135663e2d9e6..79f02332b8ce 100644 --- a/rllib/algorithms/qmix/qmix.py +++ b/rllib/algorithms/qmix/qmix.py @@ -83,7 +83,7 @@ def __init__(self): # QMix-torch overrides the TorchPolicy's learn_on_batch w/o specifying a # alternative `learn_on_loaded_batch` alternative for the GPU. # TODO: This hack will be resolved once we move all algorithms to the new - # RLModule/RLTrainer APIs. + # RLModule/Learner APIs. self.simple_optimizer = True # Override some of AlgorithmConfig's default values with QMix-specific values. diff --git a/rllib/algorithms/tests/test_algorithm_config.py b/rllib/algorithms/tests/test_algorithm_config.py index 9e90656cfaa7..03354d640713 100644 --- a/rllib/algorithms/tests/test_algorithm_config.py +++ b/rllib/algorithms/tests/test_algorithm_config.py @@ -167,8 +167,8 @@ class A: config.validate() self.assertEqual(config.rl_module_spec.module_class, A) - def test_rl_trainer_api(self): - # TODO (Kourosh): the default rl_trainer of PPO is not implemented yet. When + def test_learner_api(self): + # TODO (Kourosh): the default learner of PPO is not implemented yet. When # that's done this test should be updated class A: pass @@ -177,11 +177,11 @@ class A: PPOConfig() .environment("CartPole-v1") .rollouts(enable_connectors=True) - .training(rl_trainer_class=A, _enable_rl_trainer_api=True) + .training(learner_class=A, _enable_learner_api=True) ) config.validate() - self.assertEqual(config.rl_trainer_class, A) + self.assertEqual(config.learner_class, A) if __name__ == "__main__": diff --git a/rllib/core/learner/__init__.py b/rllib/core/learner/__init__.py new file mode 100644 index 000000000000..4fe8d1f67dec --- /dev/null +++ b/rllib/core/learner/__init__.py @@ -0,0 +1,10 @@ +from ray.rllib.core.learner.learner import Learner +from ray.rllib.core.learner.learner_group import LearnerGroup +from ray.rllib.core.learner.learner_group_config import LearnerGroupConfig + + +__all__ = [ + "Learner", + "LearnerGroup", + "LearnerGroupConfig", +] diff --git a/rllib/core/rl_trainer/rl_trainer.py b/rllib/core/learner/learner.py similarity index 91% rename from rllib/core/rl_trainer/rl_trainer.py rename to rllib/core/learner/learner.py index 87a8269d70c3..42abf794d8d4 100644 --- a/rllib/core/rl_trainer/rl_trainer.py +++ b/rllib/core/learner/learner.py @@ -36,8 +36,8 @@ MiniBatchDummyIterator, MiniBatchCyclicIterator, ) -from ray.rllib.core.rl_trainer.scaling_config import TrainerScalingConfig -from ray.rllib.core.rl_trainer.reduce_result_dict_fn import _reduce_mean_results +from ray.rllib.core.learner.scaling_config import LearnerGroupScalingConfig +from ray.rllib.core.learner.reduce_result_dict_fn import _reduce_mean_results from ray.rllib.utils.annotations import ( OverrideToImplementCustomLogic, OverrideToImplementCustomLogic_CallToSuperRecommended, @@ -71,10 +71,10 @@ class FrameworkHPs: @dataclass -class RLTrainerHPs: - """The hyper-parameters for RLTrainer. +class LearnerHPs: + """The hyper-parameters for Learner. - When creating a new RLTrainer, the new hyper-parameters have to be defined by + When creating a new Learner, the new hyper-parameters have to be defined by subclassing this class and adding the new hyper-parameters as fields. # TODO (Kourosh): The things that could be part of the base class: @@ -85,7 +85,7 @@ class RLTrainerHPs: pass -class RLTrainer: +class Learner: """Base class for learners. This class will be used to train RLModules. It is responsible for defining the loss @@ -112,17 +112,17 @@ class RLTrainer: optimizer_config: The deep learning gradient optimizer configuration to be used. For example lr=0.0001, momentum=0.9, etc. scaling_config: Configuration for scaling the learner actors. - Refer to ray.rllib.core.rl_trainer.scaling_config.TrainerScalingConfig + Refer to ray.rllib.core.learner.scaling_config.LearnerGroupScalingConfig for more info. - trainer_hyperparameters: The hyper-parameters for the Learner. + learner_hyperparameters: The hyper-parameters for the Learner. Algorithm specific learner hyper-parameters will passed in via this argument. For example in PPO the `vf_loss_coeff` hyper-parameter will be passed in via this argument. Refer to - ray.rllib.core.rl_trainer.rl_trainer.RLTrainerHPs for more info. + ray.rllib.core.learner.learner.LearnerHPs for more info. framework_hps: The framework specific hyper-parameters. This will be used to pass in any framework specific hyper-parameter that will impact the module creation. For example eager_tracing in TF or compile in Torch. - Refer to ray.rllib.core.rl_trainer.rl_trainer.FrameworkHPs for more info. + Refer to ray.rllib.core.learner.learner.FrameworkHPs for more info. Usage pattern: @@ -166,10 +166,10 @@ class RLTrainer: # will train previous modules only. results = learner.update(batch) - # get the state of the trainer + # get the state of the learner state = learner.get_state() - # set the state of the trainer + # set the state of the learner learner.set_state(state) # get the weights of the underly multi-agent RLModule @@ -203,31 +203,31 @@ def __init__( ] = None, module: Optional[RLModule] = None, optimizer_config: Mapping[str, Any] = None, - trainer_scaling_config: TrainerScalingConfig = TrainerScalingConfig(), - trainer_hyperparameters: Optional[RLTrainerHPs] = RLTrainerHPs(), + learner_scaling_config: LearnerGroupScalingConfig = LearnerGroupScalingConfig(), + learner_hyperparameters: Optional[LearnerHPs] = LearnerHPs(), framework_hyperparameters: Optional[FrameworkHPs] = FrameworkHPs(), ): # TODO (Kourosh): convert optimizer configs to dataclasses if module_spec is not None and module is not None: raise ValueError( - "Only one of module spec or module can be provided to RLTrainer." + "Only one of module spec or module can be provided to Learner." ) if module_spec is None and module is None: raise ValueError( - "Either module_spec or module should be provided to RLTrainer." + "Either module_spec or module should be provided to Learner." ) self._module_spec = module_spec self._module_obj = module self._optimizer_config = optimizer_config - self._hps = trainer_hyperparameters + self._hps = learner_hyperparameters - # pick the configs that we need for the trainer from scaling config - self._distributed = trainer_scaling_config.num_workers > 1 - self._use_gpu = trainer_scaling_config.num_gpus_per_worker > 0 + # pick the configs that we need for the learner from scaling config + self._distributed = learner_scaling_config.num_workers > 1 + self._use_gpu = learner_scaling_config.num_gpus_per_worker > 0 # if we are using gpu but we are not distributed, use this gpu for training - self._local_gpu_idx = trainer_scaling_config.local_gpu_idx + self._local_gpu_idx = learner_scaling_config.local_gpu_idx # These are the attributes that are set during build self._module: MultiAgentRLModule = None @@ -247,8 +247,8 @@ def module(self) -> MultiAgentRLModule: return self._module @property - def hps(self) -> RLTrainerHPs: - """The hyper-parameters for the trainer.""" + def hps(self) -> LearnerHPs: + """The hyper-parameters for the learner.""" return self._hps @abc.abstractmethod @@ -328,7 +328,7 @@ def get_param_ref(self, param: ParamType) -> Hashable: def get_parameters(self, module: RLModule) -> Sequence[ParamType]: """Returns the list of parameters of a module. - This should be overriden in framework specific trainer. For example in torch it + This should be overriden in framework specific learner. For example in torch it will return .parameters(), while in tf it returns .trainable_variables. Args: @@ -695,7 +695,7 @@ def update( return reduce_fn(results) def set_state(self, state: Mapping[str, Any]) -> None: - """Set the state of the trainer. + """Set the state of the learner. Args: state: The state of the optimizer and module. Can be obtained @@ -709,7 +709,7 @@ def set_state(self, state: Mapping[str, Any]) -> None: self._module.set_state(state.get("module_state", {})) def get_state(self) -> Mapping[str, Any]: - """Get the state of the trainer. + """Get the state of the learner. Returns: The state of the optimizer and module. @@ -720,7 +720,7 @@ def get_state(self) -> Mapping[str, Any]: return {"module_state": self._module.get_state()} def _make_module(self) -> MultiAgentRLModule: - """Construct the multi-agent RL module for the trainer. + """Construct the multi-agent RL module for the learner. This method uses `self._module_specs` or `self._module_obj` to construct the module. If the module_class is a single agent RL module it will be wrapped to a @@ -759,8 +759,8 @@ def _update( def __check_if_build_called(self): if self._module is None: raise ValueError( - "RLTrainer.build() must be called after constructing a " - "RLTrainer and before calling any methods on it." + "Learner.build() must be called after constructing a " + "Learner and before calling any methods on it." ) def apply(self, func, *_args, **_kwargs): @@ -768,43 +768,43 @@ def apply(self, func, *_args, **_kwargs): @dataclass -class RLTrainerSpec: - """The spec for constructing RLTrainer actors. +class LearnerSpec: + """The spec for constructing Learner actors. Args: - rl_trainer_class: The RLTrainer class to use. + learner_class: The Learner class to use. module_spec: The underlying (MA)RLModule spec to completely define the module. module: Alternatively the RLModule instance can be passed in directly. This - only works if the RLTrainer is not an actor. + only works if the Learner is not an actor. backend_config: The backend config for properly distributing the RLModule. optimizer_config: The optimizer setting to apply during training. - trainer_hyperparameters: The extra config for the loss/additional update. This - should be a subclass of RLTrainerHPs. This is useful for passing in + learner_hyperparameters: The extra config for the loss/additional update. This + should be a subclass of LearnerHPs. This is useful for passing in algorithm configs that contains the hyper-parameters for loss computation, change of training behaviors, etc. e.g lr, entropy_coeff. """ - rl_trainer_class: Type["RLTrainer"] + learner_class: Type["Learner"] module_spec: Union["SingleAgentRLModuleSpec", "MultiAgentRLModuleSpec"] = None module: Optional["RLModule"] = None - trainer_scaling_config: TrainerScalingConfig = field( - default_factory=TrainerScalingConfig + learner_scaling_config: LearnerGroupScalingConfig = field( + default_factory=LearnerGroupScalingConfig ) optimizer_config: Dict[str, Any] = field(default_factory=dict) - trainer_hyperparameters: RLTrainerHPs = field(default_factory=RLTrainerHPs) + learner_hyperparameters: LearnerHPs = field(default_factory=LearnerHPs) framework_hyperparameters: FrameworkHPs = field(default_factory=FrameworkHPs) def get_params_dict(self) -> Dict[str, Any]: - """Returns the parameters than be passed to the RLTrainer constructor.""" + """Returns the parameters than be passed to the Learner constructor.""" return { "module": self.module, "module_spec": self.module_spec, - "trainer_scaling_config": self.trainer_scaling_config, + "learner_scaling_config": self.learner_scaling_config, "optimizer_config": self.optimizer_config, - "trainer_hyperparameters": self.trainer_hyperparameters, + "learner_hyperparameters": self.learner_hyperparameters, "framework_hyperparameters": self.framework_hyperparameters, } - def build(self) -> "RLTrainer": - """Builds the RLTrainer instance.""" - return self.rl_trainer_class(**self.get_params_dict()) + def build(self) -> "Learner": + """Builds the Learner instance.""" + return self.learner_class(**self.get_params_dict()) diff --git a/rllib/core/rl_trainer/trainer_runner.py b/rllib/core/learner/learner_group.py similarity index 83% rename from rllib/core/rl_trainer/trainer_runner.py rename to rllib/core/learner/learner_group.py index 6f2876a79a5b..f95b05bab341 100644 --- a/rllib/core/rl_trainer/trainer_runner.py +++ b/rllib/core/learner/learner_group.py @@ -3,14 +3,14 @@ import ray -from ray.rllib.core.rl_trainer.reduce_result_dict_fn import _reduce_mean_results +from ray.rllib.core.learner.reduce_result_dict_fn import _reduce_mean_results from ray.rllib.core.rl_module.rl_module import ( RLModule, ModuleID, SingleAgentRLModuleSpec, ) -from ray.rllib.core.rl_trainer.rl_trainer import ( - RLTrainerSpec, +from ray.rllib.core.learner.learner import ( + LearnerSpec, ParamOptimizerPairs, Optimizer, ) @@ -22,15 +22,15 @@ from ray.train._internal.backend_executor import BackendExecutor if TYPE_CHECKING: - from ray.rllib.core.rl_trainer.rl_trainer import RLTrainer + from ray.rllib.core.learner.learner import Learner -def _get_backend_config(rl_trainer_class: Type["RLTrainer"]) -> str: - if rl_trainer_class.framework == "torch": +def _get_backend_config(learner_class: Type["Learner"]) -> str: + if learner_class.framework == "torch": from ray.train.torch import TorchConfig backend_config = TorchConfig() - elif rl_trainer_class.framework == "tf": + elif learner_class.framework == "tf": from ray.train.tensorflow import TensorflowConfig backend_config = TensorflowConfig() @@ -40,23 +40,23 @@ def _get_backend_config(rl_trainer_class: Type["RLTrainer"]) -> str: return backend_config -class TrainerRunner: - """Coordinator of RLTrainers. +class LearnerGroup: + """Coordinator of Learners. Public API: .update(batch) -> updates the RLModule based on gradient descent algos. .additional_update() -> any additional non-gradient based updates will get called from this entry point. .get_state() -> returns the state of the RLModule and RLOptimizer from - all of the RLTrainers. - .set_state() -> sets the state of all the RLTrainers. - .get_weights() -> returns the weights of the RLModule from the RLTrainer(s). - .set_weights() -> sets the weights of the RLModule in the RLTrainer(s). + all of the Learners. + .set_state() -> sets the state of all the Learners. + .get_weights() -> returns the weights of the RLModule from the Learner(s). + .set_weights() -> sets the weights of the RLModule in the Learner(s). .add_module() -> add a new RLModule to the MultiAgentRLModule being trained by - this TrainerRunner. + this LearnerGroup. .remove_module() -> remove an RLModule from the MultiAgentRLModule being trained - by this TrainerRunner. + by this LearnerGroup. Args: - rl_trainer_spec: The specification for constructing RLTrainers. + learner_spec: The specification for constructing Learners. max_queue_len: The maximum number of batches to queue up if doing non-blocking updates (e.g. `self.update(batch, block=False)`). If the queue is full it will evict the oldest batch first. @@ -64,16 +64,16 @@ class TrainerRunner: def __init__( self, - rl_trainer_spec: RLTrainerSpec, + learner_spec: LearnerSpec, max_queue_len: int = 20, ): - scaling_config = rl_trainer_spec.trainer_scaling_config - rl_trainer_class = rl_trainer_spec.rl_trainer_class + scaling_config = learner_spec.learner_scaling_config + learner_class = learner_spec.learner_class # TODO (Kourosh): Go with a _remote flag instead of _is_local to be more # explicit self._is_local = scaling_config.num_workers == 0 - self._trainer = None + self._learner = None self._workers = None # if a user calls self.shutdown() on their own then this flag is set to true. # When del is called the backend executor isn't shutdown twice if this flag is @@ -82,11 +82,11 @@ def __init__( self._is_shut_down = False if self._is_local: - self._trainer = rl_trainer_class(**rl_trainer_spec.get_params_dict()) - self._trainer.build() + self._learner = learner_class(**learner_spec.get_params_dict()) + self._learner.build() self._worker_manager = None else: - backend_config = _get_backend_config(rl_trainer_class) + backend_config = _get_backend_config(learner_class) backend_executor = BackendExecutor( backend_config=backend_config, num_workers=scaling_config.num_workers, @@ -96,8 +96,8 @@ def __init__( ) backend_executor.start( - train_cls=rl_trainer_class, - train_cls_kwargs=rl_trainer_spec.get_params_dict(), + train_cls=learner_class, + train_cls_kwargs=learner_spec.get_params_dict(), ) self._backend_executor = backend_executor @@ -126,14 +126,14 @@ def update( reduce_fn: Callable[[ResultDict], ResultDict] = _reduce_mean_results, block: bool = True, ) -> List[Mapping[str, Any]]: - """Do one gradient based update to the RLTrainer(s). + """Do one gradient based update to the Learner(s). Args: batch: The data to use for the update. minibatch_size: The minibatch size to use for the update. num_iters: The number of complete passes over all the sub-batches in the input multi-agent batch. - reduce_fn: A function to reduce the results from a list of RLTrainer Actors + reduce_fn: A function to reduce the results from a list of Learner Actors into a single result. This can be any arbitrary function that takes a list of dictionaries and returns a single dictionary. For example you can either take an average (default) or concatenate the results (for @@ -143,7 +143,7 @@ def update( block: Whether to block until the update is complete. Returns: - A list of dictionaries of results from the updates from the RLTrainer(s) + A list of dictionaries of results from the updates from the Learner(s) """ if self.is_local: if not block: @@ -152,7 +152,7 @@ def update( "mode with num_workers=0." ) results = [ - self._trainer.update( + self._learner.update( batch, minibatch_size=minibatch_size, num_iters=num_iters, @@ -182,9 +182,9 @@ def _distributed_update( reduce_fn: Callable[[ResultDict], ResultDict] = _reduce_mean_results, block: bool = True, ) -> List[Mapping[str, Any]]: - """Do a gradient based update to the RLTrainers using DDP training. + """Do a gradient based update to the Learners using DDP training. - Note: this function is used if the num_gpus this TrainerRunner is configured + Note: this function is used if the num_gpus this LearnerGroup is configured with is > 0. If _fake_gpus is True then this function will still be used for distributed training, but the workers will be configured to use a different backend than the cuda backend. @@ -193,7 +193,7 @@ def _distributed_update( See `.update()` docstring. Returns: - A list of dictionaries of results from the updates from the RLTrainer(s) + A list of dictionaries of results from the updates from the Learner(s) """ if block: @@ -247,23 +247,23 @@ def additional_update( reduce_fn: Optional[Callable[[ResultDict], ResultDict]] = _reduce_mean_results, **kwargs, ) -> List[Mapping[str, Any]]: - """Apply additional non-gradient based updates to the RLTrainers. + """Apply additional non-gradient based updates to the Learners. For example, this could be used to do a polyak averaging update of a target network in off policy algorithms like SAC or DQN. - By default this is a pass through that calls `RLTrainer.additional_update` + By default this is a pass through that calls `Learner.additional_update` Args: reduce_fn: See `update()` documentation for more details. - **kwargs: Keyword arguments to pass to each RLTrainer. + **kwargs: Keyword arguments to pass to each Learner. Returns: A list of dictionaries of results from the updates from each worker. """ if self.is_local: - results = [self._trainer.additional_update(**kwargs)] + results = [self._learner.additional_update(**kwargs)] else: results = self._worker_manager.foreach_actor( [lambda w: w.additional_update(**kwargs) for worker in self._workers] @@ -281,7 +281,7 @@ def add_module( set_optimizer_fn: Optional[Callable[[RLModule], ParamOptimizerPairs]] = None, optimizer_cls: Optional[Type[Optimizer]] = None, ) -> None: - """Add a module to the RLTrainers maintained by this TrainerRunner. + """Add a module to the Learners maintained by this LearnerGroup. Args: module_id: The id of the module to add. @@ -295,7 +295,7 @@ def add_module( should be provided. """ if self.is_local: - self._trainer.add_module( + self._learner.add_module( module_id=module_id, module_spec=module_spec, set_optimizer_fn=set_optimizer_fn, @@ -313,14 +313,14 @@ def add_module( return self._get_results(results) def remove_module(self, module_id: ModuleID) -> None: - """Remove a module from the RLTrainers maintained by this TrainerRunner. + """Remove a module from the Learners maintained by this LearnerGroup. Args: module_id: The id of the module to remove. """ if self.is_local: - self._trainer.remove_module(module_id) + self._learner.remove_module(module_id) else: refs = [] for worker in self._workers: @@ -332,7 +332,7 @@ def set_weights(self, weights) -> None: # TODO (Kourosh) Set / get weight has to be thoroughly # tested across actors and multi-gpus if self.is_local: - self._trainer.set_weights(weights) + self._learner.set_weights(weights) else: results_or_errors = self._worker_manager.foreach_actor( lambda w: w.set_weights(weights) @@ -342,7 +342,7 @@ def set_weights(self, weights) -> None: def get_weights(self, module_ids: Optional[Set[str]] = None) -> Mapping[str, Any]: if self.is_local: - weights = self._trainer.get_weights(module_ids) + weights = self._learner.get_weights(module_ids) else: worker = self._worker_manager.healthy_actor_ids()[0] assert len(self._workers) == self._worker_manager.num_healthy_actors() @@ -354,12 +354,12 @@ def get_weights(self, module_ids: Optional[Set[str]] = None) -> Mapping[str, Any return convert_to_numpy(weights) def get_state(self) -> Mapping[ModuleID, Mapping[str, Any]]: - """Get the states of the first RLTrainers. + """Get the states of the first Learners. - This should be the same across RLTrainers + This should be the same across Learners """ if self.is_local: - return self._trainer.get_state() + return self._learner.get_state() else: worker = self._worker_manager.healthy_actor_ids()[0] assert len(self._workers) == self._worker_manager.num_healthy_actors() @@ -369,19 +369,19 @@ def get_state(self) -> Mapping[ModuleID, Mapping[str, Any]]: return self._get_results(results)[0] def set_state(self, state: List[Mapping[ModuleID, Mapping[str, Any]]]) -> None: - """Sets the states of the RLTrainers. + """Sets the states of the Learners. Args: - state: The state of the RLTrainers + state: The state of the Learners """ if self.is_local: - self._trainer.set_state(state) + self._learner.set_state(state) else: self._worker_manager.foreach_actor(lambda w: w.set_state(state)) def shutdown(self): - """Shuts down the TrainerRunner.""" + """Shuts down the LearnerGroup.""" if not self._is_local: self._backend_executor.shutdown() self._is_shut_down = True diff --git a/rllib/core/rl_trainer/trainer_runner_config.py b/rllib/core/learner/learner_group_config.py similarity index 52% rename from rllib/core/rl_trainer/trainer_runner_config.py rename to rllib/core/learner/learner_group_config.py index 33f73f8a04d0..bf6454886a3d 100644 --- a/rllib/core/rl_trainer/trainer_runner_config.py +++ b/rllib/core/learner/learner_group_config.py @@ -2,44 +2,44 @@ from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec -from ray.rllib.core.rl_trainer.trainer_runner import TrainerRunner -from ray.rllib.core.rl_trainer.scaling_config import TrainerScalingConfig -from ray.rllib.core.rl_trainer.rl_trainer import ( - RLTrainerSpec, - RLTrainerHPs, +from ray.rllib.core.learner.learner_group import LearnerGroup +from ray.rllib.core.learner.scaling_config import LearnerGroupScalingConfig +from ray.rllib.core.learner.learner import ( + LearnerSpec, + LearnerHPs, FrameworkHPs, ) from ray.rllib.utils.from_config import NotProvided if TYPE_CHECKING: - from ray.rllib.core.rl_trainer import RLTrainer + from ray.rllib.core.learner import Learner ModuleSpec = Union[SingleAgentRLModuleSpec, MultiAgentRLModuleSpec] # TODO (Kourosh): We should make all configs come from a standard base class that # defines the general interfaces for validation, from_dict, to_dict etc. -class TrainerRunnerConfig: - """Configuration object for TrainerRunner.""" +class LearnerGroupConfig: + """Configuration object for LearnerGroup.""" - def __init__(self, cls: Type[TrainerRunner] = None) -> None: + def __init__(self, cls: Type[LearnerGroup] = None) -> None: - # Define the default TrainerRunner class - self.trainer_runner_class = cls or TrainerRunner + # Define the default LearnerGroup class + self.learner_group_class = cls or LearnerGroup # `self.module()` self.module_spec = None - # `self.trainer()` - self.trainer_class = None + # `self.learner()` + self.learner_class = None self.optimizer_config = None - self.rl_trainer_hps = RLTrainerHPs() + self.learner_hps = LearnerHPs() # `self.resources()` - self.num_gpus_per_trainer_worker = 0 - self.num_cpus_per_trainer_worker = 1 - self.num_trainer_workers = 1 + self.num_gpus_per_learner_worker = 0 + self.num_cpus_per_learner_worker = 1 + self.num_learner_workers = 1 # TODO (Avnishn): We should come back and revise how to specify algorithm # resources this is a stop gap solution for now so that users can specify the @@ -55,14 +55,14 @@ def validate(self) -> None: if self.module_spec is None: raise ValueError( - "Cannot initialize an RLTrainer without the module specs. " + "Cannot initialize an Learner without the module specs. " "Please provide the specs via .module(module_spec)." ) - if self.trainer_class is None: + if self.learner_class is None: raise ValueError( - "Cannot initialize an RLTrainer without an RLTrainer. Please provide " - "the RLTrainer class with .trainer(trainer_class=MyTrainerClass)." + "Cannot initialize an Learner without an Learner class. Please provide " + "the Learner class with .learner(learner_class=MyTrainerClass)." ) if self.optimizer_config is None: @@ -70,32 +70,32 @@ def validate(self) -> None: # TODO (Kourosh): Change the optimizer config to a dataclass object. self.optimizer_config = {"lr": 1e-3} - def build(self) -> TrainerRunner: + def build(self) -> LearnerGroup: self.validate() - scaling_config = TrainerScalingConfig( - num_workers=self.num_trainer_workers, - num_gpus_per_worker=self.num_gpus_per_trainer_worker, - num_cpus_per_worker=self.num_cpus_per_trainer_worker, + scaling_config = LearnerGroupScalingConfig( + num_workers=self.num_learner_workers, + num_gpus_per_worker=self.num_gpus_per_learner_worker, + num_cpus_per_worker=self.num_cpus_per_learner_worker, local_gpu_idx=self.local_gpu_idx, ) framework_hps = FrameworkHPs(eager_tracing=self.eager_tracing) - rl_trainer_spec = RLTrainerSpec( - rl_trainer_class=self.trainer_class, + learner_spec = LearnerSpec( + learner_class=self.learner_class, module_spec=self.module_spec, optimizer_config=self.optimizer_config, - trainer_scaling_config=scaling_config, - trainer_hyperparameters=self.rl_trainer_hps, + learner_scaling_config=scaling_config, + learner_hyperparameters=self.learner_hps, framework_hyperparameters=framework_hps, ) - return self.trainer_runner_class(rl_trainer_spec) + return self.learner_group_class(learner_spec) def framework( self, eager_tracing: Optional[bool] = NotProvided - ) -> "TrainerRunnerConfig": + ) -> "LearnerGroupConfig": if eager_tracing is not NotProvided: self.eager_tracing = eager_tracing @@ -104,7 +104,7 @@ def framework( def module( self, module_spec: Optional[ModuleSpec] = NotProvided, - ) -> "TrainerRunnerConfig": + ) -> "LearnerGroupConfig": if module_spec is not NotProvided: self.module_spec = module_spec @@ -113,36 +113,36 @@ def module( def resources( self, - num_trainer_workers: Optional[int] = NotProvided, - num_gpus_per_trainer_worker: Optional[Union[float, int]] = NotProvided, - num_cpus_per_trainer_worker: Optional[Union[float, int]] = NotProvided, + num_learner_workers: Optional[int] = NotProvided, + num_gpus_per_learner_worker: Optional[Union[float, int]] = NotProvided, + num_cpus_per_learner_worker: Optional[Union[float, int]] = NotProvided, local_gpu_idx: Optional[int] = NotProvided, - ) -> "TrainerRunnerConfig": - - if num_trainer_workers is not NotProvided: - self.num_trainer_workers = num_trainer_workers - if num_gpus_per_trainer_worker is not NotProvided: - self.num_gpus_per_trainer_worker = num_gpus_per_trainer_worker - if num_cpus_per_trainer_worker is not NotProvided: - self.num_cpus_per_trainer_worker = num_cpus_per_trainer_worker + ) -> "LearnerGroupConfig": + + if num_learner_workers is not NotProvided: + self.num_learner_workers = num_learner_workers + if num_gpus_per_learner_worker is not NotProvided: + self.num_gpus_per_learner_worker = num_gpus_per_learner_worker + if num_cpus_per_learner_worker is not NotProvided: + self.num_cpus_per_learner_worker = num_cpus_per_learner_worker if local_gpu_idx is not NotProvided: self.local_gpu_idx = local_gpu_idx return self - def trainer( + def learner( self, *, - trainer_class: Optional[Type["RLTrainer"]] = NotProvided, + learner_class: Optional[Type["Learner"]] = NotProvided, optimizer_config: Optional[Dict] = NotProvided, - rl_trainer_hps: Optional[RLTrainerHPs] = NotProvided, - ) -> "TrainerRunnerConfig": + learner_hps: Optional[LearnerHPs] = NotProvided, + ) -> "LearnerGroupConfig": - if trainer_class is not NotProvided: - self.trainer_class = trainer_class + if learner_class is not NotProvided: + self.learner_class = learner_class if optimizer_config is not NotProvided: self.optimizer_config = optimizer_config - if rl_trainer_hps is not NotProvided: - self.rl_trainer_hps = rl_trainer_hps + if learner_hps is not NotProvided: + self.learner_hps = learner_hps return self diff --git a/rllib/core/rl_trainer/reduce_result_dict_fn.py b/rllib/core/learner/reduce_result_dict_fn.py similarity index 100% rename from rllib/core/rl_trainer/reduce_result_dict_fn.py rename to rllib/core/learner/reduce_result_dict_fn.py diff --git a/rllib/core/rl_trainer/scaling_config.py b/rllib/core/learner/scaling_config.py similarity index 97% rename from rllib/core/rl_trainer/scaling_config.py rename to rllib/core/learner/scaling_config.py index 4a81f7e12589..8b02494a5efb 100644 --- a/rllib/core/rl_trainer/scaling_config.py +++ b/rllib/core/learner/scaling_config.py @@ -2,7 +2,7 @@ @dataclass -class TrainerScalingConfig: +class LearnerGroupScalingConfig: """Configuratiom for scaling training actors. Attributes: diff --git a/rllib/core/rl_trainer/tests/__init__.py b/rllib/core/learner/tests/__init__.py similarity index 100% rename from rllib/core/rl_trainer/tests/__init__.py rename to rllib/core/learner/tests/__init__.py diff --git a/rllib/core/rl_trainer/tests/test_rl_trainer.py b/rllib/core/learner/tests/test_learner.py similarity index 76% rename from rllib/core/rl_trainer/tests/test_rl_trainer.py rename to rllib/core/learner/tests/test_learner.py index 7489c2a1225e..06878395d0bc 100644 --- a/rllib/core/rl_trainer/tests/test_rl_trainer.py +++ b/rllib/core/learner/tests/test_learner.py @@ -6,18 +6,18 @@ import ray from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec -from ray.rllib.core.rl_trainer.rl_trainer import RLTrainer, FrameworkHPs +from ray.rllib.core.learner.learner import Learner, FrameworkHPs from ray.rllib.core.testing.tf.bc_module import DiscreteBCTFModule -from ray.rllib.core.testing.tf.bc_rl_trainer import BCTfRLTrainer +from ray.rllib.core.testing.tf.bc_learner import BCTfLearner from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils.test_utils import check, get_cartpole_dataset_reader -from ray.rllib.core.rl_trainer.scaling_config import TrainerScalingConfig +from ray.rllib.core.learner.scaling_config import LearnerGroupScalingConfig -def get_trainer() -> RLTrainer: +def get_learner() -> Learner: env = gym.make("CartPole-v1") - trainer = BCTfRLTrainer( + learner = BCTfLearner( module_spec=SingleAgentRLModuleSpec( module_class=DiscreteBCTFModule, observation_space=env.observation_space, @@ -25,16 +25,16 @@ def get_trainer() -> RLTrainer: model_config={"fcnet_hiddens": [32]}, ), optimizer_config={"lr": 1e-3}, - trainer_scaling_config=TrainerScalingConfig(), + learner_scaling_config=LearnerGroupScalingConfig(), framework_hyperparameters=FrameworkHPs(eager_tracing=True), ) - trainer.build() + learner.build() - return trainer + return learner -class TestRLTrainer(unittest.TestCase): +class TestLearner(unittest.TestCase): @classmethod def setUp(cls) -> None: ray.init() @@ -45,13 +45,13 @@ def tearDown(cls) -> None: def test_end_to_end_update(self): - trainer = get_trainer() + learner = get_learner() reader = get_cartpole_dataset_reader(batch_size=512) min_loss = float("inf") for iter_i in range(1000): batch = reader.next() - results = trainer.update(batch.as_multi_agent()) + results = learner.update(batch.as_multi_agent()) loss = results["loss"]["total_loss"] min_loss = min(loss, min_loss) @@ -68,12 +68,12 @@ def test_compute_gradients(self): Tests that if we sum all the trainable variables the gradient of output w.r.t. the weights is all ones. """ - trainer = get_trainer() + learner = get_learner() with tf.GradientTape() as tape: - params = trainer.module[DEFAULT_POLICY_ID].trainable_variables + params = learner.module[DEFAULT_POLICY_ID].trainable_variables loss = {"total_loss": sum([tf.reduce_sum(param) for param in params])} - gradients = trainer.compute_gradients(loss, tape) + gradients = learner.compute_gradients(loss, tape) # type should be a mapping from ParamRefs to gradients self.assertIsInstance(gradients, dict) @@ -88,18 +88,18 @@ def test_apply_gradients(self): standard SGD/Adam update rule. """ - trainer = get_trainer() + learner = get_learner() # calculated the expected new params based on gradients of all ones. - params = trainer.module[DEFAULT_POLICY_ID].trainable_variables + params = learner.module[DEFAULT_POLICY_ID].trainable_variables n_steps = 100 expected = [ - param - n_steps * trainer._optimizer_config["lr"] * np.ones(param.shape) + param - n_steps * learner._optimizer_config["lr"] * np.ones(param.shape) for param in params ] for _ in range(n_steps): - gradients = {trainer.get_param_ref(p): tf.ones_like(p) for p in params} - trainer.apply_gradients(gradients) + gradients = {learner.get_param_ref(p): tf.ones_like(p) for p in params} + learner.apply_gradients(gradients) check(params, expected) @@ -111,7 +111,7 @@ def test_add_remove_module(self): all variables the updated parameters follow the SGD update rule. """ env = gym.make("CartPole-v1") - trainer = get_trainer() + learner = get_learner() # add a test module with SGD optimizer with a known lr lr = 1e-4 @@ -121,7 +121,7 @@ def set_optimizer_fn(module): (module.trainable_variables, tf.keras.optimizers.SGD(learning_rate=lr)) ] - trainer.add_module( + learner.add_module( module_id="test", module_spec=SingleAgentRLModuleSpec( module_class=DiscreteBCTFModule, @@ -132,20 +132,20 @@ def set_optimizer_fn(module): set_optimizer_fn=set_optimizer_fn, ) - trainer.remove_module(DEFAULT_POLICY_ID) + learner.remove_module(DEFAULT_POLICY_ID) # only test module should be left - self.assertEqual(set(trainer.module.keys()), {"test"}) + self.assertEqual(set(learner.module.keys()), {"test"}) # calculated the expected new params based on gradients of all ones. - params = trainer.module["test"].trainable_variables + params = learner.module["test"].trainable_variables n_steps = 100 expected = [param - n_steps * lr * np.ones(param.shape) for param in params] for _ in range(n_steps): with tf.GradientTape() as tape: loss = {"total_loss": sum([tf.reduce_sum(param) for param in params])} - gradients = trainer.compute_gradients(loss, tape) - trainer.apply_gradients(gradients) + gradients = learner.compute_gradients(loss, tape) + learner.apply_gradients(gradients) check(params, expected) diff --git a/rllib/core/rl_trainer/tests/test_trainer_runner.py b/rllib/core/learner/tests/test_learner_group.py similarity index 73% rename from rllib/core/rl_trainer/tests/test_trainer_runner.py rename to rllib/core/learner/tests/test_learner_group.py index bc13442dfccd..3bf6535bb380 100644 --- a/rllib/core/rl_trainer/tests/test_trainer_runner.py +++ b/rllib/core/learner/tests/test_learner_group.py @@ -6,29 +6,29 @@ import ray from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, MultiAgentBatch from ray.rllib.utils.test_utils import check, get_cartpole_dataset_reader -from ray.rllib.core.rl_trainer.scaling_config import TrainerScalingConfig +from ray.rllib.core.learner.scaling_config import LearnerGroupScalingConfig from ray.rllib.core.testing.utils import ( - get_trainer_runner, - get_rl_trainer, - add_module_to_runner_or_trainer, + get_learner_group, + get_learner, + add_module_to_learner_or_learner_group, ) from ray.util.timer import _Timer REMOTE_SCALING_CONFIGS = { - "remote-cpu": TrainerScalingConfig(num_workers=1), - "remote-gpu": TrainerScalingConfig(num_workers=1, num_gpus_per_worker=0.5), - "multi-gpu-ddp": TrainerScalingConfig(num_workers=2, num_gpus_per_worker=1), - "multi-cpu-ddp": TrainerScalingConfig(num_workers=2, num_cpus_per_worker=2), - # "multi-gpu-ddp-pipeline": TrainerScalingConfig( + "remote-cpu": LearnerGroupScalingConfig(num_workers=1), + "remote-gpu": LearnerGroupScalingConfig(num_workers=1, num_gpus_per_worker=0.5), + "multi-gpu-ddp": LearnerGroupScalingConfig(num_workers=2, num_gpus_per_worker=1), + "multi-cpu-ddp": LearnerGroupScalingConfig(num_workers=2, num_cpus_per_worker=2), + # "multi-gpu-ddp-pipeline": LearnerGroupScalingConfig( # num_workers=2, num_gpus_per_worker=2 # ), } LOCAL_SCALING_CONFIGS = { - "local-cpu": TrainerScalingConfig(num_workers=0, num_gpus_per_worker=0), - "local-gpu": TrainerScalingConfig(num_workers=0, num_gpus_per_worker=0.5), + "local-cpu": LearnerGroupScalingConfig(num_workers=0, num_gpus_per_worker=0), + "local-gpu": LearnerGroupScalingConfig(num_workers=0, num_gpus_per_worker=0.5), } @@ -39,44 +39,44 @@ class RemoteTrainingHelper: def local_training_helper(self, fw, scaling_mode) -> None: env = gym.make("CartPole-v1") scaling_config = LOCAL_SCALING_CONFIGS[scaling_mode] - runner = get_trainer_runner(fw, env, scaling_config, eager_tracing=True) - local_trainer = get_rl_trainer(fw, env) - local_trainer.build() + learner_group = get_learner_group(fw, env, scaling_config, eager_tracing=True) + local_learner = get_learner(fw, env) + local_learner.build() - # make the state of the trainer and the local runner identical - local_trainer.set_state(runner.get_state()) + # make the state of the learner and the local learner_group identical + local_learner.set_state(learner_group.get_state()) reader = get_cartpole_dataset_reader(batch_size=500) batch = reader.next() batch = batch.as_multi_agent() - check(local_trainer.update(batch), runner.update(batch)) + check(local_learner.update(batch), learner_group.update(batch)) new_module_id = "test_module" - add_module_to_runner_or_trainer(fw, env, new_module_id, runner) - add_module_to_runner_or_trainer(fw, env, new_module_id, local_trainer) + add_module_to_learner_or_learner_group(fw, env, new_module_id, learner_group) + add_module_to_learner_or_learner_group(fw, env, new_module_id, local_learner) - # make the state of the trainer and the local runner identical - local_trainer.set_state(runner.get_state()) + # make the state of the learner and the local learner_group identical + local_learner.set_state(learner_group.get_state()) # do another update batch = reader.next() ma_batch = MultiAgentBatch( {new_module_id: batch, DEFAULT_POLICY_ID: batch}, env_steps=batch.count ) - check(local_trainer.update(ma_batch), runner.update(ma_batch)) + check(local_learner.update(ma_batch), learner_group.update(ma_batch)) - check(local_trainer.get_state(), runner.get_state()) + check(local_learner.get_state(), learner_group.get_state()) -class TestTrainerRunner(unittest.TestCase): +class TestLearnerGroup(unittest.TestCase): def setUp(self) -> None: ray.init() def tearDown(self) -> None: ray.shutdown() - def test_trainer_runner_local(self): + def test_learner_group_local(self): fws = ["tf", "torch"] test_iterator = itertools.product(fws, LOCAL_SCALING_CONFIGS) @@ -99,13 +99,15 @@ def test_update_multigpu(self): env = gym.make("CartPole-v1") scaling_config = REMOTE_SCALING_CONFIGS[scaling_mode] - runner = get_trainer_runner(fw, env, scaling_config, eager_tracing=True) + learner_group = get_learner_group( + fw, env, scaling_config, eager_tracing=True + ) reader = get_cartpole_dataset_reader(batch_size=1024) min_loss = float("inf") for iter_i in range(1000): batch = reader.next() - results = runner.update(batch.as_multi_agent(), reduce_fn=None) + results = learner_group.update(batch.as_multi_agent(), reduce_fn=None) loss = np.mean([res["loss"]["total_loss"] for res in results]) min_loss = min(loss, min_loss) @@ -123,9 +125,10 @@ def test_update_multigpu(self): self.assertLess(min_loss, 0.57) - # make sure the runner resources are freed up so that we don't autoscale - runner.shutdown() - del runner + # make sure the learner_group resources are freed up so that we don't + # autoscale + learner_group.shutdown() + del learner_group def test_add_remove_module(self): fws = ["tf", "torch"] @@ -136,20 +139,24 @@ def test_add_remove_module(self): print(f"Testing framework: {fw}, scaling mode: {scaling_mode}.") env = gym.make("CartPole-v1") scaling_config = REMOTE_SCALING_CONFIGS[scaling_mode] - runner = get_trainer_runner(fw, env, scaling_config, eager_tracing=True) + learner_group = get_learner_group( + fw, env, scaling_config, eager_tracing=True + ) reader = get_cartpole_dataset_reader(batch_size=512) batch = reader.next() # update once with the default policy - results = runner.update(batch.as_multi_agent(), reduce_fn=None) + results = learner_group.update(batch.as_multi_agent(), reduce_fn=None) module_ids_before_add = {DEFAULT_POLICY_ID} new_module_id = "test_module" # add a test_module - add_module_to_runner_or_trainer(fw, env, new_module_id, runner) + add_module_to_learner_or_learner_group( + fw, env, new_module_id, learner_group + ) # do training that includes the test_module - results = runner.update( + results = learner_group.update( MultiAgentBatch( {new_module_id: batch, DEFAULT_POLICY_ID: batch}, batch.count ), @@ -173,10 +180,10 @@ def test_add_remove_module(self): ) # remove the test_module - runner.remove_module(module_id=new_module_id) + learner_group.remove_module(module_id=new_module_id) # run training without the test_module - results = runner.update(batch.as_multi_agent(), reduce_fn=None) + results = learner_group.update(batch.as_multi_agent(), reduce_fn=None) # check that module weights are updated across workers and synchronized for i in range(1, len(results)): @@ -194,9 +201,10 @@ def test_add_remove_module(self): set(result["loss"]) - {"total_loss"}, module_ids_before_add ) - # make sure the runner resources are freed up so that we don't autoscale - runner.shutdown() - del runner + # make sure the learner_group resources are freed up so that we don't + # autoscale + learner_group.shutdown() + del learner_group def test_async_update(self): """Test that async style updates converge to the same result as sync.""" @@ -210,16 +218,16 @@ def test_async_update(self): print(f"Testing framework: {fw}, scaling mode: {scaling_mode}.") env = gym.make("CartPole-v1") scaling_config = REMOTE_SCALING_CONFIGS[scaling_mode] - runner = get_trainer_runner(fw, env, scaling_config) + learner_group = get_learner_group(fw, env, scaling_config) reader = get_cartpole_dataset_reader(batch_size=512) min_loss = float("inf") batch = reader.next() timer_sync = _Timer() timer_async = _Timer() with timer_sync: - runner.update(batch.as_multi_agent(), block=True, reduce_fn=None) + learner_group.update(batch.as_multi_agent(), block=True, reduce_fn=None) with timer_async: - result_async = runner.update( + result_async = learner_group.update( batch.as_multi_agent(), block=False, reduce_fn=None ) # ideally the the first async update will return nothing, and an easy @@ -230,7 +238,7 @@ def test_async_update(self): self.assertEqual(len(result_async), 0) for iter_i in range(1000): batch = reader.next() - results = runner.update( + results = learner_group.update( batch.as_multi_agent(), block=False, reduce_fn=None ) if not results: @@ -248,7 +256,7 @@ def test_async_update(self): res1["mean_weight"]["default_policy"], res2["mean_weight"]["default_policy"], ) - runner.shutdown() + learner_group.shutdown() self.assertLess(min_loss, 0.57) diff --git a/rllib/core/rl_trainer/tests/test_trainer_runner_config.py b/rllib/core/learner/tests/test_learner_group_config.py similarity index 61% rename from rllib/core/rl_trainer/tests/test_trainer_runner_config.py rename to rllib/core/learner/tests/test_learner_group_config.py index 7a4f111c17d9..476ba86f2cec 100644 --- a/rllib/core/rl_trainer/tests/test_trainer_runner_config.py +++ b/rllib/core/learner/tests/test_learner_group_config.py @@ -5,9 +5,9 @@ from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec -from ray.rllib.core.rl_trainer.trainer_runner_config import TrainerRunnerConfig +from ray.rllib.core.learner.learner_group_config import LearnerGroupConfig from ray.rllib.core.testing.tf.bc_module import DiscreteBCTFModule -from ray.rllib.core.testing.tf.bc_rl_trainer import BCTfRLTrainer +from ray.rllib.core.testing.tf.bc_learner import BCTfLearner from ray.rllib.core.testing.utils import get_module_spec @@ -20,28 +20,28 @@ def setUpClass(cls): def tearDownClass(cls): ray.shutdown() - def test_trainer_runner_build(self): - """Tests whether the trainer_runner can be constructed and built.""" + def test_learner_group_build(self): + """Tests whether the learner_group can be constructed and built.""" env = gym.make("CartPole-v1") config = ( - TrainerRunnerConfig() + LearnerGroupConfig() .module(get_module_spec("tf", env)) - .trainer( - trainer_class=BCTfRLTrainer, + .learner( + learner_class=BCTfLearner, ) ) config.build() - def test_trainer_runner_build_from_algorithm_config(self): - """Tests whether we can build a trainer runner object from algorithm_config.""" + def test_learner_group_build_from_algorithm_config(self): + """Tests whether we can build a learner_groupobject from algorithm_config.""" env = gym.make("CartPole-v1") - config = AlgorithmConfig().training(rl_trainer_class=BCTfRLTrainer) + config = AlgorithmConfig().training(learner_class=BCTfLearner) config.freeze() - runner_config = config.get_trainer_runner_config( + learner_group_config = config.get_learner_group_config( SingleAgentRLModuleSpec( module_class=DiscreteBCTFModule, observation_space=env.observation_space, @@ -49,7 +49,7 @@ def test_trainer_runner_build_from_algorithm_config(self): model_config={"fcnet_hiddens": [32]}, ) ) - runner_config.build() + learner_group_config.build() if __name__ == "__main__": diff --git a/rllib/core/rl_trainer/tf/__init__.py b/rllib/core/learner/tf/__init__.py similarity index 100% rename from rllib/core/rl_trainer/tf/__init__.py rename to rllib/core/learner/tf/__init__.py diff --git a/rllib/core/rl_trainer/tf/tf_rl_trainer.py b/rllib/core/learner/tf/tf_learner.py similarity index 94% rename from rllib/core/rl_trainer/tf/tf_rl_trainer.py rename to rllib/core/learner/tf/tf_learner.py index 6944cc1db1d3..77a5d3d5ed5a 100644 --- a/rllib/core/rl_trainer/tf/tf_rl_trainer.py +++ b/rllib/core/learner/tf/tf_learner.py @@ -10,9 +10,9 @@ Hashable, ) -from ray.rllib.core.rl_trainer.rl_trainer import ( +from ray.rllib.core.learner.learner import ( FrameworkHPs, - RLTrainer, + Learner, ParamOptimizerPairs, Optimizer, ParamType, @@ -39,7 +39,7 @@ logger = logging.getLogger(__name__) -class TfRLTrainer(RLTrainer): +class TfLearner(Learner): framework: str = "tf" @@ -58,7 +58,7 @@ def __init__( # cpu only case, build will override this if needed. self._strategy = tf.distribute.get_strategy() - @override(RLTrainer) + @override(Learner) def configure_optimizers(self) -> ParamOptimizerPairs: """Configures the optimizers for the Learner. @@ -75,14 +75,14 @@ def configure_optimizers(self) -> ParamOptimizerPairs: for key in self._module.keys() ] - @override(RLTrainer) + @override(Learner) def compute_gradients( self, loss: Union[TensorType, Mapping[str, Any]], tape: "tf.GradientTape" ) -> ParamDictType: grads = tape.gradient(loss[self.TOTAL_LOSS_KEY], self._params) return grads - @override(RLTrainer) + @override(Learner) def apply_gradients(self, gradients: ParamDictType) -> None: # TODO (Avnishn, kourosh): apply gradients doesn't work in cases where # only some agents have a sample batch that is passed but not others. @@ -93,32 +93,32 @@ def apply_gradients(self, gradients: ParamDictType) -> None: gradient_list = [gradients[param_ref] for param_ref in param_ref_seq] optim.apply_gradients(zip(gradient_list, variable_list)) - @override(RLTrainer) + @override(Learner) def get_weights(self) -> Mapping[str, Any]: # TODO (Kourosh) Implement this. raise NotImplementedError - @override(RLTrainer) + @override(Learner) def set_weights(self, weights: Mapping[str, Any]) -> None: # TODO (Kourosh) Implement this. raise NotImplementedError - @override(RLTrainer) + @override(Learner) def get_param_ref(self, param: ParamType) -> Hashable: return param.ref() - @override(RLTrainer) + @override(Learner) def get_parameters(self, module: RLModule) -> Sequence[ParamType]: return list(module.trainable_variables) - @override(RLTrainer) + @override(Learner) def get_optimizer_obj( self, module: RLModule, optimizer_cls: Type[Optimizer] ) -> Optimizer: lr = self._optimizer_config["lr"] return optimizer_cls(learning_rate=lr) - @override(RLTrainer) + @override(Learner) def _convert_batch_type(self, batch: MultiAgentBatch) -> NestedDict[TensorType]: """Convert the arrays of batch to tf.Tensor's. @@ -141,7 +141,7 @@ def _convert_batch_type(self, batch: MultiAgentBatch) -> NestedDict[TensorType]: batch[key] = tf.convert_to_tensor(value, dtype=tf.float32) return batch.asdict() - @override(RLTrainer) + @override(Learner) def add_module( self, *, @@ -168,14 +168,14 @@ def add_module( if self._enable_tf_function: self._update_fn = tf.function(self._do_update_fn, reduce_retracing=True) - @override(RLTrainer) + @override(Learner) def remove_module(self, module_id: ModuleID) -> None: with self._strategy.scope(): super().remove_module(module_id) if self._enable_tf_function: self._update_fn = tf.function(self._do_update_fn, reduce_retracing=True) - @override(RLTrainer) + @override(Learner) def build(self) -> None: """Build the TfLearner. @@ -205,7 +205,7 @@ def build(self) -> None: else: self._update_fn = self._do_update_fn - @override(RLTrainer) + @override(Learner) def update( self, batch: MultiAgentBatch, @@ -214,12 +214,12 @@ def update( num_iters: int = 1, reduce_fn: Callable[[ResultDict], ResultDict] = ..., ) -> Mapping[str, Any]: - # TODO (Kourosh): The update of rl_trainer is vastly differnet than the base + # TODO (Kourosh): The update of learner is vastly differnet than the base # class. So we need to unify them. if set(batch.policy_batches.keys()) != set(self._module.keys()): raise ValueError( - "Batch keys must match module keys. RLTrainer does not " + "Batch keys must match module keys. Learner does not " "currently support training of only some modules and not others" ) diff --git a/rllib/core/rl_trainer/torch/__init__.py b/rllib/core/learner/torch/__init__.py similarity index 100% rename from rllib/core/rl_trainer/torch/__init__.py rename to rllib/core/learner/torch/__init__.py diff --git a/rllib/core/rl_trainer/torch/tests/__init__.py b/rllib/core/learner/torch/tests/__init__.py similarity index 100% rename from rllib/core/rl_trainer/torch/tests/__init__.py rename to rllib/core/learner/torch/tests/__init__.py diff --git a/rllib/core/rl_trainer/torch/tests/test_torch_rl_trainer.py b/rllib/core/learner/torch/tests/test_torch_learner.py similarity index 76% rename from rllib/core/rl_trainer/torch/tests/test_torch_rl_trainer.py rename to rllib/core/learner/torch/tests/test_torch_learner.py index 95b68da7ceb3..222c817760da 100644 --- a/rllib/core/rl_trainer/torch/tests/test_torch_rl_trainer.py +++ b/rllib/core/learner/torch/tests/test_torch_learner.py @@ -6,23 +6,23 @@ import ray from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec -from ray.rllib.core.rl_trainer.rl_trainer import RLTrainer +from ray.rllib.core.learner.learner import Learner from ray.rllib.core.testing.torch.bc_module import DiscreteBCTorchModule from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils.test_utils import check, get_cartpole_dataset_reader from ray.rllib.utils.numpy import convert_to_numpy -from ray.rllib.core.testing.utils import get_rl_trainer +from ray.rllib.core.testing.utils import get_learner -def _get_trainer() -> RLTrainer: +def _get_learner() -> Learner: env = gym.make("CartPole-v1") - trainer = get_rl_trainer("torch", env) - trainer.build() + learner = get_learner("torch", env) + learner.build() - return trainer + return learner -class TestRLTrainer(unittest.TestCase): +class TestLearner(unittest.TestCase): @classmethod def setUp(cls) -> None: ray.init() @@ -33,13 +33,13 @@ def tearDown(cls) -> None: def test_end_to_end_update(self): - trainer = _get_trainer() + learner = _get_learner() reader = get_cartpole_dataset_reader(batch_size=512) min_loss = float("inf") for iter_i in range(1000): batch = reader.next() - results = trainer.update(batch.as_multi_agent()) + results = learner.update(batch.as_multi_agent()) loss = results["loss"]["total_loss"] min_loss = min(loss, min_loss) @@ -56,11 +56,11 @@ def test_compute_gradients(self): Tests that if we sum all the trainable variables the gradient of output w.r.t. the weights is all ones. """ - trainer = _get_trainer() + learner = _get_learner() - params = trainer.get_parameters(trainer.module[DEFAULT_POLICY_ID]) + params = learner.get_parameters(learner.module[DEFAULT_POLICY_ID]) loss = {"total_loss": sum([param.sum() for param in params])} - gradients = trainer.compute_gradients(loss) + gradients = learner.compute_gradients(loss) # type should be a mapping from ParamRefs to gradients self.assertIsInstance(gradients, dict) @@ -75,19 +75,19 @@ def test_apply_gradients(self): standard SGD/Adam update rule. """ - trainer = _get_trainer() + learner = _get_learner() # calculated the expected new params based on gradients of all ones. - params = trainer.get_parameters(trainer.module[DEFAULT_POLICY_ID]) + params = learner.get_parameters(learner.module[DEFAULT_POLICY_ID]) n_steps = 100 expected = [ convert_to_numpy(param) - - n_steps * trainer._optimizer_config["lr"] * np.ones(param.shape) + - n_steps * learner._optimizer_config["lr"] * np.ones(param.shape) for param in params ] for _ in range(n_steps): - gradients = {trainer.get_param_ref(p): torch.ones_like(p) for p in params} - trainer.apply_gradients(gradients) + gradients = {learner.get_param_ref(p): torch.ones_like(p) for p in params} + learner.apply_gradients(gradients) check(params, expected) @@ -99,7 +99,7 @@ def test_add_remove_module(self): all variables the updated parameters follow the SGD update rule. """ env = gym.make("CartPole-v1") - trainer = _get_trainer() + learner = _get_learner() # add a test module with SGD optimizer with a known lr lr = 1e-4 @@ -107,7 +107,7 @@ def test_add_remove_module(self): def set_optimizer_fn(module): return [(module.parameters(), torch.optim.Adam(module.parameters(), lr=lr))] - trainer.add_module( + learner.add_module( module_id="test", module_spec=SingleAgentRLModuleSpec( module_class=DiscreteBCTorchModule, @@ -118,13 +118,13 @@ def set_optimizer_fn(module): set_optimizer_fn=set_optimizer_fn, ) - trainer.remove_module(DEFAULT_POLICY_ID) + learner.remove_module(DEFAULT_POLICY_ID) # only test module should be left - self.assertEqual(set(trainer.module.keys()), {"test"}) + self.assertEqual(set(learner.module.keys()), {"test"}) # calculated the expected new params based on gradients of all ones. - params = trainer.get_parameters(trainer.module["test"]) + params = learner.get_parameters(learner.module["test"]) n_steps = 100 expected = [ convert_to_numpy(param) - n_steps * lr * np.ones(param.shape) @@ -132,8 +132,8 @@ def set_optimizer_fn(module): ] for _ in range(n_steps): loss = {"total_loss": sum([param.sum() for param in params])} - gradients = trainer.compute_gradients(loss) - trainer.apply_gradients(gradients) + gradients = learner.compute_gradients(loss) + learner.apply_gradients(gradients) check(params, expected) diff --git a/rllib/core/rl_trainer/torch/torch_rl_trainer.py b/rllib/core/learner/torch/torch_learner.py similarity index 95% rename from rllib/core/rl_trainer/torch/torch_rl_trainer.py rename to rllib/core/learner/torch/torch_learner.py index b86daf604a21..5069f6821707 100644 --- a/rllib/core/rl_trainer/torch/torch_rl_trainer.py +++ b/rllib/core/learner/torch/torch_learner.py @@ -17,9 +17,9 @@ SingleAgentRLModuleSpec, ) from ray.rllib.core.rl_module.marl_module import MultiAgentRLModule -from ray.rllib.core.rl_trainer.rl_trainer import ( +from ray.rllib.core.learner.learner import ( FrameworkHPs, - RLTrainer, + Learner, ParamOptimizerPairs, Optimizer, ParamType, @@ -43,7 +43,7 @@ logger = logging.getLogger(__name__) -class TorchRLTrainer(RLTrainer): +class TorchLearner(Learner): framework: str = "torch" def __init__( @@ -57,7 +57,7 @@ def __init__( # will be set during build self._device = None - @override(RLTrainer) + @override(Learner) def configure_optimizers(self) -> ParamOptimizerPairs: """Configures the optimizers for the Learner. @@ -74,7 +74,7 @@ def configure_optimizers(self) -> ParamOptimizerPairs: for key in self._module.keys() ] - @override(RLTrainer) + @override(Learner) def compute_gradients( self, loss: Union[TensorType, Mapping[str, Any]] ) -> ParamDictType: @@ -86,7 +86,7 @@ def compute_gradients( return grads - @override(RLTrainer) + @override(Learner) def apply_gradients(self, gradients: ParamDictType) -> None: # make sure the parameters do not carry gradients on their own for optim in self._optim_to_param: @@ -100,7 +100,7 @@ def apply_gradients(self, gradients: ParamDictType) -> None: for optim in self._optim_to_param: optim.step() - @override(RLTrainer) + @override(Learner) def get_weights(self, module_ids: Optional[Set[str]] = None) -> Mapping[str, Any]: """Returns the weights of the underlying MultiAgentRLModule""" module_weights = self._module.get_state() @@ -111,21 +111,21 @@ def get_weights(self, module_ids: Optional[Set[str]] = None) -> Mapping[str, Any {k: v for k, v in module_weights.items() if k in module_ids} ) - @override(RLTrainer) + @override(Learner) def set_weights(self, weights: Mapping[str, Any]) -> None: """Sets the weights of the underlying MultiAgentRLModule""" weights = convert_to_torch_tensor(weights, device=self._device) return self._module.set_state(weights) - @override(RLTrainer) + @override(Learner) def get_param_ref(self, param: ParamType) -> Hashable: return param - @override(RLTrainer) + @override(Learner) def get_parameters(self, module: RLModule) -> Sequence[ParamType]: return list(module.parameters()) - @override(RLTrainer) + @override(Learner) def get_optimizer_obj( self, module: RLModule, optimizer_cls: Type[Optimizer] ) -> Optimizer: @@ -134,13 +134,13 @@ def get_optimizer_obj( lr = self._optimizer_config["lr"] return optimizer_cls(module.parameters(), lr=lr) - @override(RLTrainer) + @override(Learner) def _convert_batch_type(self, batch: MultiAgentBatch): batch = convert_to_torch_tensor(batch.policy_batches, device=self._device) batch = NestedDict(batch) return batch - @override(RLTrainer) + @override(Learner) def add_module( self, *, @@ -163,7 +163,7 @@ def add_module( module_id, TorchDDPRLModule(self._module[module_id]), override=True ) - @override(RLTrainer) + @override(Learner) def build(self) -> None: """Builds the TorchLearner. @@ -212,7 +212,7 @@ def build(self) -> None: key, TorchDDPRLModule(self._module[key]), override=True ) - @override(RLTrainer) + @override(Learner) def _make_module(self) -> MultiAgentRLModule: module = super()._make_module() self._map_module_to_device(module) diff --git a/rllib/core/rl_module/rl_module.py b/rllib/core/rl_module/rl_module.py index 776c50806474..9db414985720 100644 --- a/rllib/core/rl_module/rl_module.py +++ b/rllib/core/rl_module/rl_module.py @@ -333,7 +333,7 @@ def _forward_exploration(self, batch: NestedDict, **kwargs) -> Mapping[str, Any] @check_input_specs("_input_specs_train") @check_output_specs("_output_specs_train") def forward_train(self, batch: SampleBatchType, **kwargs) -> Mapping[str, Any]: - """Forward-pass during training called from the trainer. This method should + """Forward-pass during training called from the learner. This method should not be overriden. Instead, override the _forward_train method. Args: diff --git a/rllib/core/rl_module/torch/torch_rl_module.py b/rllib/core/rl_module/torch/torch_rl_module.py index 03f27693ce29..dcf4e5759678 100644 --- a/rllib/core/rl_module/torch/torch_rl_module.py +++ b/rllib/core/rl_module/torch/torch_rl_module.py @@ -69,9 +69,9 @@ def set_state(self, *args, **kwargs): @override(RLModule) def make_distributed(self, dist_config: Mapping[str, Any] = None) -> None: # TODO (Kourosh): Not to sure about this make_distributed api belonging to - # RLModule or the RLTrainer? For now the logic is kept in RLTrainer. + # RLModule or the Learner? For now the logic is kept in Learner. # We should see if we can use this api end-point for both tf - # and torch instead of doing it in the trainer. + # and torch instead of doing it in the learner. pass @override(RLModule) diff --git a/rllib/core/rl_trainer/__init__.py b/rllib/core/rl_trainer/__init__.py deleted file mode 100644 index 872c801aec49..000000000000 --- a/rllib/core/rl_trainer/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -from ray.rllib.core.rl_trainer.rl_trainer import RLTrainer -from ray.rllib.core.rl_trainer.trainer_runner import TrainerRunner -from ray.rllib.core.rl_trainer.trainer_runner_config import TrainerRunnerConfig - - -__all__ = [ - "RLTrainer", - "TrainerRunner", - "TrainerRunnerConfig", -] diff --git a/rllib/core/testing/bc_algorithm.py b/rllib/core/testing/bc_algorithm.py index 0f195ff4e8fd..40a5234e08cf 100644 --- a/rllib/core/testing/bc_algorithm.py +++ b/rllib/core/testing/bc_algorithm.py @@ -8,9 +8,9 @@ from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2 from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2 from ray.rllib.core.testing.torch.bc_module import DiscreteBCTorchModule -from ray.rllib.core.testing.torch.bc_rl_trainer import BCTorchRLTrainer +from ray.rllib.core.testing.torch.bc_learner import BCTorchLearner from ray.rllib.core.testing.tf.bc_module import DiscreteBCTFModule -from ray.rllib.core.testing.tf.bc_rl_trainer import BCTfRLTrainer +from ray.rllib.core.testing.tf.bc_learner import BCTfLearner from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec @@ -24,11 +24,11 @@ def get_default_rl_module_spec(self): elif self.framework_str == "tf2": return SingleAgentRLModuleSpec(module_class=DiscreteBCTFModule) - def get_default_rl_trainer_class(self): + def get_default_learner_class(self): if self.framework_str == "torch": - return BCTorchRLTrainer + return BCTorchLearner elif self.framework_str == "tf2": - return BCTfRLTrainer + return BCTfLearner class BCAlgorithmTest(Algorithm): diff --git a/rllib/core/testing/testing_trainer.py b/rllib/core/testing/testing_learner.py similarity index 90% rename from rllib/core/testing/testing_trainer.py rename to rllib/core/testing/testing_learner.py index 712b73859aae..ff7278e44a42 100644 --- a/rllib/core/testing/testing_trainer.py +++ b/rllib/core/testing/testing_learner.py @@ -1,12 +1,12 @@ from typing import Mapping, Any import numpy as np -from ray.rllib.core.rl_trainer.rl_trainer import RLTrainer +from ray.rllib.core.learner.learner import Learner from ray.rllib.utils.nested_dict import NestedDict from ray.rllib.utils.numpy import convert_to_numpy -class BaseTestingTrainer(RLTrainer): +class BaseTestingLearner(Learner): def compile_results( self, batch: NestedDict, diff --git a/rllib/core/testing/tests/test_bc_algorithm.py b/rllib/core/testing/tests/test_bc_algorithm.py index bc943f6882b6..591889699dba 100644 --- a/rllib/core/testing/tests/test_bc_algorithm.py +++ b/rllib/core/testing/tests/test_bc_algorithm.py @@ -18,7 +18,7 @@ from ray.rllib.examples.env.multi_agent import MultiAgentCartPole -class TestRLTrainer(unittest.TestCase): +class TestLearner(unittest.TestCase): @classmethod def setUpClass(cls) -> None: ray.init() @@ -33,7 +33,7 @@ def test_bc_algorithm(self): config = ( BCConfigTest() .rl_module(_enable_rl_module_api=True) - .training(_enable_rl_trainer_api=True, model={"fcnet_hiddens": [32, 32]}) + .training(_enable_learner_api=True, model={"fcnet_hiddens": [32, 32]}) ) # TODO (Kourosh): Add tf2 support @@ -54,7 +54,7 @@ def test_bc_algorithm_marl(self): config = ( BCConfigTest() .rl_module(_enable_rl_module_api=True) - .training(_enable_rl_trainer_api=True, model={"fcnet_hiddens": [32, 32]}) + .training(_enable_learner_api=True, model={"fcnet_hiddens": [32, 32]}) .multi_agent( policies=policies, policy_mapping_fn=lambda agent_id, **kwargs: list(policies)[agent_id], @@ -97,7 +97,7 @@ def test_bc_algorithm_w_custom_marl_module(self): .framework(fw) .rl_module(_enable_rl_module_api=True, rl_module_spec=spec) .training( - _enable_rl_trainer_api=True, + _enable_learner_api=True, model={"fcnet_hiddens": [32, 32]}, ) .multi_agent( diff --git a/rllib/core/testing/tf/bc_rl_trainer.py b/rllib/core/testing/tf/bc_learner.py similarity index 71% rename from rllib/core/testing/tf/bc_rl_trainer.py rename to rllib/core/testing/tf/bc_learner.py index fca29c502828..0bc0d782d094 100644 --- a/rllib/core/testing/tf/bc_rl_trainer.py +++ b/rllib/core/testing/tf/bc_learner.py @@ -1,14 +1,14 @@ import tensorflow as tf from typing import Any, Mapping -from ray.rllib.core.rl_trainer.tf.tf_rl_trainer import TfRLTrainer +from ray.rllib.core.learner.tf.tf_learner import TfLearner from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.core.testing.testing_trainer import BaseTestingTrainer +from ray.rllib.core.testing.testing_learner import BaseTestingLearner from ray.rllib.utils.typing import TensorType -class BCTfRLTrainer(TfRLTrainer, BaseTestingTrainer): +class BCTfLearner(TfLearner, BaseTestingLearner): def compute_loss_per_module( self, module_id: str, batch: SampleBatch, fwd_out: Mapping[str, TensorType] ) -> Mapping[str, Any]: diff --git a/rllib/core/testing/torch/bc_rl_trainer.py b/rllib/core/testing/torch/bc_learner.py similarity index 69% rename from rllib/core/testing/torch/bc_rl_trainer.py rename to rllib/core/testing/torch/bc_learner.py index 9de288b982fd..123d8566c1ef 100644 --- a/rllib/core/testing/torch/bc_rl_trainer.py +++ b/rllib/core/testing/torch/bc_learner.py @@ -1,13 +1,13 @@ import torch from typing import Any, Mapping -from ray.rllib.core.rl_trainer.torch.torch_rl_trainer import TorchRLTrainer -from ray.rllib.core.testing.testing_trainer import BaseTestingTrainer +from ray.rllib.core.learner.torch.torch_learner import TorchLearner +from ray.rllib.core.testing.testing_learner import BaseTestingLearner from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.typing import TensorType -class BCTorchRLTrainer(TorchRLTrainer, BaseTestingTrainer): +class BCTorchLearner(TorchLearner, BaseTestingLearner): def compute_loss_per_module( self, module_id: str, batch: SampleBatch, fwd_out: Mapping[str, TensorType] ) -> Mapping[str, Any]: diff --git a/rllib/core/testing/utils.py b/rllib/core/testing/utils.py index 15f268312817..5401253d0366 100644 --- a/rllib/core/testing/utils.py +++ b/rllib/core/testing/utils.py @@ -3,9 +3,9 @@ from ray.rllib.utils.annotations import DeveloperAPI -from ray.rllib.core.rl_trainer.trainer_runner import TrainerRunner -from ray.rllib.core.rl_trainer.rl_trainer import RLTrainerSpec, FrameworkHPs -from ray.rllib.core.rl_trainer.scaling_config import TrainerScalingConfig +from ray.rllib.core.learner.learner_group import LearnerGroup +from ray.rllib.core.learner.learner import LearnerSpec, FrameworkHPs +from ray.rllib.core.learner.scaling_config import LearnerGroupScalingConfig from ray.rllib.core.rl_module.marl_module import ( MultiAgentRLModuleSpec, @@ -18,7 +18,7 @@ import torch import tensorflow as tf - from ray.rllib.core.rl_trainer.rl_trainer import RLTrainer + from ray.rllib.core.learner.learner import Learner from ray.rllib.core.rl_module import RLModule @@ -26,15 +26,15 @@ @DeveloperAPI -def get_trainer_class(framework: str) -> Type["RLTrainer"]: +def get_learner_class(framework: str) -> Type["Learner"]: if framework == "tf": - from ray.rllib.core.testing.tf.bc_rl_trainer import BCTfRLTrainer + from ray.rllib.core.testing.tf.bc_learner import BCTfLearner - return BCTfRLTrainer + return BCTfLearner elif framework == "torch": - from ray.rllib.core.testing.torch.bc_rl_trainer import BCTorchRLTrainer + from ray.rllib.core.testing.torch.bc_learner import BCTorchLearner - return BCTorchRLTrainer + return BCTorchLearner else: raise ValueError(f"Unsupported framework: {framework}") @@ -88,26 +88,26 @@ def get_optimizer_default_class(framework: str) -> Type[Optimizer]: @DeveloperAPI -def get_rl_trainer( +def get_learner( framework: str, env: "gym.Env", is_multi_agent: bool = False, -) -> "RLTrainer": +) -> "Learner": - _cls = get_trainer_class(framework) + _cls = get_learner_class(framework) spec = get_module_spec(framework=framework, env=env, is_multi_agent=is_multi_agent) return _cls(module_spec=spec, optimizer_config={"lr": 0.1}) @DeveloperAPI -def get_trainer_runner( +def get_learner_group( framework: str, env: "gym.Env", - scaling_config: TrainerScalingConfig, + scaling_config: LearnerGroupScalingConfig, is_multi_agent: bool = False, eager_tracing: bool = False, -) -> TrainerRunner: - """Construct a trainer runner for testing. +) -> LearnerGroup: + """Construct a learner_group for testing. Args: framework: The framework used for training. @@ -119,35 +119,35 @@ def get_trainer_runner( optimizations. Returns: - A trainer runner. + A learner_group. """ if framework == "tf": - trainer_hps = FrameworkHPs(eager_tracing=eager_tracing) + learner_hps = FrameworkHPs(eager_tracing=eager_tracing) else: - trainer_hps = None - rl_trainer_spec = RLTrainerSpec( - rl_trainer_class=get_trainer_class(framework), + learner_hps = None + learner_spec = LearnerSpec( + learner_class=get_learner_class(framework), module_spec=get_module_spec( framework=framework, env=env, is_multi_agent=is_multi_agent ), optimizer_config={"lr": 0.1}, - trainer_scaling_config=scaling_config, - trainer_hyperparameters=trainer_hps, + learner_scaling_config=scaling_config, + learner_hyperparameters=learner_hps, ) - runner = TrainerRunner(rl_trainer_spec) + lg = LearnerGroup(learner_spec) - return runner + return lg @DeveloperAPI -def add_module_to_runner_or_trainer( +def add_module_to_learner_or_learner_group( framework: str, env: "gym.Env", module_id: str, - runner_or_trainer: Union[TrainerRunner, "RLTrainer"], + learner_group_or_learner: Union[LearnerGroup, "Learner"], ): - runner_or_trainer.add_module( + learner_group_or_learner.add_module( module_id=module_id, module_spec=get_module_spec(framework, env, is_multi_agent=False), optimizer_cls=get_optimizer_default_class(framework), diff --git a/rllib/evaluation/worker_set.py b/rllib/evaluation/worker_set.py index e5e7243df6d3..b07268f51043 100644 --- a/rllib/evaluation/worker_set.py +++ b/rllib/evaluation/worker_set.py @@ -18,7 +18,7 @@ from ray.actor import ActorHandle from ray.exceptions import RayActorError -from ray.rllib.core.rl_trainer import TrainerRunner +from ray.rllib.core.learner import LearnerGroup from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.utils.actor_manager import RemoteCallResults from ray.rllib.env.base_env import BaseEnv @@ -382,21 +382,21 @@ def num_remote_worker_restarts(self) -> int: def sync_weights( self, policies: Optional[List[PolicyID]] = None, - from_worker_or_trainer: Optional[Union[RolloutWorker, TrainerRunner]] = None, + from_worker_or_trainer: Optional[Union[RolloutWorker, LearnerGroup]] = None, to_worker_indices: Optional[List[int]] = None, global_vars: Optional[Dict[str, TensorType]] = None, timeout_seconds: Optional[int] = 0, ) -> None: """Syncs model weights from the given weight source to all remote workers. - Weight source can be either a (local) rollout worker or a trainer runner. It + Weight source can be either a (local) rollout worker or a learner_group. It should just implement a `get_weights` method. Args: policies: Optional list of PolicyIDs to sync weights for. If None (default), sync weights to/from all policies. from_worker_or_trainer: Optional (local) RolloutWorker instance or - TrainerRunner instance to sync from. If None (default), + LearnerGroup instance to sync from. If None (default), sync from this WorkerSet's local worker. to_worker_indices: Optional list of worker indices to sync the weights to. If None (default), sync to all remote workers. diff --git a/rllib/examples/rl_trainer/multi_agent_cartpole_ppo.py b/rllib/examples/learner/multi_agent_cartpole_ppo.py similarity index 95% rename from rllib/examples/rl_trainer/multi_agent_cartpole_ppo.py rename to rllib/examples/learner/multi_agent_cartpole_ppo.py index 319b0b432f83..8f62f8beeb55 100644 --- a/rllib/examples/rl_trainer/multi_agent_cartpole_ppo.py +++ b/rllib/examples/learner/multi_agent_cartpole_ppo.py @@ -30,7 +30,7 @@ parser.add_argument("--num-policies", type=int, default=2) parser.add_argument( "--framework", - choices=["tf2", "torch"], # tf will be deprecated with the new RLTrainer stack + choices=["tf2", "torch"], # tf will be deprecated with the new Learner stack default="torch", help="The DL framework specifier.", ) @@ -89,8 +89,8 @@ def policy_mapping_fn(agent_id, episode, worker, **kwargs): return pol_id scaling_config = { - "num_trainer_workers": args.num_gpus, - "num_gpus_per_trainer_worker": int(args.num_gpus > 0), + "num_learner_workers": args.num_gpus, + "num_gpus_per_learner_worker": int(args.num_gpus > 0), } config = ( @@ -101,7 +101,7 @@ def policy_mapping_fn(agent_id, episode, worker, **kwargs): .training(num_sgd_iter=10) .multi_agent(policies=policies, policy_mapping_fn=policy_mapping_fn) .rl_module(_enable_rl_module_api=True) - .training(_enable_rl_trainer_api=True) + .training(_enable_learner_api=True) .resources(**scaling_config) ) diff --git a/rllib/policy/eager_tf_policy.py b/rllib/policy/eager_tf_policy.py index e89b47afe0b8..aa6f747f20c9 100644 --- a/rllib/policy/eager_tf_policy.py +++ b/rllib/policy/eager_tf_policy.py @@ -763,7 +763,7 @@ def export_model(self, export_dir, onnx: Optional[int] = None) -> None: within this TfModelV2 class that is-a tf.keras.Model. This base model will be used here for the export. TODO (kourosh): This restriction will be resolved once we move Policy and - ModelV2 to the new RLTrainer/RLModule APIs. + ModelV2 to the new Learner/RLModule APIs. Args: export_dir: Local writable directory. diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index 193c78abfa94..0c6ab850f6b8 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -1218,7 +1218,7 @@ def _get_num_gpus_for_policy(self) -> int: # if we are in the new rl trainer world num_gpus is deprecated. # so use num_gpus_per_worker for policy sampling # we need this .get() syntax here to ensure backwards compatibility. - if self.config.get("_enable_rl_trainer_api", False): + if self.config.get("_enable_learner_api", False): num_gpus = self.config["num_gpus_per_worker"] else: # If head node, take num_gpus.