From 5eefcba3ef21a3accabfb6ccc5a84f9022d1ade7 Mon Sep 17 00:00:00 2001 From: Peyton Murray Date: Tue, 4 Apr 2023 17:53:15 -0700 Subject: [PATCH] Global logging format changes (#32741) This PR changes how the logging configuration for Ray is set, and changes the format of log messages. After a discussion with @rkooo567 we've decided to split the logging changes into multiple PRs. This is the first in a series which makes changes to library-level logging code for Ray. Signed-off-by: Jack He --- ci/lint/check_api_annotations.py | 8 +- dashboard/agent.py | 18 ++- doc/source/ray-observability/ray-logging.rst | 71 +++++++--- python/ray/__init__.py | 2 + python/ray/_private/log.py | 133 ++++++++++++++++++ python/ray/_private/ray_logging.py | 10 -- python/ray/_private/test_utils.py | 30 ++++ python/ray/_private/worker.py | 2 + python/ray/serve/tests/test_logging.py | 24 ++++ python/ray/tests/test_cli.py | 4 +- python/ray/tests/test_get_or_create_actor.py | 2 +- python/ray/tests/test_logging.py | 72 +++++++++- python/ray/tests/test_multi_node_3.py | 28 ++-- python/ray/tests/test_output.py | 71 ++++++---- .../tests/test_runtime_env_working_dir_2.py | 8 +- python/ray/tune/automl/search_policy.py | 2 +- python/ray/tune/tests/test_commands.py | 66 +++++---- 17 files changed, 442 insertions(+), 109 deletions(-) create mode 100644 python/ray/_private/log.py diff --git a/ci/lint/check_api_annotations.py b/ci/lint/check_api_annotations.py index 2445afc1a72d..4f5ee74b3ce8 100755 --- a/ci/lint/check_api_annotations.py +++ b/ci/lint/check_api_annotations.py @@ -99,7 +99,13 @@ def verify(symbol, scanned, ok, output, prefix=None, ignore=None): verify(ray.air, set(), ok, output) verify(ray.train, set(), ok, output) verify(ray.tune, set(), ok, output) - verify(ray, set(), ok, output, ignore=["ray.workflow", "ray.tune", "ray.serve"]) + verify( + ray, + set(), + ok, + output, + ignore=["ray.workflow", "ray.tune", "ray.serve"], + ) verify(ray.serve, set(), ok, output) assert len(ok) >= 500, len(ok) # TODO(ekl) enable it for all modules. diff --git a/dashboard/agent.py b/dashboard/agent.py index 8b01695dd300..1bc58af1ab13 100644 --- a/dashboard/agent.py +++ b/dashboard/agent.py @@ -5,6 +5,7 @@ import logging import logging.handlers import os +import pathlib import sys import signal @@ -17,7 +18,10 @@ from ray.dashboard.consts import _PARENT_DEATH_THREASHOLD from ray._private.gcs_pubsub import GcsAioPublisher, GcsPublisher from ray._private.gcs_utils import GcsAioClient, GcsClient -from ray._private.ray_logging import setup_component_logger +from ray._private.ray_logging import ( + setup_component_logger, + configure_log_file, +) from ray.core.generated import agent_manager_pb2, agent_manager_pb2_grpc from ray.experimental.internal_kv import ( _initialize_internal_kv, @@ -338,6 +342,14 @@ async def _check_parent(): await self.http_server.cleanup() +def open_capture_files(log_dir): + filename = f"agent-{args.agent_id}" + return ( + ray._private.utils.open_log(pathlib.Path(log_dir) / f"{filename}.out"), + ray._private.utils.open_log(pathlib.Path(log_dir) / f"{filename}.err"), + ) + + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Dashboard agent.") parser.add_argument( @@ -504,6 +516,10 @@ async def _check_parent(): # w.r.t grpc server init in the DashboardAgent initializer. loop = ray._private.utils.get_or_create_event_loop() + # Setup stdout/stderr redirect files + out_file, err_file = open_capture_files(args.log_dir) + configure_log_file(out_file, err_file) + agent = DashboardAgent( args.node_ip_address, args.dashboard_agent_port, diff --git a/doc/source/ray-observability/ray-logging.rst b/doc/source/ray-observability/ray-logging.rst index 8b236d5a150d..254218348667 100644 --- a/doc/source/ray-observability/ray-logging.rst +++ b/doc/source/ray-observability/ray-logging.rst @@ -2,7 +2,43 @@ Logging ======= -This document will explain Ray's logging system and its best practices. +This document explains Ray's logging system and related best practices. + +Internal Ray Logging Configuration +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +When ``import ray`` is executed, Ray's logger is initialized, generating a sensible configuration given in ``python/ray/_private/log.py``. The default logging level is ``logging.INFO``. + +All ray loggers are automatically configured in ``ray._private.ray_logging``. To change the Ray library logging configuration: + +.. code-block:: python + + import logging + + logger = logging.getLogger("ray") + logger # Modify the ray logging config + +Similarly, to modify the logging configuration for any Ray subcomponent, specify the appropriate logger name: + +.. code-block:: python + + import logging + + # First, get the handle for the logger you want to modify + ray_data_logger = logging.getLogger("ray.data") + ray_tune_logger = logging.getLogger("ray.tune") + ray_rllib_logger = logging.getLogger("ray.rllib") + ray_air_logger = logging.getLogger("ray.air") + ray_train_logger = logging.getLogger("ray.train") + ray_workflow_logger = logging.getLogger("ray.workflow") + + # Modify the ray.data logging level + ray_data_logger.setLevel(logging.WARNING) + + # Other loggers can be modified similarly. + # Here's how to add an aditional file handler for ray tune: + ray_tune_logger.addHandler(logging.FileHandler("extra_ray_tune_log.log")) + +For more information about logging in workers, see :ref:`Customizing worker loggers`. Driver logs ~~~~~~~~~~~ @@ -16,12 +52,12 @@ The log file consists of the stdout of the entrypoint command of the job. For t .. _ray-worker-logs: -Worker logs +Worker stdout and stderr ~~~~~~~~~~~ -Ray's tasks or actors are executed remotely within Ray's worker processes. Ray has special support to improve the visibility of logs produced by workers. +Ray's tasks or actors are executed remotely within Ray's worker processes. Ray has special support to improve the visibility of stdout and stderr produced by workers. -- By default, all of the tasks/actors stdout and stderr are redirected to the worker log files. Check out :ref:`Logging directory structure ` to learn how Ray's logging directory is structured. -- By default, all of the tasks/actors stdout and stderr that is redirected to worker log files are published to the driver. Drivers display logs generated from its tasks/actors to its stdout and stderr. +- By default, stdout and stderr from all tasks and actors are redirected to the worker log files, including any log messages generated by the worker. See :ref:`Logging directory structure ` to understand the structure of the Ray logging directory. +- By default, the driver reads the worker log files to which the stdout and stderr for all tasks and actors are redirected. Drivers display all stdout and stderr generated from their tasks or actors to their own stdout and stderr. Let's look at a code example to see how this works. @@ -37,7 +73,7 @@ Let's look at a code example to see how this works. ray.get(task.remote()) -You should be able to see the string `task` from your driver stdout. +You should be able to see the string `task` from your driver stdout. When logs are printed, the process id (pid) and an IP address of the node that executes tasks/actors are printed together. Check out the output below. @@ -129,10 +165,9 @@ Limitations: Tip: To avoid `print` statements from the driver conflicting with tqdm output, use `ray.experimental.tqdm_ray.safe_print` instead. -How to set up loggers +Customizing Worker Loggers ~~~~~~~~~~~~~~~~~~~~~ -When using ray, all of the tasks and actors are executed remotely in Ray's worker processes. -Since Python logger module creates a singleton logger per process, loggers should be configured on per task/actor basis. +When using Ray, all tasks and actors are executed remotely in Ray's worker processes. .. note:: @@ -154,7 +189,8 @@ Since Python logger module creates a singleton logger per process, loggers shoul logging.basicConfig(level=logging.INFO) def log(self, msg): - logging.info(msg) + logger = logging.getLogger(__name__) + logger.info(msg) actor = Actor.remote() ray.get(actor.log.remote("A log message for an actor.")) @@ -162,14 +198,15 @@ Since Python logger module creates a singleton logger per process, loggers shoul @ray.remote def f(msg): logging.basicConfig(level=logging.INFO) - logging.info(msg) + logger = logging.getLogger(__name__) + logger.info(msg) - ray.get(f.remote("A log message for a task")) + ray.get(f.remote("A log message for a task.")) .. code-block:: bash - (pid=95193) INFO:root:A log message for a task - (pid=95192) INFO:root:A log message for an actor. + (Actor pid=179641) INFO:__main__:A log message for an actor. + (f pid=177572) INFO:__main__:A log message for a task. How to use structured logging ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -196,11 +233,11 @@ Logging directory structure --------------------------- .. _logging-directory-structure: -By default, Ray logs are stored in a ``/tmp/ray/session_*/logs`` directory. +By default, Ray logs are stored in a ``/tmp/ray/session_*/logs`` directory. .. note:: - The default temp directory is ``/tmp/ray`` (for Linux and Mac OS). If you'd like to change the temp directory, you can specify it when ``ray start`` or ``ray.init()`` is called. + The default temp directory is ``/tmp/ray`` (for Linux and MacOS). To change the temp directory, specify it when you call ``ray start`` or ``ray.init()``. A new Ray instance creates a new session ID to the temp directory. The latest session ID is symlinked to ``/tmp/ray/session_latest``. @@ -225,7 +262,7 @@ Here's a Ray log directory structure. Note that ``.out`` is logs from stdout/std For the logs of the actual installations (including e.g. ``pip install`` logs), see the ``runtime_env_setup-[job_id].log`` file (see below). - ``runtime_env_setup-[job_id].log``: Logs from installing :ref:`runtime environments ` for a task, actor or job. This file will only be present if a runtime environment is installed. - ``runtime_env_setup-ray_client_server_[port].log``: Logs from installing :ref:`runtime environments ` for a job when connecting via :ref:`Ray Client `. -- ``worker-[worker_id]-[job_id]-[pid].[out|err]``: Python/Java part of Ray drivers and workers. All of stdout and stderr from tasks/actors are streamed here. Note that job_id is an id of the driver.- +- ``worker-[worker_id]-[job_id]-[pid].[out|err]``: Python or Java part of Ray drivers and workers. All of stdout and stderr from tasks or actors are streamed here. Note that job_id is an id of the driver.- .. _ray-log-rotation: diff --git a/python/ray/__init__.py b/python/ray/__init__.py index 68a7a613fcf7..db059b6183fc 100644 --- a/python/ray/__init__.py +++ b/python/ray/__init__.py @@ -1,8 +1,10 @@ # isort: skip_file +from ray._private import log # isort: skip # noqa: F401 import logging import os import sys +log.generate_logging_config() logger = logging.getLogger(__name__) diff --git a/python/ray/_private/log.py b/python/ray/_private/log.py new file mode 100644 index 000000000000..eaa59fd8fc8d --- /dev/null +++ b/python/ray/_private/log.py @@ -0,0 +1,133 @@ +import logging +import re +from logging.config import dictConfig +import threading + + +class ContextFilter(logging.Filter): + """A filter that adds ray context info to log records. + + This filter adds a package name to append to the message as well as information + about what worker emitted the message, if applicable. + """ + + logger_regex = re.compile(r"ray(\.(?P\w+))?(\..*)?") + package_message_names = { + "air": "AIR", + "data": "Data", + "rllib": "RLlib", + "serve": "Serve", + "train": "Train", + "tune": "Tune", + "workflow": "Workflow", + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def filter(self, record: logging.LogRecord) -> bool: + """Add context information to the log record. + + This filter adds a package name from where the message was generated as + well as the worker IP address, if applicable. + + Args: + record: Record to be filtered + + Returns: + True if the record is to be logged, False otherwise. (This filter only + adds context, so records are always logged.) + """ + match = self.logger_regex.search(record.name) + if match and match["subpackage"] in self.package_message_names: + record.package = f"[Ray {self.package_message_names[match['subpackage']]}]" + else: + record.package = "" + + return True + + +class PlainRayHandler(logging.StreamHandler): + """A plain log handler. + + This handler writes to whatever sys.stderr points to at emit-time, + not at instantiation time. See docs for logging._StderrHandler. + """ + + def __init__(self): + super().__init__() + self.plain_handler = logging._StderrHandler() + self.plain_handler.level = self.level + self.plain_handler.formatter = logging.Formatter(fmt="%(message)s") + + def emit(self, record: logging.LogRecord): + """Emit the log message. + + If this is a worker, bypass fancy logging and just emit the log record. + If this is the driver, emit the message using the appropriate console handler. + + Args: + record: Log record to be emitted + """ + import ray + + if ( + hasattr(ray, "_private") + and ray._private.worker.global_worker.mode + == ray._private.worker.WORKER_MODE + ): + self.plain_handler.emit(record) + else: + logging._StderrHandler.emit(self, record) + + +logger_initialized = False +logging_config_lock = threading.Lock() + + +def generate_logging_config(): + """Generate the default Ray logging configuration.""" + with logging_config_lock: + global logger_initialized + if logger_initialized: + return + logger_initialized = True + + formatters = { + "plain": { + "datefmt": "[%Y-%m-%d %H:%M:%S]", + "format": "%(asctime)s %(package)s %(levelname)s %(name)s::%(message)s", + }, + } + filters = {"context_filter": {"()": ContextFilter}} + handlers = { + "default": { + "()": PlainRayHandler, + "formatter": "plain", + "filters": ["context_filter"], + } + } + + loggers = { + # Default ray logger; any log message that gets propagated here will be + # logged to the console + "ray": { + "level": "INFO", + "handlers": ["default"], + }, + # Special handling for ray.rllib: only warning-level messages passed through + # See https://github.com/ray-project/ray/pull/31858 for related PR + "ray.rllib": { + "level": "WARN", + }, + } + + dictConfig( + { + "version": 1, + "formatters": formatters, + "filters": filters, + "handlers": handlers, + "loggers": loggers, + } + ) diff --git a/python/ray/_private/ray_logging.py b/python/ray/_private/ray_logging.py index fb0609a64017..3f022d0bda04 100644 --- a/python/ray/_private/ray_logging.py +++ b/python/ray/_private/ray_logging.py @@ -20,8 +20,6 @@ from ray._private.utils import binary_to_hex from ray.util.debug import log_once -_default_handler = None - def setup_logger( logging_level: int, @@ -32,14 +30,6 @@ def setup_logger( if type(logging_level) is str: logging_level = logging.getLevelName(logging_level.upper()) logger.setLevel(logging_level) - global _default_handler - if _default_handler is None: - _default_handler = logging._StderrHandler() - logger.addHandler(_default_handler) - _default_handler.setFormatter(logging.Formatter(logging_format)) - # Setting this will avoid the message - # being propagated to the parent logger. - logger.propagate = False def setup_component_logger( diff --git a/python/ray/_private/test_utils.py b/python/ray/_private/test_utils.py index 034ae2ea5720..dacaf6381355 100644 --- a/python/ray/_private/test_utils.py +++ b/python/ray/_private/test_utils.py @@ -1842,3 +1842,33 @@ def get_current_unused_port(): port = sock.getsockname()[1] sock.close() return port + + +def search_words(string: str, words: str): + """Check whether each word is in the given string. + + Args: + string: String to search + words: Space-separated string of words to search for + """ + return [word in string for word in words.split(" ")] + + +def has_all_words(string: str, words: str): + """Check that string has all of the given words. + + Args: + string: String to search + words: Space-separated string of words to search for + """ + return all(search_words(string, words)) + + +def has_no_words(string, words): + """Check that string has none of the given words. + + Args: + string: String to search + words: Space-separated string of words to search for + """ + return not any(search_words(string, words)) diff --git a/python/ray/_private/worker.py b/python/ray/_private/worker.py index 0b899b1afc28..13e2162aa389 100644 --- a/python/ray/_private/worker.py +++ b/python/ray/_private/worker.py @@ -1271,6 +1271,8 @@ def init( """ if configure_logging: setup_logger(logging_level, logging_format or ray_constants.LOGGER_FORMAT) + else: + logging.getLogger("ray").handlers.clear() # Parse the hidden options: _enable_object_reconstruction: bool = kwargs.pop( diff --git a/python/ray/serve/tests/test_logging.py b/python/ray/serve/tests/test_logging.py index 561dceaa28b2..52e442b035b4 100644 --- a/python/ray/serve/tests/test_logging.py +++ b/python/ray/serve/tests/test_logging.py @@ -156,6 +156,30 @@ def __call__(self, *args): assert replica_tag not in f.getvalue() +def test_deprecated_deployment_logger(serve_instance, capfd): + # NOTE(edoakes): using this logger is no longer recommended as of Ray 1.13. + # The test is maintained for backwards compatibility. + logger = logging.getLogger("ray") + + @serve.deployment(name="counter") + class Counter: + def __init__(self): + self.count = 0 + + def __call__(self, request): + self.count += 1 + logger.info(f"count: {self.count}") + + serve.run(Counter.bind()) + requests.get("http://127.0.0.1:8000/counter/") + + def counter_log_success(): + err = capfd.readouterr().err + return "deployment" in err and "replica" in err and "count" in err + + wait_for_condition(counter_log_success) + + def test_context_information_in_logging(serve_instance): """Make sure all context information exist in the log message""" diff --git a/python/ray/tests/test_cli.py b/python/ray/tests/test_cli.py index 818fb361e41f..959eac6007bf 100644 --- a/python/ray/tests/test_cli.py +++ b/python/ray/tests/test_cli.py @@ -168,6 +168,7 @@ def _die_on_error(result): def _debug_check_line_by_line(result, expected_lines): + """Print the result and expected output line-by-line.""" output_lines = result.output.split("\n") i = 0 @@ -192,10 +193,9 @@ def _debug_check_line_by_line(result, expected_lines): if i < len(expected_lines): print("!!! ERROR: Expected extra lines (regex):") for line in expected_lines[i:]: - print(repr(line)) - assert False + assert False, (result.output, expected_lines) @contextmanager diff --git a/python/ray/tests/test_get_or_create_actor.py b/python/ray/tests/test_get_or_create_actor.py index f91055a3e248..17ec8692527b 100644 --- a/python/ray/tests/test_get_or_create_actor.py +++ b/python/ray/tests/test_get_or_create_actor.py @@ -96,7 +96,7 @@ def do_run(name): if "local Ray instance" not in line and "The object store" not in line: out.append(line) valid = "".join(out) - assert valid.strip() == "DONE", out_str + assert "DONE" in valid, out_str if __name__ == "__main__": diff --git a/python/ray/tests/test_logging.py b/python/ray/tests/test_logging.py index 5ef89b6cc761..6796a4d6c940 100644 --- a/python/ray/tests/test_logging.py +++ b/python/ray/tests/test_logging.py @@ -213,7 +213,12 @@ def f(): # Create a runtime env to make sure dashboard agent is alive. ray.get(f.options(runtime_env={"env_vars": {"A": "a", "B": "b"}}).remote()) - paths = list(log_dir_path.iterdir()) + # Filter out only paths that end in .log, .log.1, etc. + # These paths are handled by the logger; the others (.out, .err) are not. + paths = [] + for path in log_dir_path.iterdir(): + if re.search(r".*\.log(\.\d+)?", str(path)): + paths.append(path) def component_exist(component, paths): for path in paths: @@ -380,11 +385,11 @@ def print_after(_obj): assert msgs[0][0] == "done" -def test_log_redirect_to_stderr(shutdown_only, capfd): +def test_log_redirect_to_stderr(shutdown_only): log_components = { ray_constants.PROCESS_TYPE_DASHBOARD: "Dashboard head grpc address", - ray_constants.PROCESS_TYPE_DASHBOARD_AGENT: "Dashboard agent grpc address", + ray_constants.PROCESS_TYPE_DASHBOARD_AGENT: "", ray_constants.PROCESS_TYPE_GCS_SERVER: "Loading job table data", # No log monitor output if all components are writing to stderr. ray_constants.PROCESS_TYPE_LOG_MONITOR: "", @@ -438,7 +443,6 @@ def f(): # Make sure that the expected startup log records for each of the # components appears in the stderr stream. - # stderr = capfd.readouterr().err for component, canonical_record in log_components.items(): if not canonical_record: # Process not run or doesn't generate logs; skip. @@ -874,6 +878,66 @@ def test_ray_does_not_break_makeRecord(): logging.setLogRecordFactory(logging.LogRecord) +@pytest.mark.parametrize( + "logger_name,package_name", + ( + ("ray", ""), + ("ray.air", "[Ray AIR]"), + ("ray.data", "[Ray Data]"), + ("ray.rllib", "[Ray RLlib]"), + ("ray.serve", "[Ray Serve]"), + ("ray.train", "[Ray Train]"), + ("ray.tune", "[Ray Tune]"), + ("ray.workflow", "[Ray Workflow]"), + ), +) +def test_log_library_context(logger_name, package_name, caplog): + """Test that the log configuration injects the correct context into log messages.""" + logger = logging.getLogger(logger_name) + logger.critical("Test!") + assert ( + caplog.records[-1].package == package_name + ), "Missing ray package name in log record." + + +@pytest.mark.parametrize( + "logger_name,logger_level", + ( + ("ray", logging.INFO), + ("ray.air", logging.INFO), + ("ray.data", logging.INFO), + ("ray.rllib", logging.WARNING), + ("ray.serve", logging.INFO), + ("ray.train", logging.INFO), + ("ray.tune", logging.INFO), + ("ray.workflow", logging.INFO), + ), +) +@pytest.mark.parametrize( + "test_level", + ( + logging.NOTSET, + logging.DEBUG, + logging.INFO, + logging.WARNING, + logging.ERROR, + logging.CRITICAL, + ), +) +def test_log_level_settings(logger_name, logger_level, test_level, caplog): + """Test that logs of lower level than the ray subpackage is + configured for are rejected. + """ + logger = logging.getLogger(logger_name) + logger.log(test_level, "Test!") + + if test_level >= logger_level: + assert caplog.records, "Log message missing where one is expected." + assert caplog.records[-1].levelno == test_level, "Log message level mismatch." + else: + assert len(caplog.records) == 0, "Log message found where none are expected." + + if __name__ == "__main__": import sys diff --git a/python/ray/tests/test_multi_node_3.py b/python/ray/tests/test_multi_node_3.py index 4c7b03ac8dd7..48634c16ea31 100644 --- a/python/ray/tests/test_multi_node_3.py +++ b/python/ray/tests/test_multi_node_3.py @@ -24,7 +24,6 @@ def test_calling_start_ray_head(call_ray_stop_only): - # Test that we can call ray start with various command line # parameters. @@ -200,7 +199,6 @@ def test_calling_start_ray_head(call_ray_stop_only): def test_ray_start_non_head(call_ray_stop_only, monkeypatch): - # Test that we can call ray start to connect to an existing cluster. # Test starting Ray with a port specified. @@ -433,8 +431,7 @@ def f(): def test_multi_driver_logging(ray_start_regular): - address_info = ray_start_regular - address = address_info["address"] + address = ray_start_regular["address"] # ray.init(address=address) driver1_wait = Semaphore.options(name="driver1_wait").remote(value=0) @@ -479,10 +476,10 @@ def remote_print(s, file=None): """ p1 = run_string_as_driver_nonblocking( - driver_script_template.format(address, "driver1_wait", "1", "2") + driver_script_template.format(address, "driver1_wait", "message1", "message2") ) p2 = run_string_as_driver_nonblocking( - driver_script_template.format(address, "driver2_wait", "3", "4") + driver_script_template.format(address, "driver2_wait", "message3", "message4") ) ray.get(main_wait.acquire.remote()) @@ -492,29 +489,24 @@ def remote_print(s, file=None): ray.get(driver1_wait.release.remote()) ray.get(driver2_wait.release.remote()) - # At this point driver1 should receive '1' and driver2 '3' + # At this point driver1 should receive 'message1' and driver2 'message3' ray.get(main_wait.acquire.remote()) ray.get(main_wait.acquire.remote()) ray.get(driver1_wait.release.remote()) ray.get(driver2_wait.release.remote()) - # At this point driver1 should receive '2' and driver2 '4' + # At this point driver1 should receive 'message2' and driver2 'message4' ray.get(main_wait.acquire.remote()) ray.get(main_wait.acquire.remote()) driver1_out = p1.stdout.read().decode("ascii") driver2_out = p2.stdout.read().decode("ascii") - if sys.platform == "win32": - driver1_out = driver1_out.replace("\r", "") - driver2_out = driver2_out.replace("\r", "") - driver1_out_split = driver1_out.split("\n") - driver2_out_split = driver2_out.split("\n") - - assert driver1_out_split[0][-1] == "1", driver1_out_split - assert driver1_out_split[1][-1] == "2", driver1_out_split - assert driver2_out_split[0][-1] == "3", driver2_out_split - assert driver2_out_split[1][-1] == "4", driver2_out_split + + assert "message1" in driver1_out + assert "message2" in driver1_out + assert "message3" in driver2_out + assert "message4" in driver2_out @pytest.fixture diff --git a/python/ray/tests/test_output.py b/python/ray/tests/test_output.py index 6310c3517e00..55eb8c936664 100644 --- a/python/ray/tests/test_output.py +++ b/python/ray/tests/test_output.py @@ -4,6 +4,7 @@ import subprocess import sys import time +import importlib import pytest @@ -11,6 +12,8 @@ from ray._private.test_utils import ( run_string_as_driver, run_string_as_driver_nonblocking, + has_no_words, + has_all_words, ) @@ -36,19 +39,18 @@ def verbose(): assert out_str.count("[repeated 9x across cluster]") == 1 -def test_logger_config(): +def test_logger_config_with_ray_init(): + """Test that the logger is correctly configured when ray.init is called.""" + script = """ import ray ray.init(num_cpus=1) """ - proc = run_string_as_driver_nonblocking(script) - out_str = proc.stdout.read().decode("ascii") - err_str = proc.stderr.read().decode("ascii") - - print(out_str, err_str) - assert "INFO worker.py:" in err_str, err_str + out_str = run_string_as_driver(script) + assert "INFO" in out_str, out_str + assert "ray._private.worker" in out_str, out_str @pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.") @@ -502,17 +504,23 @@ def test_output_local_ray(): """ output = run_string_as_driver(script) lines = output.strip("\n").split("\n") - for line in lines: - print(line) lines = [line for line in lines if "The object store is using /tmp" not in line] - assert len(lines) == 1 - line = lines[0] - print(line) - assert "Started a local Ray instance." in line - if os.environ.get("RAY_MINIMAL") == "1": - assert "View the dashboard" not in line - else: - assert "View the dashboard" in line + assert len(lines) >= 1 + + try: + importlib.import_module("rich") + assert has_all_words(output, "Started a local Ray instance.") + + if os.environ.get("RAY_MINIMAL") == "1": + assert has_no_words(output, "View the dashboard"), output + else: + assert has_all_words(output, "View the dashboard"), output + except ModuleNotFoundError: + assert "Started a local Ray instance." in output + if os.environ.get("RAY_MINIMAL") == "1": + assert "View the dashboard" not in output + else: + assert "View the dashboard" in output def test_output_ray_cluster(call_ray_start): @@ -522,15 +530,26 @@ def test_output_ray_cluster(call_ray_start): """ output = run_string_as_driver(script) lines = output.strip("\n").split("\n") - for line in lines: - print(line) - assert len(lines) == 2 - assert "Connecting to existing Ray cluster at address:" in lines[0] - assert "Connected to Ray cluster." in lines[1] - if os.environ.get("RAY_MINIMAL") == "1": - assert "View the dashboard" not in lines[1] - else: - assert "View the dashboard" in lines[1] + assert len(lines) >= 1 + + try: + importlib.import_module("rich") + assert has_all_words( + output, "Connecting to existing Ray cluster at address:" + ), output + assert has_all_words(output, "Connected to Ray cluster."), output + + if os.environ.get("RAY_MINIMAL") == "1": + assert has_no_words(output, "View the dashboard"), output + else: + assert has_all_words(output, "View the dashboard"), output + except ModuleNotFoundError: + assert "Connecting to existing Ray cluster at address:" in output + assert "Connected to Ray cluster." in output + if os.environ.get("RAY_MINIMAL") == "1": + assert "View the dashboard" not in output + else: + assert "View the dashboard" in output @pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.") diff --git a/python/ray/tests/test_runtime_env_working_dir_2.py b/python/ray/tests/test_runtime_env_working_dir_2.py index ae4dd5de13e5..4afd28672506 100644 --- a/python/ray/tests/test_runtime_env_working_dir_2.py +++ b/python/ray/tests/test_runtime_env_working_dir_2.py @@ -4,10 +4,13 @@ import tempfile import pytest -from ray._private.test_utils import run_string_as_driver +from ray._private.test_utils import ( + chdir, + run_string_as_driver, +) + import ray -from ray._private.test_utils import chdir from ray._private.runtime_env import RAY_WORKER_DEV_EXCLUDES from ray._private.runtime_env.packaging import GCS_STORAGE_MAX_SIZE from ray.exceptions import RuntimeEnvSetupError @@ -147,7 +150,6 @@ def test_large_dir_upload_message(start_cluster, option): output = run_string_as_driver(driver_script) assert "Pushing file package" in output assert "Successfully pushed file package" in output - assert "warning" not in output.lower() # TODO(architkulkarni): Deflake and reenable this test. diff --git a/python/ray/tune/automl/search_policy.py b/python/ray/tune/automl/search_policy.py index deb46da332ef..497235980e6b 100644 --- a/python/ray/tune/automl/search_policy.py +++ b/python/ray/tune/automl/search_policy.py @@ -181,7 +181,7 @@ def on_trial_complete(self, trial_id, result=None, error=False): "reward_attr": self.reward_attr, "reward": self.best_trial.best_result[self.reward_attr] if self.best_trial - else None, + else 0, }, ) diff --git a/python/ray/tune/tests/test_commands.py b/python/ray/tune/tests/test_commands.py index efe8c6fc0077..70a9d6dea47d 100644 --- a/python/ray/tune/tests/test_commands.py +++ b/python/ray/tune/tests/test_commands.py @@ -4,6 +4,7 @@ import subprocess import sys import time +from unittest import mock try: from cStringIO import StringIO @@ -61,7 +62,11 @@ def test_time(start_ray, tmpdir): assert sum(times) / len(times) < 7.0, "CLI is taking too long!" -def test_ls(start_ray, tmpdir): +@mock.patch( + "ray.tune.cli.commands.print_format_output", + wraps=ray.tune.cli.commands.print_format_output, +) +def test_ls(mock_print_format_output, start_ray, tmpdir): """This test captures output of list_trials.""" experiment_name = "test_ls" experiment_path = os.path.join(str(tmpdir), experiment_name) @@ -76,23 +81,26 @@ def test_ls(start_ray, tmpdir): columns = ["episode_reward_mean", "training_iteration", "trial_id"] limit = 2 - with Capturing() as output: - commands.list_trials(experiment_path, info_keys=columns, limit=limit) - lines = output.captured - - assert all(col in lines[1] for col in columns) - assert lines[1].count("|") == len(columns) + 1 - assert len(lines) == 3 + limit + 1 - - with Capturing() as output: - commands.list_trials( - experiment_path, - sort=["trial_id"], - info_keys=("trial_id", "training_iteration"), - filter_op="training_iteration == 1", - ) - lines = output.captured - assert len(lines) == 3 + num_samples + 1 + commands.list_trials(experiment_path, info_keys=columns, limit=limit) + + # The dataframe that is printed as a table is the first arg of the last + # call made to `ray.tune.cli.commands.print_format_output`. + mock_print_format_output.assert_called() + args, _ = mock_print_format_output.call_args_list[-1] + df = args[0] + assert sorted(df.columns.to_list()) == sorted(columns), df + assert len(df.index) == limit, df + + commands.list_trials( + experiment_path, + sort=["trial_id"], + info_keys=("trial_id", "training_iteration"), + filter_op="training_iteration == 1", + ) + args, _ = mock_print_format_output.call_args_list[-1] + df = args[0] + assert sorted(df.columns.to_list()) == sorted(["trial_id", "training_iteration"]) + assert len(df.index) == num_samples with pytest.raises(click.ClickException): commands.list_trials( @@ -103,7 +111,11 @@ def test_ls(start_ray, tmpdir): commands.list_trials(experiment_path, info_keys=("asdf",)) -def test_ls_with_cfg(start_ray, tmpdir): +@mock.patch( + "ray.tune.cli.commands.print_format_output", + wraps=ray.tune.cli.commands.print_format_output, +) +def test_ls_with_cfg(mock_print_format_output, start_ray, tmpdir): experiment_name = "test_ls_with_cfg" experiment_path = os.path.join(str(tmpdir), experiment_name) tune.run( @@ -116,12 +128,16 @@ def test_ls_with_cfg(start_ray, tmpdir): columns = [CONFIG_PREFIX + "/test_variable", "trial_id"] limit = 4 - with Capturing() as output: - commands.list_trials(experiment_path, info_keys=columns, limit=limit) - lines = output.captured - assert all(col in lines[1] for col in columns) - assert lines[1].count("|") == len(columns) + 1 - assert len(lines) == 3 + limit + 1 + + commands.list_trials(experiment_path, info_keys=columns, limit=limit) + + # The dataframe that is printed as a table is the first arg of the last + # call made to `ray.tune.cli.commands.print_format_output`. + mock_print_format_output.assert_called() + args, _ = mock_print_format_output.call_args_list[-1] + df = args[0] + assert sorted(df.columns.to_list()) == sorted(columns), df + assert len(df.index) == limit, df def test_lsx(start_ray, tmpdir):