From 5b3dd988ad7f0f91a2f61f77a8b3380eae3131fd Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Jan 2024 21:31:25 +0000 Subject: [PATCH] [Feature] Remove and check for prints in codebase using flake8-print (#1758) --- .../unittest/helpers/coverage_run_parallel.py | 6 +- .pre-commit-config.yaml | 1 + benchmarks/benchmark_batched_envs.py | 9 ++- benchmarks/conftest.py | 4 +- benchmarks/ecosystem/gym_env_throughput.py | 1 - ...s_rllib_vs_torchrl_sampling_performance.py | 7 +- .../benchmark_sample_latency_over_rpc.py | 11 +-- examples/a2c/a2c_atari.py | 3 +- examples/a2c/a2c_mujoco.py | 3 +- examples/cql/cql_offline.py | 4 +- examples/cql/cql_online.py | 4 +- examples/cql/discrete_cql_online.py | 4 +- examples/ddpg/ddpg.py | 4 +- examples/decision_transformer/dt.py | 5 +- examples/decision_transformer/online_dt.py | 6 +- examples/discrete_sac/discrete_sac.py | 4 +- .../collectors/multi_nodes/delayed_dist.py | 3 +- .../collectors/multi_nodes/delayed_rpc.py | 3 +- .../collectors/multi_nodes/generic.py | 3 +- .../distributed/collectors/multi_nodes/ray.py | 3 +- .../collectors/multi_nodes/ray_train.py | 4 +- .../distributed/collectors/multi_nodes/rpc.py | 3 +- .../collectors/multi_nodes/sync.py | 3 +- .../collectors/single_machine/generic.py | 3 +- .../collectors/single_machine/rpc.py | 3 +- .../collectors/single_machine/sync.py | 3 +- .../distributed_replay_buffer.py | 33 ++++----- examples/dqn/dqn_atari.py | 4 +- examples/dqn/dqn_cartpole.py | 4 +- examples/dreamer/dreamer.py | 7 +- examples/impala/impala_multi_node_ray.py | 4 +- examples/impala/impala_multi_node_submitit.py | 4 +- examples/impala/impala_single_node.py | 4 +- examples/iql/iql_offline.py | 3 +- examples/iql/iql_online.py | 4 +- examples/multiagent/iql.py | 4 +- examples/multiagent/maddpg_iddpg.py | 4 +- examples/multiagent/mappo_ippo.py | 5 +- examples/multiagent/qmix_vdn.py | 4 +- examples/multiagent/sac.py | 4 +- examples/ppo/ppo_atari.py | 3 +- examples/ppo/ppo_mujoco.py | 4 +- examples/redq/utils.py | 5 +- examples/rlhf/models/reward.py | 3 +- examples/rlhf/models/transformer.py | 5 +- examples/rlhf/train.py | 7 +- examples/rlhf/train_reward.py | 7 +- examples/sac/sac.py | 4 +- examples/td3/td3.py | 4 +- setup.cfg | 5 ++ setup.py | 13 ++-- test/_utils_internal.py | 3 +- test/assets/generate.py | 1 - test/conftest.py | 1 - test/opengl_rendering.py | 4 +- test/test_collector.py | 3 +- test/test_distributed.py | 5 +- test/test_env.py | 3 - test/test_libs.py | 24 +++---- test/test_modules.py | 1 - test/test_rb_distributed.py | 3 +- test/test_shared.py | 19 ++--- torchrl/_utils.py | 13 ++-- torchrl/collectors/collectors.py | 16 ++--- torchrl/collectors/distributed/generic.py | 69 ++++++++++--------- torchrl/collectors/distributed/ray.py | 2 +- torchrl/collectors/distributed/rpc.py | 37 +++++----- torchrl/collectors/distributed/sync.py | 25 ++++--- torchrl/collectors/distributed/utils.py | 7 +- torchrl/data/datasets/d4rl.py | 8 ++- torchrl/data/datasets/minari_data.py | 13 ++-- torchrl/data/datasets/roboset.py | 21 +++--- torchrl/data/datasets/vd4rl.py | 5 +- torchrl/data/replay_buffers/storages.py | 9 +-- torchrl/data/rlhf/dataset.py | 3 +- torchrl/envs/batched_envs.py | 14 ++-- torchrl/envs/env_creator.py | 3 +- torchrl/envs/libs/dm_control.py | 3 +- torchrl/envs/libs/envpool.py | 2 +- torchrl/envs/transforms/vc1.py | 4 +- torchrl/envs/utils.py | 3 +- torchrl/record/loggers/csv.py | 1 - torchrl/trainers/helpers/envs.py | 5 +- torchrl/trainers/helpers/trainers.py | 3 +- torchrl/trainers/trainers.py | 3 +- 85 files changed, 327 insertions(+), 264 deletions(-) diff --git a/.github/unittest/helpers/coverage_run_parallel.py b/.github/unittest/helpers/coverage_run_parallel.py index 001eb6299b6..ca156b72c2d 100644 --- a/.github/unittest/helpers/coverage_run_parallel.py +++ b/.github/unittest/helpers/coverage_run_parallel.py @@ -11,7 +11,7 @@ nevertheless. It writes temporary coverage config files on the fly and invokes coverage with proper arguments """ - +import logging import os import shlex import subprocess @@ -45,7 +45,9 @@ def write_config(config_path: Path, argv: List[str]) -> None: def main(argv: List[str]) -> int: if len(argv) < 1: - print("Usage: 'python coverage_run_parallel.py [command arguments]'") + logging.info( + "Usage: 'python coverage_run_parallel.py [command arguments]'" + ) sys.exit(1) # The temporary config is written into a temp dir that will be deleted # including all contents on context exit. diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 971ea8516dc..532445125aa 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,6 +28,7 @@ repos: - flake8-bugbear==22.10.27 - flake8-comprehensions==3.10.1 - torchfix==0.0.2 + - flake8-print==5.0.0 - repo: https://github.com/PyCQA/pydocstyle rev: 6.1.1 diff --git a/benchmarks/benchmark_batched_envs.py b/benchmarks/benchmark_batched_envs.py index bc0cd57003a..3c21372a369 100644 --- a/benchmarks/benchmark_batched_envs.py +++ b/benchmarks/benchmark_batched_envs.py @@ -68,8 +68,8 @@ def run_env(env): devices.append("cuda") for device in devices: for num_workers in [1, 4, 16]: - print(f"With num_workers={num_workers}, {device}") - print("Multithreaded...") + logging.info(f"With num_workers={num_workers}, {device}") + logging.info("Multithreaded...") env_multithreaded = create_multithreaded(num_workers, device) res_multithreaded = Timer( stmt="run_env(env)", @@ -78,7 +78,7 @@ def run_env(env): ) time_multithreaded = res_multithreaded.blocked_autorange().mean - print("Serial...") + logging.info("Serial...") env_serial = create_serial(num_workers, device) res_serial = Timer( stmt="run_env(env)", @@ -87,7 +87,7 @@ def run_env(env): ) time_serial = res_serial.blocked_autorange().mean - print("Parallel...") + logging.info("Parallel...") env_parallel = create_parallel(num_workers, device) res_parallel = Timer( stmt="run_env(env)", @@ -96,7 +96,6 @@ def run_env(env): ) time_parallel = res_parallel.blocked_autorange().mean - print(time_serial, time_parallel, time_multithreaded) res[f"num_workers_{num_workers}_{device}"] = { "Serial, s": time_serial, "Parallel, s": time_parallel, diff --git a/benchmarks/conftest.py b/benchmarks/conftest.py index 7f320ff2e8d..bec558ac92d 100644 --- a/benchmarks/conftest.py +++ b/benchmarks/conftest.py @@ -2,7 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - +import logging import os import time import warnings @@ -32,7 +32,7 @@ def pytest_sessionfinish(maxprint=50): out_str += f"\t{key}{spaces}{item: 4.4f}s\n" if i == maxprint - 1: break - print(out_str) + logging.info(out_str) @pytest.fixture(autouse=True) diff --git a/benchmarks/ecosystem/gym_env_throughput.py b/benchmarks/ecosystem/gym_env_throughput.py index 246c5ee15f0..c69fc985ded 100644 --- a/benchmarks/ecosystem/gym_env_throughput.py +++ b/benchmarks/ecosystem/gym_env_throughput.py @@ -63,7 +63,6 @@ def make(envname=envname, gym_backend=gym_backend): global_step = 0 times = [] start = time.time() - print("Timer started.") for _ in tqdm.tqdm(range(total_frames // num_workers)): env.step(env.action_space.sample()) global_step += num_workers diff --git a/benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py b/benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py index daaf800353f..02526095a60 100644 --- a/benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +++ b/benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. +import logging import os import pickle @@ -164,11 +165,11 @@ def run_comparison_torchrl_rllib( evaluation = {} for framework in ["TorchRL", "RLlib"]: if framework not in evaluation.keys(): - print(f"\nFramework {framework}") + logging.info(f"\nFramework {framework}") vmas_times = [] for n_envs in list_n_envs: n_envs = int(n_envs) - print(f"Running {n_envs} environments") + logging.info(f"Running {n_envs} environments") if framework == "TorchRL": vmas_times.append( (n_envs * n_steps) @@ -189,7 +190,7 @@ def run_comparison_torchrl_rllib( device=device, ) ) - print(f"fps {vmas_times[-1]}s") + logging.info(f"fps {vmas_times[-1]}s") evaluation[framework] = vmas_times store_pickled_evaluation(name=figure_name_pkl, evaluation=evaluation) diff --git a/benchmarks/storage/benchmark_sample_latency_over_rpc.py b/benchmarks/storage/benchmark_sample_latency_over_rpc.py index be1055e8b8a..693cbb9a462 100644 --- a/benchmarks/storage/benchmark_sample_latency_over_rpc.py +++ b/benchmarks/storage/benchmark_sample_latency_over_rpc.py @@ -14,6 +14,7 @@ This code is based on examples/distributed/distributed_replay_buffer.py. """ import argparse +import logging import os import pickle import sys @@ -105,10 +106,10 @@ def _create_replay_buffer(self) -> rpc.RRef: buffer_rref = rpc.remote( replay_buffer_info, ReplayBufferNode, args=(1000000,) ) - print(f"Connected to replay buffer {replay_buffer_info}") + logging.info(f"Connected to replay buffer {replay_buffer_info}") return buffer_rref except Exception: - print("Failed to connect to replay buffer") + logging.info("Failed to connect to replay buffer") time.sleep(RETRY_DELAY_SECS) @@ -143,7 +144,7 @@ def __init__(self, capacity: int): rank = args.rank storage_type = args.storage - print(f"Rank: {rank}; Storage: {storage_type}") + logging.info(f"Rank: {rank}; Storage: {storage_type}") os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "29500" @@ -166,7 +167,7 @@ def __init__(self, capacity: int): if i == 0: continue results.append(result) - print(i, results[-1]) + logging.info(i, results[-1]) with open( f'./benchmark_{datetime.now().strftime("%d-%m-%Y%H:%M:%S")};batch_size={BATCH_SIZE};tensor_size={TENSOR_SIZE};repeat={REPEATS};storage={storage_type}.pkl', @@ -175,7 +176,7 @@ def __init__(self, capacity: int): pickle.dump(results, f) tensor_results = torch.tensor(results) - print(f"Mean: {torch.mean(tensor_results)}") + logging.info(f"Mean: {torch.mean(tensor_results)}") breakpoint() elif rank == 1: # rank 1 is the replay buffer diff --git a/examples/a2c/a2c_atari.py b/examples/a2c/a2c_atari.py index 4598c11844b..8d19080f223 100644 --- a/examples/a2c/a2c_atari.py +++ b/examples/a2c/a2c_atari.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging import hydra @@ -212,7 +213,7 @@ def main(cfg: "DictConfig"): # noqa: F821 end_time = time.time() execution_time = end_time - start_time - print(f"Training took {execution_time:.2f} seconds to finish") + logging.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/examples/a2c/a2c_mujoco.py b/examples/a2c/a2c_mujoco.py index 7f9e588bbf6..4076631f1ef 100644 --- a/examples/a2c/a2c_mujoco.py +++ b/examples/a2c/a2c_mujoco.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging import hydra @@ -197,7 +198,7 @@ def main(cfg: "DictConfig"): # noqa: F821 end_time = time.time() execution_time = end_time - start_time - print(f"Training took {execution_time:.2f} seconds to finish") + logging.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/examples/cql/cql_offline.py b/examples/cql/cql_offline.py index 87739763fa2..c33bce7d65b 100644 --- a/examples/cql/cql_offline.py +++ b/examples/cql/cql_offline.py @@ -9,7 +9,7 @@ The helper functions are coded in the utils.py associated with this script. """ - +import logging import time import hydra @@ -145,7 +145,7 @@ def main(cfg: "DictConfig"): # noqa: F821 log_metrics(logger, to_log, i) pbar.close() - print(f"Training time: {time.time() - start_time}") + logging.info(f"Training time: {time.time() - start_time}") if __name__ == "__main__": diff --git a/examples/cql/cql_online.py b/examples/cql/cql_online.py index 93427a0d8cf..4ee218da770 100644 --- a/examples/cql/cql_online.py +++ b/examples/cql/cql_online.py @@ -11,7 +11,7 @@ The helper functions are coded in the utils.py associated with this script. """ - +import logging import time import hydra @@ -206,7 +206,7 @@ def main(cfg: "DictConfig"): # noqa: F821 collector.shutdown() end_time = time.time() execution_time = end_time - start_time - print(f"Training took {execution_time:.2f} seconds to finish") + logging.info(f"Training took {execution_time:.2f} seconds to finish") collector.shutdown() diff --git a/examples/cql/discrete_cql_online.py b/examples/cql/discrete_cql_online.py index 0c93875ec9c..cc4f89d667e 100644 --- a/examples/cql/discrete_cql_online.py +++ b/examples/cql/discrete_cql_online.py @@ -10,7 +10,7 @@ The helper functions are coded in the utils.py associated with this script. """ - +import logging import time import hydra @@ -192,7 +192,7 @@ def main(cfg: "DictConfig"): # noqa: F821 collector.shutdown() end_time = time.time() execution_time = end_time - start_time - print(f"Training took {execution_time:.2f} seconds to finish") + logging.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/examples/ddpg/ddpg.py b/examples/ddpg/ddpg.py index 65e1919567c..1eb7af83e02 100644 --- a/examples/ddpg/ddpg.py +++ b/examples/ddpg/ddpg.py @@ -10,7 +10,7 @@ The helper functions are coded in the utils.py associated with this script. """ - +import logging import time import hydra @@ -192,7 +192,7 @@ def main(cfg: "DictConfig"): # noqa: F821 collector.shutdown() end_time = time.time() execution_time = end_time - start_time - print(f"Training took {execution_time:.2f} seconds to finish") + logging.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/examples/decision_transformer/dt.py b/examples/decision_transformer/dt.py index 8cd56f692bf..894562185d9 100644 --- a/examples/decision_transformer/dt.py +++ b/examples/decision_transformer/dt.py @@ -6,6 +6,7 @@ This is a self-contained example of an offline Decision Transformer training script. The helper functions are coded in the utils.py associated with this script. """ +import logging import time import hydra @@ -78,7 +79,7 @@ def main(cfg: "DictConfig"): # noqa: F821 pretrain_log_interval = cfg.logger.pretrain_log_interval reward_scaling = cfg.env.reward_scaling - print(" ***Pretraining*** ") + logging.info(" ***Pretraining*** ") # Pretraining start_time = time.time() for i in range(pretrain_gradient_steps): @@ -115,7 +116,7 @@ def main(cfg: "DictConfig"): # noqa: F821 log_metrics(logger, to_log, i) pbar.close() - print(f"Training time: {time.time() - start_time}") + logging.info(f"Training time: {time.time() - start_time}") if __name__ == "__main__": diff --git a/examples/decision_transformer/online_dt.py b/examples/decision_transformer/online_dt.py index 2fdb2f74cbf..a1df18e5fe6 100644 --- a/examples/decision_transformer/online_dt.py +++ b/examples/decision_transformer/online_dt.py @@ -6,7 +6,7 @@ This is a self-contained example of an Online Decision Transformer training script. The helper functions are coded in the utils.py associated with this script. """ - +import logging import time import hydra @@ -81,7 +81,7 @@ def main(cfg: "DictConfig"): # noqa: F821 pretrain_log_interval = cfg.logger.pretrain_log_interval reward_scaling = cfg.env.reward_scaling - print(" ***Pretraining*** ") + logging.info(" ***Pretraining*** ") # Pretraining start_time = time.time() for i in range(pretrain_gradient_steps): @@ -132,7 +132,7 @@ def main(cfg: "DictConfig"): # noqa: F821 log_metrics(logger, to_log, i) pbar.close() - print(f"Training time: {time.time() - start_time}") + logging.info(f"Training time: {time.time() - start_time}") if __name__ == "__main__": diff --git a/examples/discrete_sac/discrete_sac.py b/examples/discrete_sac/discrete_sac.py index 1ff922f41fb..1f052837b2d 100644 --- a/examples/discrete_sac/discrete_sac.py +++ b/examples/discrete_sac/discrete_sac.py @@ -10,7 +10,7 @@ The helper functions are coded in the utils.py associated with this script. """ - +import logging import time import hydra @@ -208,7 +208,7 @@ def main(cfg: "DictConfig"): # noqa: F821 collector.shutdown() end_time = time.time() execution_time = end_time - start_time - print(f"Training took {execution_time:.2f} seconds to finish") + logging.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/examples/distributed/collectors/multi_nodes/delayed_dist.py b/examples/distributed/collectors/multi_nodes/delayed_dist.py index b0fd091e3c0..e026912f698 100644 --- a/examples/distributed/collectors/multi_nodes/delayed_dist.py +++ b/examples/distributed/collectors/multi_nodes/delayed_dist.py @@ -23,6 +23,7 @@ and DEFAULT_SLURM_CONF_MAIN dictionaries below). """ +import logging import time from argparse import ArgumentParser @@ -149,7 +150,7 @@ def make_env(): if i == 10: t0 = time.time() t1 = time.time() - print(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps") + logging.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps") collector.shutdown() exit() diff --git a/examples/distributed/collectors/multi_nodes/delayed_rpc.py b/examples/distributed/collectors/multi_nodes/delayed_rpc.py index 7cba1eeef05..0f38d898dfc 100644 --- a/examples/distributed/collectors/multi_nodes/delayed_rpc.py +++ b/examples/distributed/collectors/multi_nodes/delayed_rpc.py @@ -23,6 +23,7 @@ and DEFAULT_SLURM_CONF_MAIN dictionaries below). """ +import logging import time from argparse import ArgumentParser @@ -147,7 +148,7 @@ def make_env(): if i == 10: t0 = time.time() t1 = time.time() - print(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps") + logging.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps") collector.shutdown() exit() diff --git a/examples/distributed/collectors/multi_nodes/generic.py b/examples/distributed/collectors/multi_nodes/generic.py index aa27059a214..07c83ba98fb 100644 --- a/examples/distributed/collectors/multi_nodes/generic.py +++ b/examples/distributed/collectors/multi_nodes/generic.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging import time from argparse import ArgumentParser @@ -127,5 +128,5 @@ def gym_make(): t0 = time.time() collector.shutdown() t1 = time.time() - print(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps") + logging.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps") exit() diff --git a/examples/distributed/collectors/multi_nodes/ray.py b/examples/distributed/collectors/multi_nodes/ray.py index 588fba55038..21d550281a2 100644 --- a/examples/distributed/collectors/multi_nodes/ray.py +++ b/examples/distributed/collectors/multi_nodes/ray.py @@ -7,6 +7,7 @@ This example should create 3 collector instances, 1 local and 2 remote, but 4 instances seem to be created. Why? """ +import logging from tensordict.nn import TensorDictModule from torch import nn @@ -44,4 +45,4 @@ def env_maker(): for batch in distributed_collector: counter += 1 num_frames += batch.shape.numel() - print(f"batch {counter}, total frames {num_frames}") + logging.info(f"batch {counter}, total frames {num_frames}") diff --git a/examples/distributed/collectors/multi_nodes/ray_train.py b/examples/distributed/collectors/multi_nodes/ray_train.py index 360c6daac28..955d97113fe 100644 --- a/examples/distributed/collectors/multi_nodes/ray_train.py +++ b/examples/distributed/collectors/multi_nodes/ray_train.py @@ -5,7 +5,7 @@ This script reproduces the PPO example in https://pytorch.org/rl/tutorials/coding_ppo.html with a RayCollector. """ - +import logging from collections import defaultdict import matplotlib.pyplot as plt @@ -235,4 +235,4 @@ plt.title("Max step count (test)") save_name = "/tmp/results.jpg" plt.savefig(save_name) - print(f"results saved in {save_name}") + logging.info(f"results saved in {save_name}") diff --git a/examples/distributed/collectors/multi_nodes/rpc.py b/examples/distributed/collectors/multi_nodes/rpc.py index b88d8cb5704..2fdbdc47a4c 100644 --- a/examples/distributed/collectors/multi_nodes/rpc.py +++ b/examples/distributed/collectors/multi_nodes/rpc.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging import time from argparse import ArgumentParser @@ -115,5 +116,5 @@ def gym_make(): t0 = time.time() collector.shutdown() t1 = time.time() - print(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps") + logging.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps") exit() diff --git a/examples/distributed/collectors/multi_nodes/sync.py b/examples/distributed/collectors/multi_nodes/sync.py index d0ef0b3c054..65b93beb294 100644 --- a/examples/distributed/collectors/multi_nodes/sync.py +++ b/examples/distributed/collectors/multi_nodes/sync.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging import time from argparse import ArgumentParser @@ -121,5 +122,5 @@ def gym_make(): t0 = time.time() collector.shutdown() t1 = time.time() - print(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps") + logging.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps") exit() diff --git a/examples/distributed/collectors/single_machine/generic.py b/examples/distributed/collectors/single_machine/generic.py index c20e5fb436d..9c1fd9976f0 100644 --- a/examples/distributed/collectors/single_machine/generic.py +++ b/examples/distributed/collectors/single_machine/generic.py @@ -17,6 +17,7 @@ `--env` flag. Any available gym env will work. """ +import logging import time from argparse import ArgumentParser @@ -155,5 +156,5 @@ def gym_make(): t0 = time.time() collector.shutdown() t1 = time.time() - print(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps") + logging.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps") exit() diff --git a/examples/distributed/collectors/single_machine/rpc.py b/examples/distributed/collectors/single_machine/rpc.py index 0a47d8014a3..7de1cf5aad0 100644 --- a/examples/distributed/collectors/single_machine/rpc.py +++ b/examples/distributed/collectors/single_machine/rpc.py @@ -17,6 +17,7 @@ `--env` flag. Any available gym env will work. """ +import logging import time from argparse import ArgumentParser @@ -123,5 +124,5 @@ def gym_make(): t0 = time.time() collector.shutdown() t1 = time.time() - print(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps") + logging.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps") exit() diff --git a/examples/distributed/collectors/single_machine/sync.py b/examples/distributed/collectors/single_machine/sync.py index d07295302fe..7f3d62efa45 100644 --- a/examples/distributed/collectors/single_machine/sync.py +++ b/examples/distributed/collectors/single_machine/sync.py @@ -18,6 +18,7 @@ `--env` flag. Any available gym env will work. """ +import logging import time from argparse import ArgumentParser @@ -147,5 +148,5 @@ def gym_make(): t0 = time.time() collector.shutdown() t1 = time.time() - print(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps") + logging.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps") exit() diff --git a/examples/distributed/replay_buffers/distributed_replay_buffer.py b/examples/distributed/replay_buffers/distributed_replay_buffer.py index 843cdfa5b5c..64f4627e2e5 100644 --- a/examples/distributed/replay_buffers/distributed_replay_buffer.py +++ b/examples/distributed/replay_buffers/distributed_replay_buffer.py @@ -8,6 +8,7 @@ """ import argparse +import logging import os import random import sys @@ -50,7 +51,7 @@ class DummyDataCollectorNode: def __init__(self, replay_buffer: rpc.RRef) -> None: self.id = rpc.get_worker_info().id self.replay_buffer = replay_buffer - print("Data Collector Node constructed") + logging.info("Data Collector Node constructed") def _submit_random_item_async(self) -> rpc.RRef: td = TensorDict({"a": torch.randint(100, (1,))}, []) @@ -68,7 +69,7 @@ def collect(self): """Method that begins experience collection (we just generate random TensorDicts in this example). `accept_remote_rref_invocation` enables this method to be invoked remotely provided the class instantiation `rpc.RRef` is provided in place of the object reference.""" for elem in range(50): time.sleep(random.randint(1, 4)) - print( + logging.info( f"Collector [{self.id}] submission {elem}: {self._submit_random_item_async().to_here()}" ) @@ -77,22 +78,22 @@ class DummyTrainerNode: """Trainer node responsible for learning from experiences sampled from an experience replay buffer.""" def __init__(self) -> None: - print("DummyTrainerNode") + logging.info("DummyTrainerNode") self.id = rpc.get_worker_info().id self.replay_buffer = self._create_replay_buffer() self._create_and_launch_data_collectors() def train(self, iterations: int) -> None: for iteration in range(iterations): - print(f"[{self.id}] Training Iteration: {iteration}") + logging.info(f"[{self.id}] Training Iteration: {iteration}") time.sleep(3) batch = rpc.rpc_sync( self.replay_buffer.owner(), ReplayBufferNode.sample, args=(self.replay_buffer, 16), ) - print(f"[{self.id}] Sample Obtained Iteration: {iteration}") - print(f"{batch}") + logging.info(f"[{self.id}] Sample Obtained Iteration: {iteration}") + logging.info(f"{batch}") def _create_replay_buffer(self) -> rpc.RRef: while True: @@ -101,10 +102,10 @@ def _create_replay_buffer(self) -> rpc.RRef: buffer_rref = rpc.remote( replay_buffer_info, ReplayBufferNode, args=(10000,) ) - print(f"Connected to replay buffer {replay_buffer_info}") + logging.info(f"Connected to replay buffer {replay_buffer_info}") return buffer_rref except Exception as e: - print(f"Failed to connect to replay buffer: {e}") + logging.info(f"Failed to connect to replay buffer: {e}") time.sleep(RETRY_DELAY_SECS) def _create_and_launch_data_collectors(self) -> None: @@ -118,7 +119,7 @@ def _create_and_launch_data_collectors(self) -> None: data_collector_info = rpc.get_worker_info( f"DataCollector{data_collector_number}" ) - print(f"Data collector info: {data_collector_info}") + logging.info(f"Data collector info: {data_collector_info}") dc_ref = rpc.remote( data_collector_info, DummyDataCollectorNode, @@ -130,11 +131,11 @@ def _create_and_launch_data_collectors(self) -> None: retries = 0 except Exception: retries += 1 - print( + logging.info( f"Failed to connect to DataCollector{data_collector_number} with {retries} retries" ) if retries >= RETRY_LIMIT: - print(f"{len(data_collectors)} data collectors") + logging.info(f"{len(data_collectors)} data collectors") for data_collector_info, data_collector in zip( data_collector_infos, data_collectors ): @@ -170,7 +171,7 @@ def __init__(self, capacity: int): if __name__ == "__main__": args = parser.parse_args() rank = args.rank - print(f"Rank: {rank}") + logging.info(f"Rank: {rank}") os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "29500" @@ -187,21 +188,21 @@ def __init__(self, capacity: int): backend=rpc.BackendType.TENSORPIPE, rpc_backend_options=options, ) - print(f"Initialised Trainer Node {rank}") + logging.info(f"Initialised Trainer Node {rank}") trainer = DummyTrainerNode() trainer.train(100) breakpoint() elif rank == 1: # rank 1 is the replay buffer # replay buffer waits passively for construction instructions from trainer node - print(REPLAY_BUFFER_NODE) + logging.info(REPLAY_BUFFER_NODE) rpc.init_rpc( REPLAY_BUFFER_NODE, rank=rank, backend=rpc.BackendType.TENSORPIPE, rpc_backend_options=options, ) - print(f"Initialised RB Node {rank}") + logging.info(f"Initialised RB Node {rank}") breakpoint() elif rank >= 2: # rank 2+ is a new data collector node @@ -212,7 +213,7 @@ def __init__(self, capacity: int): backend=rpc.BackendType.TENSORPIPE, rpc_backend_options=options, ) - print(f"Initialised DC Node {rank}") + logging.info(f"Initialised DC Node {rank}") breakpoint() else: sys.exit(1) diff --git a/examples/dqn/dqn_atari.py b/examples/dqn/dqn_atari.py index 3481e7c9671..ecfbfa9deab 100644 --- a/examples/dqn/dqn_atari.py +++ b/examples/dqn/dqn_atari.py @@ -7,7 +7,7 @@ DQN: Reproducing experimental results from Mnih et al. 2015 for the Deep Q-Learning Algorithm on Atari Environments. """ - +import logging import tempfile import time @@ -208,7 +208,7 @@ def main(cfg: "DictConfig"): # noqa: F821 collector.shutdown() end_time = time.time() execution_time = end_time - start_time - print(f"Training took {execution_time:.2f} seconds to finish") + logging.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/examples/dqn/dqn_cartpole.py b/examples/dqn/dqn_cartpole.py index b0629cbb364..792b1f65477 100644 --- a/examples/dqn/dqn_cartpole.py +++ b/examples/dqn/dqn_cartpole.py @@ -2,7 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - +import logging import time import hydra @@ -187,7 +187,7 @@ def main(cfg: "DictConfig"): # noqa: F821 collector.shutdown() end_time = time.time() execution_time = end_time - start_time - print(f"Training took {execution_time:.2f} seconds to finish") + logging.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/examples/dreamer/dreamer.py b/examples/dreamer/dreamer.py index 6453edcf7e5..8c1e9da2e46 100644 --- a/examples/dreamer/dreamer.py +++ b/examples/dreamer/dreamer.py @@ -1,4 +1,5 @@ import dataclasses +import logging from pathlib import Path import hydra @@ -83,7 +84,7 @@ def main(cfg: "DictConfig"): # noqa: F821 device = torch.device(cfg.model_device) else: device = torch.device("cpu") - print(f"Using device {device}") + logging.info(f"Using device {device}") exp_name = generate_exp_name("Dreamer", cfg.exp_name) logger = get_logger( @@ -184,7 +185,7 @@ def main(cfg: "DictConfig"): # noqa: F821 actor_model_explore=exploration_policy, cfg=cfg, ) - print("collector:", collector) + logging.info("collector:", collector) replay_buffer = make_replay_buffer("cpu", cfg) @@ -204,7 +205,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) final_seed = collector.set_seed(cfg.seed) - print(f"init seed: {cfg.seed}, final seed: {final_seed}") + logging.info(f"init seed: {cfg.seed}, final seed: {final_seed}") # Training loop collected_frames = 0 pbar = tqdm.tqdm(total=cfg.total_frames) diff --git a/examples/impala/impala_multi_node_ray.py b/examples/impala/impala_multi_node_ray.py index a0d2d88c5a2..46941529c00 100644 --- a/examples/impala/impala_multi_node_ray.py +++ b/examples/impala/impala_multi_node_ray.py @@ -7,6 +7,8 @@ This script reproduces the IMPALA Algorithm results from Espeholt et al. 2018 for the on Atari Environments. """ +import logging + import hydra @@ -271,7 +273,7 @@ def main(cfg: "DictConfig"): # noqa: F821 collector.shutdown() end_time = time.time() execution_time = end_time - start_time - print(f"Training took {execution_time:.2f} seconds to finish") + logging.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/examples/impala/impala_multi_node_submitit.py b/examples/impala/impala_multi_node_submitit.py index 3355febbfaf..7eef42ec98f 100644 --- a/examples/impala/impala_multi_node_submitit.py +++ b/examples/impala/impala_multi_node_submitit.py @@ -7,6 +7,8 @@ This script reproduces the IMPALA Algorithm results from Espeholt et al. 2018 for the on Atari Environments. """ +import logging + import hydra @@ -263,7 +265,7 @@ def main(cfg: "DictConfig"): # noqa: F821 collector.shutdown() end_time = time.time() execution_time = end_time - start_time - print(f"Training took {execution_time:.2f} seconds to finish") + logging.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/examples/impala/impala_single_node.py b/examples/impala/impala_single_node.py index cd270f4c9e9..9a853e9bc76 100644 --- a/examples/impala/impala_single_node.py +++ b/examples/impala/impala_single_node.py @@ -7,6 +7,8 @@ This script reproduces the IMPALA Algorithm results from Espeholt et al. 2018 for the on Atari Environments. """ +import logging + import hydra @@ -241,7 +243,7 @@ def main(cfg: "DictConfig"): # noqa: F821 collector.shutdown() end_time = time.time() execution_time = end_time - start_time - print(f"Training took {execution_time:.2f} seconds to finish") + logging.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/examples/iql/iql_offline.py b/examples/iql/iql_offline.py index f6612048318..b5df32d7f2d 100644 --- a/examples/iql/iql_offline.py +++ b/examples/iql/iql_offline.py @@ -9,6 +9,7 @@ The helper functions are coded in the utils.py associated with this script. """ +import logging import time import hydra @@ -110,7 +111,7 @@ def main(cfg: "DictConfig"): # noqa: F821 log_metrics(logger, to_log, i) pbar.close() - print(f"Training time: {time.time() - start_time}") + logging.info(f"Training time: {time.time() - start_time}") if __name__ == "__main__": diff --git a/examples/iql/iql_online.py b/examples/iql/iql_online.py index 290c6c2a8de..8dd7c0fdd07 100644 --- a/examples/iql/iql_online.py +++ b/examples/iql/iql_online.py @@ -11,7 +11,7 @@ The helper functions are coded in the utils.py associated with this script. """ - +import logging import time import hydra @@ -183,7 +183,7 @@ def main(cfg: "DictConfig"): # noqa: F821 collector.shutdown() end_time = time.time() execution_time = end_time - start_time - print(f"Training took {execution_time:.2f} seconds to finish") + logging.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/examples/multiagent/iql.py b/examples/multiagent/iql.py index 351f5c3730e..4af5da62c91 100644 --- a/examples/multiagent/iql.py +++ b/examples/multiagent/iql.py @@ -2,7 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - +import logging import time import hydra @@ -144,7 +144,7 @@ def train(cfg: "DictConfig"): # noqa: F821 total_frames = 0 sampling_start = time.time() for i, tensordict_data in enumerate(collector): - print(f"\nIteration {i}") + logging.info(f"\nIteration {i}") sampling_time = time.time() - sampling_start diff --git a/examples/multiagent/maddpg_iddpg.py b/examples/multiagent/maddpg_iddpg.py index 9301f8a63f2..4e6b821604c 100644 --- a/examples/multiagent/maddpg_iddpg.py +++ b/examples/multiagent/maddpg_iddpg.py @@ -2,7 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - +import logging import time import hydra @@ -170,7 +170,7 @@ def train(cfg: "DictConfig"): # noqa: F821 total_frames = 0 sampling_start = time.time() for i, tensordict_data in enumerate(collector): - print(f"\nIteration {i}") + logging.info(f"\nIteration {i}") sampling_time = time.time() - sampling_start diff --git a/examples/multiagent/mappo_ippo.py b/examples/multiagent/mappo_ippo.py index c2e46174e92..95d340046fa 100644 --- a/examples/multiagent/mappo_ippo.py +++ b/examples/multiagent/mappo_ippo.py @@ -2,8 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - - +import logging import time import hydra @@ -168,7 +167,7 @@ def train(cfg: "DictConfig"): # noqa: F821 total_frames = 0 sampling_start = time.time() for i, tensordict_data in enumerate(collector): - print(f"\nIteration {i}") + logging.info(f"\nIteration {i}") sampling_time = time.time() - sampling_start diff --git a/examples/multiagent/qmix_vdn.py b/examples/multiagent/qmix_vdn.py index 222e0434db2..5822bda39da 100644 --- a/examples/multiagent/qmix_vdn.py +++ b/examples/multiagent/qmix_vdn.py @@ -2,7 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - +import logging import time import hydra @@ -168,7 +168,7 @@ def train(cfg: "DictConfig"): # noqa: F821 total_frames = 0 sampling_start = time.time() for i, tensordict_data in enumerate(collector): - print(f"\nIteration {i}") + logging.info(f"\nIteration {i}") sampling_time = time.time() - sampling_start diff --git a/examples/multiagent/sac.py b/examples/multiagent/sac.py index fb184291c90..1c01b5e50b7 100644 --- a/examples/multiagent/sac.py +++ b/examples/multiagent/sac.py @@ -2,7 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - +import logging import time import hydra @@ -237,7 +237,7 @@ def train(cfg: "DictConfig"): # noqa: F821 total_frames = 0 sampling_start = time.time() for i, tensordict_data in enumerate(collector): - print(f"\nIteration {i}") + logging.info(f"\nIteration {i}") sampling_time = time.time() - sampling_start diff --git a/examples/ppo/ppo_atari.py b/examples/ppo/ppo_atari.py index 1bfbccdeba4..86685fa2642 100644 --- a/examples/ppo/ppo_atari.py +++ b/examples/ppo/ppo_atari.py @@ -7,6 +7,7 @@ This script reproduces the Proximal Policy Optimization (PPO) Algorithm results from Schulman et al. 2017 for the on Atari Environments. """ +import logging import hydra @@ -232,7 +233,7 @@ def main(cfg: "DictConfig"): # noqa: F821 end_time = time.time() execution_time = end_time - start_time - print(f"Training took {execution_time:.2f} seconds to finish") + logging.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/examples/ppo/ppo_mujoco.py b/examples/ppo/ppo_mujoco.py index 52b12f688e1..eca985c2069 100644 --- a/examples/ppo/ppo_mujoco.py +++ b/examples/ppo/ppo_mujoco.py @@ -7,6 +7,8 @@ This script reproduces the Proximal Policy Optimization (PPO) Algorithm results from Schulman et al. 2017 for the on MuJoCo Environments. """ +import logging + import hydra @@ -221,7 +223,7 @@ def main(cfg: "DictConfig"): # noqa: F821 end_time = time.time() execution_time = end_time - start_time - print(f"Training took {execution_time:.2f} seconds to finish") + logging.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/examples/redq/utils.py b/examples/redq/utils.py index 076d3bf75b3..76ddf4ad302 100644 --- a/examples/redq/utils.py +++ b/examples/redq/utils.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations +import logging from copy import copy from typing import Callable, Dict, Optional, Sequence, Tuple, Union @@ -216,7 +217,7 @@ def make_trainer( >>> logger = TensorboardLogger(exp_name=dir) >>> trainer = make_trainer(collector, loss_module, recorder, target_net_updater, policy_exploration, ... replay_buffer, logger) - >>> print(trainer) + >>> logging.info(trainer) """ @@ -243,7 +244,7 @@ def make_trainer( raise NotImplementedError(f"lr scheduler {cfg.optim.lr_scheduler}") if VERBOSE: - print( + logging.info( f"collector = {collector}; \n" f"loss_module = {loss_module}; \n" f"recorder = {recorder}; \n" diff --git a/examples/rlhf/models/reward.py b/examples/rlhf/models/reward.py index da69e74ab4d..c11f1c02244 100644 --- a/examples/rlhf/models/reward.py +++ b/examples/rlhf/models/reward.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging import warnings import torch @@ -30,7 +31,7 @@ def init_reward_model( model.to(device) if compile_model: - print("Compiling the reward model...") + logging.info("Compiling the reward model...") model = torch.compile(model) model = TensorDictModule( diff --git a/examples/rlhf/models/transformer.py b/examples/rlhf/models/transformer.py index a33891a86a5..d1c2b02d0a9 100644 --- a/examples/rlhf/models/transformer.py +++ b/examples/rlhf/models/transformer.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging + import torch from tensordict.nn import TensorDictModule from transformers import GPT2LMHeadModel @@ -27,8 +29,7 @@ def init_transformer( model.to(device) if compile_model: - # TODO: logging instead of printing? - print("Compiling transformer model...") + logging.info("Compiling transformer model...") model = torch.compile(model) if as_tensordictmodule: diff --git a/examples/rlhf/train.py b/examples/rlhf/train.py index 6d9e758503d..f5551e47579 100644 --- a/examples/rlhf/train.py +++ b/examples/rlhf/train.py @@ -9,6 +9,7 @@ To run on a single GPU, example: $ python train.py --batch_size=32 --compile=False """ +import logging import time import hydra @@ -134,20 +135,20 @@ def main(cfg): train_loss = estimate_loss(model, train_loader) val_loss = estimate_loss(model, val_loader) msg = f"VALID: {it=}: {train_loss=:.4f}, {val_loss=:.4f}" - print(msg) + logging.info(msg) loss_logger.info(msg) if val_loss < best_val_loss or always_save_checkpoint: best_val_loss = val_loss if it > 0: msg = f"saving checkpoint to {out_dir}" - print(msg) + logging.info(msg) loss_logger.info(msg) model.module.save_pretrained(out_dir) elif it % log_interval == 0: # loss as float. note: this is a CPU-GPU sync point loss = batch.loss.item() msg = f"TRAIN: {it=}: {loss=:.4f}, time {dt*1000:.2f}ms" - print(msg) + logging.info(msg) loss_logger.info(msg) diff --git a/examples/rlhf/train_reward.py b/examples/rlhf/train_reward.py index 00813c81d0a..ac1299f0175 100644 --- a/examples/rlhf/train_reward.py +++ b/examples/rlhf/train_reward.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging import time import hydra @@ -140,13 +141,13 @@ def main(cfg): f"VALID: {it=}: {train_loss=:.4f}, {val_loss=:.4f}, " f"{train_acc=:.4f}, {val_acc=:.4f}" ) - print(msg) + logging.info(msg) loss_logger.info(msg) if val_loss < best_val_loss or always_save_checkpoint: best_val_loss = val_loss if it > 0: msg = f"saving checkpoint to {reward_out_dir}" - print(msg) + logging.info(msg) loss_logger.info(msg) model.module.save_pretrained(reward_out_dir) elif it % log_interval == 0: @@ -155,7 +156,7 @@ def main(cfg): batch.chosen_data.end_scores, batch.rejected_data.end_scores ) msg = f"TRAIN: {it=}: {loss=:.4f}, {acc=:.4f} time={dt*1000:.2f}ms" - print(msg) + logging.info(msg) loss_logger.info(msg) diff --git a/examples/sac/sac.py b/examples/sac/sac.py index 76bfea72e45..9a08cd8ef9b 100644 --- a/examples/sac/sac.py +++ b/examples/sac/sac.py @@ -10,7 +10,7 @@ The helper functions are coded in the utils.py associated with this script. """ - +import logging import time import hydra @@ -204,7 +204,7 @@ def main(cfg: "DictConfig"): # noqa: F821 collector.shutdown() end_time = time.time() execution_time = end_time - start_time - print(f"Training took {execution_time:.2f} seconds to finish") + logging.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/examples/td3/td3.py b/examples/td3/td3.py index 6a129b40209..ab21db76b15 100644 --- a/examples/td3/td3.py +++ b/examples/td3/td3.py @@ -10,7 +10,7 @@ The helper functions are coded in the utils.py associated with this script. """ - +import logging import time import hydra @@ -202,7 +202,7 @@ def main(cfg: "DictConfig"): # noqa: F821 collector.shutdown() end_time = time.time() execution_time = end_time - start_time - print(f"Training took {execution_time:.2f} seconds to finish") + logging.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/setup.cfg b/setup.cfg index 55e98280d3e..6c2907430f5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,6 +24,11 @@ per-file-ignores = torchrl/objectives/sac.py: TOR101 torchrl/objectives/td3.py: TOR101 torchrl/objectives/value/advantages.py: TOR101 + tutorials/*/**.py: T201 + build_tools/setup_helpers/extension.py: T201 + examples/torchrl_features/*.py: T201 + test/opengl_rendering.py: T201 + */**/run-clang-format.py: T201 exclude = venv extend-select = B901, C401, C408, C409, TOR0, TOR1, TOR2 diff --git a/setup.py b/setup.py index fccf412aea7..033f6dd3bbf 100644 --- a/setup.py +++ b/setup.py @@ -5,6 +5,7 @@ import argparse import distutils.command.clean import glob +import logging import os import shutil import subprocess @@ -96,7 +97,7 @@ def run(self): # Remove torchrl extension for path in (ROOT_DIR / "torchrl").glob("**/*.so"): - print(f"removing '{path}'") + logging.info(f"removing '{path}'") path.unlink() # Remove build directory build_dirs = [ @@ -104,7 +105,7 @@ def run(self): ] for path in build_dirs: if path.exists(): - print(f"removing '{path}' (and everything under it)") + logging.info(f"removing '{path}' (and everything under it)") shutil.rmtree(str(path), ignore_errors=True) @@ -128,7 +129,7 @@ def get_extensions(): } debug_mode = os.getenv("DEBUG", "0") == "1" if debug_mode: - print("Compiling in debug mode") + logging.info("Compiling in debug mode") extra_compile_args = { "cxx": [ "-O0", @@ -177,11 +178,11 @@ def _main(argv): else: version = get_version() write_version_file(version) - print("Building wheel {}-{}".format(package_name, version)) - print(f"BUILD_VERSION is {os.getenv('BUILD_VERSION')}") + logging.info("Building wheel {}-{}".format(package_name, version)) + logging.info(f"BUILD_VERSION is {os.getenv('BUILD_VERSION')}") pytorch_package_dep = _get_pytorch_version(is_nightly) - print("-- PyTorch dependency:", pytorch_package_dep) + logging.info("-- PyTorch dependency:", pytorch_package_dep) # branch = _run_cmd(["git", "rev-parse", "--abbrev-ref", "HEAD"]) # tag = _run_cmd(["git", "describe", "--tags", "--exact-match", "@"]) diff --git a/test/_utils_internal.py b/test/_utils_internal.py index ca98e2cff6e..29adc35a5a0 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import contextlib +import logging import os import os.path @@ -119,7 +120,7 @@ def f_retry(*args, **kwargs): return f(*args, **kwargs) except ExceptionToCheck as e: msg = "%s, Retrying in %d seconds..." % (str(e), mdelay) - print(msg) + logging.info(msg) time.sleep(mdelay) mtries -= 1 try: diff --git a/test/assets/generate.py b/test/assets/generate.py index deb47f95999..4f50b7c03a4 100644 --- a/test/assets/generate.py +++ b/test/assets/generate.py @@ -54,7 +54,6 @@ def get_minibatch(): for data in dl: data = data.clone().memmap_("test/datasets_mini/tldr_batch/") break - print("done") if __name__ == "__main__": diff --git a/test/conftest.py b/test/conftest.py index f392cb7d4f1..5ce980a4080 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -33,7 +33,6 @@ def pytest_sessionfinish(maxprint=50): out_str += f"\t{key}{spaces}{item: 4.4f}s\n" if i == maxprint - 1: break - print(out_str) @pytest.fixture(autouse=True) diff --git a/test/opengl_rendering.py b/test/opengl_rendering.py index 7533e298069..0e2f86294c1 100644 --- a/test/opengl_rendering.py +++ b/test/opengl_rendering.py @@ -28,7 +28,7 @@ # pylint: disable=unused-import,g-import-not-at-top,g-statement-before-imports try: - import OpenGL + import OpenGL # noqa: F401 except ImportError: print("This module depends on PyOpenGL.") print( @@ -66,7 +66,7 @@ def _find_library_new(name): util.find_library = _find_library_new import OpenGL.EGL as egl - import OpenGL.GL as gl + import OpenGL.GL as gl # noqa: F401 from OpenGL import error from OpenGL.EGL.EXT.device_base import egl_get_devices from OpenGL.raw.EGL.EXT.platform_device import EGL_PLATFORM_DEVICE_EXT diff --git a/test/test_collector.py b/test/test_collector.py index 565bfe1a7fe..61c1d886c24 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import argparse +import logging import sys @@ -2080,7 +2081,7 @@ def test_num_threads(): c.shutdown() del c except Exception: - print("Failed to shut down collector") + logging.info("Failed to shut down collector") # reset vals collectors._main_async_collector = _main_async_collector_saved torch.set_num_threads(num_threads) diff --git a/test/test_distributed.py b/test/test_distributed.py index 8dcbe33f79d..debfa058ace 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -8,6 +8,7 @@ """ import abc import argparse +import logging import os import sys import time @@ -88,7 +89,7 @@ def _test_distributed_collector_basic(cls, queue, frames_per_batch): cls._start_worker() env = ContinuousActionVecMockEnv policy = RandomPolicy(env().action_spec) - print("creating collector") + logging.info("creating collector") collector = cls.distributed_class()( [env] * 2, policy, @@ -97,7 +98,7 @@ def _test_distributed_collector_basic(cls, queue, frames_per_batch): **cls.distributed_kwargs(), ) total = 0 - print("getting data...") + logging.info("getting data...") for data in collector: total += data.numel() assert data.numel() == frames_per_batch diff --git a/test/test_env.py b/test/test_env.py index aed4e07b0b7..fc566749b8c 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -1800,9 +1800,6 @@ def main_collector(j, q=None): s = "" for key, item in td_equals.items(True, True): if not item.all(): - print(key, "failed") - print("r_p", r_p.get(key)[~item]) - print("r_s", r_s.get(key)[~item]) s = s + f"\t{key}" q.put((f"failed: {s}", j)) else: diff --git a/test/test_libs.py b/test/test_libs.py index 51092ca6618..23b823419ac 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import importlib +import logging from contextlib import nullcontext from torchrl.envs.transforms import ActionMask, TransformedEnv @@ -1948,7 +1949,7 @@ def test_direct_download(self, task, tmpdir): def test_d4rl_dummy(self, task): t0 = time.time() _ = D4RLExperienceReplay(task, split_trajs=True, from_env=True, batch_size=2) - print(f"terminated test after {time.time()-t0}s") + logging.info(f"terminated test after {time.time()-t0}s") @pytest.mark.parametrize("task", ["walker2d-medium-replay-v2"]) @pytest.mark.parametrize("split_trajs", [True, False]) @@ -1969,7 +1970,7 @@ def test_dataset_build(self, task, split_trajs, from_env): offline = sample.get(key) # assert sim.dtype == offline.dtype, key assert sim.shape[-1] == offline.shape[-1], key - print(f"terminated test after {time.time()-t0}s") + logging.info(f"terminated test after {time.time()-t0}s") @pytest.mark.parametrize("task", ["walker2d-medium-replay-v2"]) @pytest.mark.parametrize("split_trajs", [True, False]) @@ -1988,7 +1989,7 @@ def test_d4rl_iteration(self, task, split_trajs): for sample in data: # noqa: B007 i += 1 assert len(data) // i == batch_size - print(f"terminated test after {time.time()-t0}s") + logging.info(f"terminated test after {time.time()-t0}s") _MINARI_DATASETS = [] @@ -2012,7 +2013,6 @@ def _minari_selected_datasets(): ] assert len(keys) > 5 _MINARI_DATASETS += keys - print("_MINARI_DATASETS", _MINARI_DATASETS) _minari_selected_datasets() @@ -2024,14 +2024,14 @@ def _minari_selected_datasets(): @pytest.mark.slow class TestMinari: def test_load(self, selected_dataset, split): - print("dataset", selected_dataset) + logging.info("dataset", selected_dataset) data = MinariExperienceReplay( selected_dataset, batch_size=32, split_trajs=split ) t0 = time.time() for i, sample in enumerate(data): t1 = time.time() - print(f"sampling time {1000 * (t1-t0): 4.4f}ms") + logging.info(f"sampling time {1000 * (t1-t0): 4.4f}ms") assert data.metadata["action_space"].is_in(sample["action"]) assert data.metadata["observation_space"].is_in(sample["observation"]) t0 = time.time() @@ -2050,7 +2050,7 @@ def test_load(self): t0 = time.time() for i, _ in enumerate(data): t1 = time.time() - print(f"sampling time {1000 * (t1-t0): 4.4f}ms") + logging.info(f"sampling time {1000 * (t1-t0): 4.4f}ms") t0 = time.time() if i == 10: break @@ -2083,7 +2083,7 @@ def test_load(self, image_size): assert (batch.get("pixels") != 0).any() assert (batch.get(("next", "pixels")) != 0).any() t1 = time.time() - print(f"sampling time {1000 * (t1-t0): 4.4f}ms") + logging.info(f"sampling time {1000 * (t1-t0): 4.4f}ms") t0 = time.time() if i == 10: break @@ -2546,16 +2546,16 @@ def test_robohive(self, from_pixels): substr in envname for substr in ("_vr3m", "_vrrl", "_vflat", "_vvc1s") ): - print("not testing envs with prebuilt rendering") + logging.info("not testing envs with prebuilt rendering") return if "Adroit" in envname: - print("tcdm are broken") + logging.info("tcdm are broken") return try: env = RoboHiveEnv(envname) except AttributeError as err: if "'MjData' object has no attribute 'get_body_xipos'" in str(err): - print("tcdm are broken") + logging.info("tcdm are broken") return else: raise err @@ -2563,7 +2563,7 @@ def test_robohive(self, from_pixels): from_pixels and len(RoboHiveEnv.get_available_cams(env_name=envname)) == 0 ): - print("no camera") + logging.info("no camera") return check_env_specs(env) except Exception as err: diff --git a/test/test_modules.py b/test/test_modules.py index cdd8987022d..68917a10d16 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -1250,7 +1250,6 @@ def test_python_gru_cell(device, bias): h0 = torch.zeros(3, 20, device=device) with torch.no_grad(): for i in range(input.size()[0]): - print(i) h1 = gru_cell1(input[i], h0) h2 = gru_cell2(input[i], h0) diff --git a/test/test_rb_distributed.py b/test/test_rb_distributed.py index 548f04dc41d..1d5a2398e92 100644 --- a/test/test_rb_distributed.py +++ b/test/test_rb_distributed.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import argparse +import logging import os import sys @@ -110,7 +111,7 @@ def _construct_buffer(target): buffer_rref = rpc.remote(target, ReplayBufferNode, args=(1000,)) return buffer_rref except Exception as e: - print(f"Failed to connect: {e}") + logging.info(f"Failed to connect: {e}") time.sleep(RETRY_BACKOFF) raise RuntimeError("Unable to connect to replay buffer") diff --git a/test/test_shared.py b/test/test_shared.py index 186c8ae9525..dcfb798e35c 100644 --- a/test/test_shared.py +++ b/test/test_shared.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import argparse +import logging import time import warnings @@ -19,7 +20,7 @@ def remote_process(command_pipe_child, command_pipe_parent, tensordict): assert tensordict.is_shared() t0 = time.time() tensordict.zero_() - print(f"zeroing time: {time.time() - t0}") + logging.info(f"zeroing time: {time.time() - t0}") command_pipe_child.send("done") command_pipe_child.close() del command_pipe_child, command_pipe_parent, tensordict @@ -112,7 +113,7 @@ def driver_func(td, stack): command_pipe_child.close() command_pipe_parent.send("stack" if stack else "serial") time_spent = command_pipe_parent.recv() - print(f"stack {stack}: time={time_spent}") + logging.info(f"stack {stack}: time={time_spent}") for item in td.values(): assert (item == 0).all() proc.join() @@ -121,7 +122,7 @@ def driver_func(td, stack): @pytest.mark.parametrize("shared", ["shared", "memmap"]) def test_shared(self, shared): - print(f"test_shared: shared={shared}") + logging.info(f"test_shared: shared={shared}") torch.manual_seed(0) tensordict = TensorDict( source={ @@ -163,36 +164,36 @@ def test_memmap(idx, dtype, large_scale=False): td_sm = td.clone().share_memory_() td_memmap = td.clone().memmap_() - print("\nTesting reading from TD") + logging.info("\nTesting reading from TD") for i in range(2): t0 = time.time() td_sm[idx].clone() if i == 1: - print(f"sm: {time.time() - t0:4.4f} sec") + logging.info(f"sm: {time.time() - t0:4.4f} sec") t0 = time.time() td_memmap[idx].clone() if i == 1: - print(f"memmap: {time.time() - t0:4.4f} sec") + logging.info(f"memmap: {time.time() - t0:4.4f} sec") td_to_copy = td[idx].contiguous() for k in td_to_copy.keys(): td_to_copy.set_(k, torch.ones_like(td_to_copy.get(k))) - print("\nTesting writing to TD") + logging.info("\nTesting writing to TD") for i in range(2): t0 = time.time() sub_td_sm = td_sm.get_sub_tensordict(idx) sub_td_sm.update_(td_to_copy) if i == 1: - print(f"sm td: {time.time() - t0:4.4f} sec") + logging.info(f"sm td: {time.time() - t0:4.4f} sec") torch.testing.assert_close(sub_td_sm.get("a"), td_to_copy.get("a")) t0 = time.time() sub_td_sm = td_memmap.get_sub_tensordict(idx) sub_td_sm.update_(td_to_copy) if i == 1: - print(f"memmap td: {time.time() - t0:4.4f} sec") + logging.info(f"memmap td: {time.time() - t0:4.4f} sec") torch.testing.assert_close(sub_td_sm.get("a")._tensor, td_to_copy.get("a")) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index b3d768f2d22..323f3554ab0 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -8,6 +8,7 @@ import functools import inspect +import logging import math import os @@ -65,7 +66,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): val[2] = N @staticmethod - def print(prefix=None): + def print(prefix=None): # noqa: T202 keys = list(timeit._REG) keys.sort() for name in keys: @@ -75,7 +76,7 @@ def print(prefix=None): strings.append( f"{name} took {timeit._REG[name][0] * 1000:4.4} msec (total = {timeit._REG[name][1]} sec)" ) - print(" -- ".join(strings)) + logging.info(" -- ".join(strings)) @staticmethod def erase(): @@ -405,7 +406,7 @@ def reset(cls, setters_dict: Dict[str, implement_for] = None): """ if VERBOSE: - print("resetting implement_for") + logging.info("resetting implement_for") if setters_dict is None: setters_dict = copy(cls._implementations) for setter in setters_dict.values(): @@ -652,17 +653,17 @@ def format_size(size): total_size_bytes = get_directory_size(path) formatted_size = format_size(total_size_bytes) - print(f"Directory size: {formatted_size}") + logging.info(f"Directory size: {formatted_size}") if os.path.isdir(path): - print(indent + os.path.basename(path) + "/") + logging.info(indent + os.path.basename(path) + "/") indent += " " for item in os.listdir(path): print_directory_tree( os.path.join(path, item), indent=indent, display_metadata=False ) else: - print(indent + os.path.basename(path)) + logging.info(indent + os.path.basename(path)) def _replace_last(key: NestedKey, new_ending: str) -> NestedKey: diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 4ab94e4cd11..8d1762f8465 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -7,6 +7,7 @@ import _pickle import abc import inspect +import logging import os import queue import sys @@ -14,7 +15,6 @@ import warnings from collections import OrderedDict from copy import deepcopy - from multiprocessing import connection, queues from multiprocessing.managers import SyncManager @@ -2125,7 +2125,7 @@ def _main_async_collector( interruptor=interruptor, ) if verbose: - print("Sync data collector created") + logging.info("Sync data collector created") dc_iter = iter(inner_collector) j = 0 pipe_child.send("instantiated") @@ -2138,10 +2138,10 @@ def _main_async_collector( counter = 0 data_in, msg = pipe_child.recv() if verbose: - print(f"worker {idx} received {msg}") + logging.info(f"worker {idx} received {msg}") else: if verbose: - print(f"poll failed, j={j}, worker={idx}") + logging.info(f"poll failed, j={j}, worker={idx}") # default is "continue" (after first iteration) # this is expected to happen if queue_out reached the timeout, but no new msg was waiting in the pipe # in that case, the main process probably expects the worker to continue collect data @@ -2161,7 +2161,7 @@ def _main_async_collector( counter += _timeout if verbose: - print(f"worker {idx} has counter {counter}") + logging.info(f"worker {idx} has counter {counter}") if counter >= (_MAX_IDLE_COUNT * _TIMEOUT): raise RuntimeError( f"This process waited for {counter} seconds " @@ -2201,13 +2201,13 @@ def _main_async_collector( try: queue_out.put((data, j), timeout=_TIMEOUT) if verbose: - print(f"worker {idx} successfully sent data") + logging.info(f"worker {idx} successfully sent data") j += 1 has_timed_out = False continue except queue.Full: if verbose: - print(f"worker {idx} has timed out") + logging.info(f"worker {idx} has timed out") has_timed_out = True continue @@ -2253,7 +2253,7 @@ def _main_async_collector( del inner_collector, dc_iter pipe_child.send("closed") if verbose: - print(f"collector {idx} closed") + logging.info(f"collector {idx} closed") break else: diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py index 752a09231c0..f213f73d160 100644 --- a/torchrl/collectors/distributed/generic.py +++ b/torchrl/collectors/distributed/generic.py @@ -5,6 +5,7 @@ r"""Generic distributed data-collector using torch.distributed backend.""" +import logging import os import socket import warnings @@ -50,10 +51,10 @@ def _node_init_dist(rank, world_size, backend, rank0_ip, tcpport, verbose): os.environ["MASTER_PORT"] = str(tcpport) if verbose: - print( + logging.info( f"Rank0 IP address: '{rank0_ip}' \ttcp port: '{tcpport}', backend={backend}." ) - print( + logging.info( f"node with rank {rank} with world_size {world_size} -- launching distributed" ) torch.distributed.init_process_group( @@ -64,7 +65,7 @@ def _node_init_dist(rank, world_size, backend, rank0_ip, tcpport, verbose): init_method=f"tcp://{rank0_ip}:{tcpport}", ) if verbose: - print(f"Connected!\nNode with rank {rank} -- creating store") + logging.info(f"Connected!\nNode with rank {rank} -- creating store") # The store carries instructions for the node _store = torch.distributed.TCPStore( host_name=rank0_ip, @@ -158,7 +159,9 @@ def _run_collector( ): rank = torch.distributed.get_rank() if verbose: - print(f"node with rank {rank} -- creating collector of type {collector_class}") + logging.info( + f"node with rank {rank} -- creating collector of type {collector_class}" + ) if not issubclass(collector_class, SyncDataCollector): env_make = [env_make] * num_workers else: @@ -192,30 +195,30 @@ def _run_collector( ) total_frames = 0 if verbose: - print(f"node with rank {rank} -- loop") + logging.info(f"node with rank {rank} -- loop") while True: instruction = _store.get(f"NODE_{rank}_in") if verbose: - print(f"node with rank {rank} -- new instruction: {instruction}") + logging.info(f"node with rank {rank} -- new instruction: {instruction}") _store.delete_key(f"NODE_{rank}_in") if instruction == b"continue": _store.set(f"NODE_{rank}_status", b"busy") if verbose: - print(f"node with rank {rank} -- new data") + logging.info(f"node with rank {rank} -- new data") data = collector.next() total_frames += data.numel() if verbose: - print(f"got data, total frames = {total_frames}") - print(f"node with rank {rank} -- sending {data}") + logging.info(f"got data, total frames = {total_frames}") + logging.info(f"node with rank {rank} -- sending {data}") if _store.get("TRAINER_status") == b"alive": data.isend(dst=0) if verbose: - print(f"node with rank {rank} -- setting to 'done'") + logging.info(f"node with rank {rank} -- setting to 'done'") if not sync: _store.set(f"NODE_{rank}_status", b"done") elif instruction == b"shutdown": if verbose: - print(f"node with rank {rank} -- shutting down") + logging.info(f"node with rank {rank} -- shutting down") try: collector.shutdown() except Exception: @@ -483,7 +486,7 @@ def _init_master_dist( backend, ): if self._VERBOSE: - print( + logging.info( f"launching main node with tcp port '{self.tcp_port}' and " f"IP '{self.IPAddr}'. rank: 0, world_size: {world_size}, backend={backend}." ) @@ -499,7 +502,7 @@ def _init_master_dist( init_method=f"tcp://{self.IPAddr}:{TCP_PORT}", ) if self._VERBOSE: - print("main initiated! Launching store...", end="\t") + logging.info("main initiated! Launching store...", end="\t") self._store = torch.distributed.TCPStore( host_name=self.IPAddr, port=int(TCP_PORT) + 1, @@ -508,12 +511,12 @@ def _init_master_dist( timeout=timedelta(10), ) if self._VERBOSE: - print("done. Setting status to 'alive'") + logging.info("done. Setting status to 'alive'") self._store.set("TRAINER_status", b"alive") def _make_container(self): if self._VERBOSE: - print("making container") + logging.info("making container") env_constructor = self.env_constructors[0] pseudo_collector = SyncDataCollector( env_constructor, @@ -525,8 +528,8 @@ def _make_container(self): for _data in pseudo_collector: break if self._VERBOSE: - print("got data", _data) - print("expanding...") + logging.info("got data", _data) + logging.info("expanding...") if not issubclass(self.collector_class, SyncDataCollector): # Multi-data collectors self._tensordict_out = ( @@ -542,7 +545,7 @@ def _make_container(self): .to(self.storing_device) ) if self._VERBOSE: - print("locking") + logging.info("locking") if self._sync: self._tensordict_out.lock_() self._tensordict_out_unbind = self._tensordict_out.unbind(0) @@ -553,11 +556,11 @@ def _make_container(self): for td in self._tensordict_out: td.lock_() if self._VERBOSE: - print("storage created:") - print("shutting down...") + logging.info("storage created:") + logging.info("shutting down...") pseudo_collector.shutdown() if self._VERBOSE: - print("dummy collector shut down!") + logging.info("dummy collector shut down!") del pseudo_collector def _init_worker_dist_submitit(self, executor, i): @@ -640,7 +643,7 @@ def _init_workers(self): else: IPAddr = "localhost" if self._VERBOSE: - print("Server IP address:", IPAddr) + logging.info("Server IP address:", IPAddr) self.IPAddr = IPAddr os.environ["MASTER_ADDR"] = str(self.IPAddr) os.environ["MASTER_PORT"] = str(self.tcp_port) @@ -656,20 +659,20 @@ def _init_workers(self): else: for i in range(self.num_workers): if self._VERBOSE: - print("Submitting job") + logging.info("Submitting job") if self.launcher == "submitit": job = self._init_worker_dist_submitit( executor, i, ) if self._VERBOSE: - print("job id", job.job_id) # ID of your job + logging.info("job id", job.job_id) # ID of your job elif self.launcher == "mp": job = self._init_worker_dist_mp( i, ) if self._VERBOSE: - print("job launched") + logging.info("job launched") self.jobs.append(job) self._init_master_dist(self.num_workers + 1, self.backend) @@ -678,13 +681,13 @@ def iterator(self): def _iterator_dist(self): if self._VERBOSE: - print("iterating...") + logging.info("iterating...") total_frames = 0 if not self._sync: for rank in range(1, self.num_workers + 1): if self._VERBOSE: - print(f"sending 'continue' to {rank}") + logging.info(f"sending 'continue' to {rank}") self._store.set(f"NODE_{rank}_in", b"continue") trackers = [] for i in range(self.num_workers): @@ -729,7 +732,7 @@ def _next_sync(self, total_frames): if total_frames < self.total_frames: for rank in range(1, self.num_workers + 1): if self._VERBOSE: - print(f"sending 'continue' to {rank}") + logging.info(f"sending 'continue' to {rank}") self._store.set(f"NODE_{rank}_in", b"continue") trackers = [] for i in range(self.num_workers): @@ -763,7 +766,7 @@ def _next_async(self, total_frames, trackers): total_frames += data.numel() if total_frames < self.total_frames: if self._VERBOSE: - print(f"sending 'continue' to {rank}") + logging.info(f"sending 'continue' to {rank}") self._store.set(f"NODE_{rank}_in", b"continue") trackers[i] = self._tensordict_out[i].irecv( src=i + 1, return_premature=True @@ -786,7 +789,7 @@ def update_policy_weights_(self, worker_rank=None) -> None: for i in workers: rank = i + 1 if self._VERBOSE: - print(f"updating weights of {rank}") + logging.info(f"updating weights of {rank}") self._store.set(f"NODE_{rank}_in", b"update_weights") if self._sync: self.policy_weights.send(rank) @@ -822,12 +825,12 @@ def shutdown(self): for i in range(self.num_workers): rank = i + 1 if self._VERBOSE: - print(f"shutting down node with rank={rank}") + logging.info(f"shutting down node with rank={rank}") self._store.set(f"NODE_{rank}_in", b"shutdown") for i in range(self.num_workers): rank = i + 1 if self._VERBOSE: - print(f"getting status of node {rank}", end="\t") + logging.info(f"getting status of node {rank}", end="\t") status = self._store.get(f"NODE_{rank}_out") if status != b"down": raise RuntimeError(f"Expected 'down' but got status {status}.") @@ -842,4 +845,4 @@ def shutdown(self): elif self.launcher == "submitit_delayed": pass if self._VERBOSE: - print("collector shut down") + logging.info("collector shut down") diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py index c05da8c5a0f..11f94e4ea64 100644 --- a/torchrl/collectors/distributed/ray.py +++ b/torchrl/collectors/distributed/ray.py @@ -64,7 +64,7 @@ def print_remote_collector_info(self): f"{get_node_ip_address()} using gpus {ray.get_gpu_ids()}" ) # logger.warning(s) - print(s) + logging.info(s) @classmethod diff --git a/torchrl/collectors/distributed/rpc.py b/torchrl/collectors/distributed/rpc.py index 5fef2dd1666..98228d15f7b 100644 --- a/torchrl/collectors/distributed/rpc.py +++ b/torchrl/collectors/distributed/rpc.py @@ -5,6 +5,7 @@ r"""Generic distributed data-collector using torch.distributed.rpc backend.""" import collections +import logging import os import socket import time @@ -74,7 +75,7 @@ def _rpc_init_collection_node( **tensorpipe_options, ) if verbose: - print( + logging.info( f"init rpc with master addr: {os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}" ) rpc.init_rpc( @@ -344,7 +345,7 @@ def _init_master_rpc( f"COLLECTOR_NODE_{rank}", {0: self.visible_devices[i]} ) if self._VERBOSE: - print("init rpc") + logging.info("init rpc") rpc.init_rpc( "TRAINER_NODE", rank=0, @@ -375,7 +376,7 @@ def _start_workers( time.sleep(time_interval) try: if self._VERBOSE: - print(f"trying to connect to collector node {i + 1}") + logging.info(f"trying to connect to collector node {i + 1}") collector_info = rpc.get_worker_info(f"COLLECTOR_NODE_{i + 1}") break except RuntimeError as err: @@ -390,7 +391,7 @@ def _start_workers( if not isinstance(env_make, (EnvBase, EnvCreator)): env_make = CloudpickleWrapper(env_make) if self._VERBOSE: - print("Making collector in remote node") + logging.info("Making collector in remote node") collector_rref = rpc.remote( collector_infos[i], collector_class, @@ -414,7 +415,7 @@ def _start_workers( if not self._sync: for i in range(num_workers): if self._VERBOSE: - print("Asking for the first batch") + logging.info("Asking for the first batch") future = rpc.rpc_async( collector_infos[i], collector_class.next, @@ -444,7 +445,7 @@ def _init_worker_rpc(self, executor, i): self._VERBOSE, ) if self._VERBOSE: - print("job id", job.job_id) # ID of your job + logging.info("job id", job.job_id) # ID of your job return job elif self.launcher == "mp": job = _ProcessNoWarn( @@ -488,7 +489,7 @@ def _init(self): self.jobs = [] for i in range(self.num_workers): if self._VERBOSE: - print(f"Submitting job {i}") + logging.info(f"Submitting job {i}") job = self._init_worker_rpc( executor, i, @@ -545,7 +546,7 @@ def update_policy_weights_(self, workers=None, wait=True) -> None: futures = [] for i in workers: if self._VERBOSE: - print(f"calling update on worker {i}") + logging.info(f"calling update on worker {i}") futures.append( rpc.rpc_async( self.collector_infos[i], @@ -556,14 +557,14 @@ def update_policy_weights_(self, workers=None, wait=True) -> None: if wait: for i in workers: if self._VERBOSE: - print(f"waiting for worker {i}") + logging.info(f"waiting for worker {i}") futures[i].wait() if self._VERBOSE: - print("got it!") + logging.info("got it!") def _next_async_rpc(self): if self._VERBOSE: - print("next async") + logging.info("next async") if not len(self.futures): raise StopIteration( f"The queue is empty, the collector has ran out of data after {self._collected_frames} collected frames." @@ -574,7 +575,7 @@ def _next_async_rpc(self): if self.update_after_each_batch: self.update_policy_weights_(workers=(i,), wait=False) if self._VERBOSE: - print(f"future {i} is done") + logging.info(f"future {i} is done") data = future.value() self._collected_frames += data.numel() if self._collected_frames < self.total_frames: @@ -589,7 +590,7 @@ def _next_async_rpc(self): def _next_sync_rpc(self): if self._VERBOSE: - print("next sync: futures") + logging.info("next sync: futures") if self.update_after_each_batch: self.update_policy_weights_() for i in range(self.num_workers): @@ -606,7 +607,7 @@ def _next_sync_rpc(self): if future.done(): data += [future.value()] if self._VERBOSE: - print( + logging.info( f"got data from {i} // data has len {len(data)} / {self.num_workers}" ) else: @@ -637,15 +638,15 @@ def shutdown(self): if self._shutdown: return if self._VERBOSE: - print("shutting down") + logging.info("shutting down") for future, i in self.futures: # clear the futures while future is not None and not future.done(): - print(f"waiting for proc {i} to clear") + logging.info(f"waiting for proc {i} to clear") future.wait() for i in range(self.num_workers): if self._VERBOSE: - print(f"shutting down {i}") + logging.info(f"shutting down {i}") rpc.rpc_sync( self.collector_infos[i], self.collector_class.shutdown, @@ -653,7 +654,7 @@ def shutdown(self): timeout=int(IDLE_TIMEOUT), ) if self._VERBOSE: - print("rpc shutdown") + logging.info("rpc shutdown") rpc.shutdown(timeout=int(IDLE_TIMEOUT)) if self.launcher == "mp": for job in self.jobs: diff --git a/torchrl/collectors/distributed/sync.py b/torchrl/collectors/distributed/sync.py index 66e55318832..8d3afa488d4 100644 --- a/torchrl/collectors/distributed/sync.py +++ b/torchrl/collectors/distributed/sync.py @@ -5,6 +5,7 @@ r"""Generic distributed data-collector using torch.distributed backend.""" +import logging import os import socket from copy import copy, deepcopy @@ -63,7 +64,9 @@ def _distributed_init_collection_node( os.environ["MASTER_PORT"] = str(tcpport) if verbose: - print(f"node with rank {rank} -- creating collector of type {collector_class}") + logging.info( + f"node with rank {rank} -- creating collector of type {collector_class}" + ) if not issubclass(collector_class, SyncDataCollector): env_make = [env_make] * num_workers else: @@ -96,9 +99,9 @@ def _distributed_init_collection_node( **collector_kwargs, ) - print("IP address:", rank0_ip, "\ttcp port:", tcpport) + logging.info("IP address:", rank0_ip, "\ttcp port:", tcpport) if verbose: - print(f"node with rank {rank} -- launching distributed") + logging.info(f"node with rank {rank} -- launching distributed") torch.distributed.init_process_group( backend, rank=rank, @@ -107,9 +110,9 @@ def _distributed_init_collection_node( # init_method=f"tcp://{rank0_ip}:{tcpport}", ) if verbose: - print(f"node with rank {rank} -- creating store") + logging.info(f"node with rank {rank} -- creating store") if verbose: - print(f"node with rank {rank} -- loop") + logging.info(f"node with rank {rank} -- loop") policy_weights.irecv(0) frames = 0 for i, data in enumerate(collector): @@ -329,7 +332,7 @@ def _init_master_dist( backend, ): TCP_PORT = self.tcp_port - print("init master...", end="\t") + logging.info("init master...", end="\t") torch.distributed.init_process_group( backend, rank=0, @@ -337,7 +340,7 @@ def _init_master_dist( timeout=timedelta(MAX_TIME_TO_CONNECT), init_method=f"tcp://{self.IPAddr}:{TCP_PORT}", ) - print("done") + logging.info("done") def _make_container(self): env_constructor = self.env_constructors[0] @@ -422,7 +425,7 @@ def _init_workers(self): hostname = socket.gethostname() IPAddr = socket.gethostbyname(hostname) - print("Server IP address:", IPAddr) + logging.info("Server IP address:", IPAddr) self.IPAddr = IPAddr os.environ["MASTER_ADDR"] = str(self.IPAddr) os.environ["MASTER_PORT"] = str(self.tcp_port) @@ -434,18 +437,18 @@ def _init_workers(self): executor = submitit.AutoExecutor(folder="log_test") executor.update_parameters(**self.slurm_kwargs) for i in range(self.num_workers): - print("Submitting job") + logging.info("Submitting job") if self.launcher == "submitit": job = self._init_worker_dist_submitit( executor, i, ) - print("job id", job.job_id) # ID of your job + logging.info("job id", job.job_id) # ID of your job elif self.launcher == "mp": job = self._init_worker_dist_mp( i, ) - print("job launched") + logging.info("job launched") self.jobs.append(job) self._init_master_dist(self.num_workers + 1, self.backend) diff --git a/torchrl/collectors/distributed/utils.py b/torchrl/collectors/distributed/utils.py index 9559101e38e..24444fc171d 100644 --- a/torchrl/collectors/distributed/utils.py +++ b/torchrl/collectors/distributed/utils.py @@ -1,3 +1,4 @@ +import logging import subprocess import time @@ -96,7 +97,7 @@ def exec_fun(): executor.update_parameters(**self.submitit_main_conf) main_job = executor.submit(main_func) # listen to output file looking for IP address - print(f"job id: {main_job.job_id}") + logging.info(f"job id: {main_job.job_id}") time.sleep(2.0) node = None while not node: @@ -107,11 +108,11 @@ def exec_fun(): except ValueError: time.sleep(0.5) continue - print(f"node: {node}") + logging.info(f"node: {node}") # by default, sinfo will truncate the node name at char 20, we increase this to 200 cmd = f"sinfo -n {node} -O nodeaddr:200 | tail -1" rank0_ip = subprocess.check_output(cmd, shell=True, text=True).strip() - print(f"IP: {rank0_ip}") + logging.info(f"IP: {rank0_ip}") world_size = self.num_jobs + 1 # submit jobs diff --git a/torchrl/data/datasets/d4rl.py b/torchrl/data/datasets/d4rl.py index a7f4fb0b198..3afa680c88d 100644 --- a/torchrl/data/datasets/d4rl.py +++ b/torchrl/data/datasets/d4rl.py @@ -5,6 +5,8 @@ from __future__ import annotations import importlib + +import logging import os import tempfile import urllib @@ -449,7 +451,7 @@ def _shift_reward_done(self, dataset): def _download_dataset_from_url(dataset_url, dataset_path): dataset_filepath = _filepath_from_url(dataset_url, dataset_path) if not os.path.exists(dataset_filepath): - print("Downloading dataset:", dataset_url, "to", dataset_filepath) + logging.info("Downloading dataset:", dataset_url, "to", dataset_filepath) urllib.request.urlretrieve(dataset_url, dataset_filepath) if not os.path.exists(dataset_filepath): raise IOError("Failed to download dataset from %s" % dataset_url) @@ -473,7 +475,7 @@ def _filepath_from_url(dataset_url, dataset_path): if __name__ == "__main__": data = D4RLExperienceReplay("kitchen-partial-v0", batch_size=128) - print(data) + logging.info(data) for sample in data: - print(sample) + logging.info(sample) break diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index 9566c1eff10..8e20ebc12da 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -6,6 +6,7 @@ import importlib.util import json +import logging import os.path import shutil import tempfile @@ -107,7 +108,7 @@ class MinariExperienceReplay(TensorDictReplayBuffer): >>> from torchrl.data.datasets.minari_data import MinariExperienceReplay >>> data = MinariExperienceReplay("door-human-v1", batch_size=32, download="force") >>> for sample in data: - ... print(sample) + ... logging.info(sample) ... break TensorDict( fields={ @@ -249,7 +250,7 @@ def _download_and_preproc(self): td_data = TensorDict({}, []) total_steps = 0 - print("first read through data to create data structure...") + logging.info("first read through data to create data structure...") h5_data = PersistentTensorDict.from_h5(parent_dir / "main_data.hdf5") # populate the tensordict episode_dict = {} @@ -288,11 +289,13 @@ def _download_and_preproc(self): td_data["done"] = td_data["truncated"] | td_data["terminated"] td_data = td_data.expand(total_steps) # save to designated location - print(f"creating tensordict data in {self.data_path_root}: ", end="\t") + logging.info( + f"creating tensordict data in {self.data_path_root}: ", end="\t" + ) td_data = td_data.memmap_like(self.data_path_root) - print("tensordict structure:", td_data) + logging.info("tensordict structure:", td_data) - print(f"Reading data from {max(*episode_dict) + 1} episodes") + logging.info(f"Reading data from {max(*episode_dict) + 1} episodes") index = 0 with tqdm(total=total_steps) if _has_tqdm else nullcontext() as pbar: # iterate over episodes and populate the tensordict diff --git a/torchrl/data/datasets/roboset.py b/torchrl/data/datasets/roboset.py index 62e4e41e982..bcbb12a4891 100644 --- a/torchrl/data/datasets/roboset.py +++ b/torchrl/data/datasets/roboset.py @@ -5,6 +5,7 @@ from __future__ import annotations import importlib.util +import logging import os.path import shutil import tempfile @@ -91,11 +92,11 @@ class RobosetExperienceReplay(TensorDictReplayBuffer): >>> for batch in d: ... break >>> # data is organised by seed and episode, but stored contiguously - >>> print(batch["seed"], batch["episode"]) + >>> logging.info(batch["seed"], batch["episode"]) tensor([2, 1, 0, 0, 1, 1, 0, 0, 1, 1, 2, 2, 2, 2, 2, 1, 1, 2, 0, 2, 0, 2, 2, 1, 0, 2, 0, 0, 1, 1, 2, 1]) tensor([17, 20, 18, 9, 6, 1, 12, 6, 2, 6, 8, 15, 8, 21, 17, 3, 9, 20, 23, 12, 3, 16, 19, 16, 16, 4, 4, 12, 1, 2, 15, 24]) - >>> print(batch) + >>> logging.info(batch) TensorDict( fields={ action: Tensor(shape=torch.Size([32, 9]), device=cpu, dtype=torch.float64, is_shared=False), @@ -240,13 +241,13 @@ def _download_and_preproc(self): def _preproc_h5(self, h5_data_files): td_data = TensorDict({}, []) total_steps = 0 - print( + logging.info( f"first read through data files {h5_data_files} to create data structure..." ) episode_dict = {} h5_datas = [] for seed, h5_data_name in enumerate(h5_data_files): - print("\nReading", h5_data_name) + logging.info("\nReading", h5_data_name) h5_data = PersistentTensorDict.from_h5(h5_data_name) h5_datas.append(h5_data) for i, (episode_key, episode) in enumerate(h5_data.items()): @@ -255,7 +256,7 @@ def _preproc_h5(self, h5_data_files): episode_dict[(seed, episode_num)] = (episode_key, episode_len) # Get the total number of steps for the dataset total_steps += episode_len - print("total_steps", total_steps, end="\t") + logging.info("total_steps", total_steps, end="\t") if i == 0 and seed == 0: td_data.set("episode", 0) td_data.set("seed", 0) @@ -278,12 +279,14 @@ def _preproc_h5(self, h5_data_files): td_data = td_data.expand(total_steps) # save to designated location - print(f"creating tensordict data in {self.data_path_root}: ", end="\t") + logging.info(f"creating tensordict data in {self.data_path_root}: ", end="\t") td_data = td_data.memmap_like(self.data_path_root) - # print("tensordict structure:", td_data) - print("Local dataset structure:", print_directory_tree(self.data_path_root)) + # logging.info("tensordict structure:", td_data) + logging.info( + "Local dataset structure:", print_directory_tree(self.data_path_root) + ) - print(f"Reading data from {len(episode_dict)} episodes") + logging.info(f"Reading data from {len(episode_dict)} episodes") index = 0 if _has_tqdm: from tqdm import tqdm diff --git a/torchrl/data/datasets/vd4rl.py b/torchrl/data/datasets/vd4rl.py index b3313ef8812..a6e79f9b266 100644 --- a/torchrl/data/datasets/vd4rl.py +++ b/torchrl/data/datasets/vd4rl.py @@ -6,6 +6,7 @@ import importlib import json +import logging import os import pathlib import shutil @@ -280,7 +281,7 @@ def _download_and_preproc(cls, dataset_id, data_path): zip(paths_to_proc, files_to_proc), ) files = list(files) - print("Downloaded, processing files") + logging.info("Downloaded, processing files") if _has_tqdm: import tqdm @@ -308,7 +309,7 @@ def _download_and_preproc(cls, dataset_id, data_path): # From this point, the local paths are non needed anymore td_save = td_save.expand(total_steps).memmap_like(data_path, num_threads=32) - print("Saved tensordict:", td_save) + logging.info("Saved tensordict:", td_save) idx0 = 0 idx1 = 0 while len(files): diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index e2085e8da97..6134a820dbf 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -5,6 +5,7 @@ import abc import json +import logging import os import textwrap import warnings @@ -640,7 +641,7 @@ def __init__(self, max_size, device="cpu"): def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: if VERBOSE: - print("Creating a TensorStorage...") + logging.info("Creating a TensorStorage...") if self.device == "auto": self.device = data.device if isinstance(data, torch.Tensor): @@ -800,7 +801,7 @@ def load_state_dict(self, state_dict): def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: if VERBOSE: - print("Creating a MemmapStorage...") + logging.info("Creating a MemmapStorage...") if self.device == "auto": self.device = data.device if self.device.type != "cpu": @@ -819,7 +820,7 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: ): if VERBOSE: filesize = os.path.getsize(tensor.filename) / 1024 / 1024 - print( + logging.info( f"\t{key}: {tensor.filename}, {filesize} Mb of storage (size: {tensor.shape})." ) else: @@ -834,7 +835,7 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: ) if VERBOSE: filesize = os.path.getsize(out.filename) / 1024 / 1024 - print( + logging.info( f"The storage was created in {out.filename} and occupies {filesize} Mb of storage." ) self._storage = out diff --git a/torchrl/data/rlhf/dataset.py b/torchrl/data/rlhf/dataset.py index 3d8f7fa6de1..35f4e99914c 100644 --- a/torchrl/data/rlhf/dataset.py +++ b/torchrl/data/rlhf/dataset.py @@ -5,6 +5,7 @@ from __future__ import annotations import importlib.util +import logging import os from pathlib import Path @@ -140,7 +141,7 @@ def load(self): data_dir = root_dir / str(Path(self.dataset_name).name).split("-")[0] data_dir_total = data_dir / split / str(max_length) # search for data - print("Looking for data in", data_dir_total) + logging.info("Looking for data in", data_dir_total) if os.path.exists(data_dir_total): dataset = TensorDict.load_memmap(data_dir_total) return dataset diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 082fa9ea50e..03262fcdd1d 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -5,6 +5,8 @@ from __future__ import annotations +import logging + import os from collections import OrderedDict from copy import deepcopy @@ -489,7 +491,7 @@ def close(self) -> None: if self.is_closed: raise RuntimeError("trying to close a closed environment") if self._verbose: - print(f"closing {self.__class__.__name__}") + logging.info(f"closing {self.__class__.__name__}") self.__dict__["_input_spec"] = None self.__dict__["_output_spec"] = None @@ -783,7 +785,7 @@ def _start_workers(self) -> None: with clear_mpi_env_vars(): for idx in range(_num_workers): if self._verbose: - print(f"initiating worker {idx}") + logging.info(f"initiating worker {idx}") # No certainty which module multiprocessing_context is parent_pipe, child_pipe = ctx.Pipe() env_fun = self.create_env_fn[idx] @@ -1036,7 +1038,7 @@ def _shutdown_workers(self) -> None: ) for i, channel in enumerate(self.parent_channels): if self._verbose: - print(f"closing {i}") + logging.info(f"closing {i}") channel.send(("close", None)) self._events[i].wait() self._events[i].clear() @@ -1193,7 +1195,7 @@ def _run_worker_pipe_shared_mem( elif cmd == "init": if verbose: - print(f"initializing {pid}") + logging.info(f"initializing {pid}") if initialized: raise RuntimeError("worker already initialized") i = 0 @@ -1209,7 +1211,7 @@ def _run_worker_pipe_shared_mem( elif cmd == "reset": if verbose: - print(f"resetting worker {pid}") + logging.info(f"resetting worker {pid}") if not initialized: raise RuntimeError("call 'init' before resetting") cur_td = env.reset(tensordict=data) @@ -1255,7 +1257,7 @@ def _run_worker_pipe_shared_mem( mp_event.set() child_pipe.close() if verbose: - print(f"{pid} closed") + logging.info(f"{pid} closed") break elif cmd == "load_state_dict": diff --git a/torchrl/envs/env_creator.py b/torchrl/envs/env_creator.py index 6dff69e8d5f..9053b42f7f6 100644 --- a/torchrl/envs/env_creator.py +++ b/torchrl/envs/env_creator.py @@ -5,6 +5,7 @@ from __future__ import annotations +import logging from collections import OrderedDict from typing import Callable, Dict, Optional, Union @@ -97,7 +98,7 @@ def share_memory(self, state_dict: OrderedDict) -> None: if not item.is_shared(): item.share_memory_() else: - print( + logging.info( f"{self.env_type}: {item} is already shared" ) # , deleting key') del state_dict[key] diff --git a/torchrl/envs/libs/dm_control.py b/torchrl/envs/libs/dm_control.py index 89b402bd904..371234e6df8 100644 --- a/torchrl/envs/libs/dm_control.py +++ b/torchrl/envs/libs/dm_control.py @@ -7,6 +7,7 @@ import collections import importlib +import logging import os from typing import Any, Dict, Optional, Tuple, Union @@ -32,7 +33,7 @@ n = torch.cuda.device_count() - 1 os.environ["EGL_DEVICE_ID"] = str(1 + (os.getpid() % n)) if VERBOSE: - print("EGL_DEVICE_ID: ", os.environ["EGL_DEVICE_ID"]) + logging.info("EGL_DEVICE_ID: ", os.environ["EGL_DEVICE_ID"]) _has_dmc = _has_dm_control = importlib.util.find_spec("dm_control") is not None diff --git a/torchrl/envs/libs/envpool.py b/torchrl/envs/libs/envpool.py index d78eb3ca1b5..1a83840fb19 100644 --- a/torchrl/envs/libs/envpool.py +++ b/torchrl/envs/libs/envpool.py @@ -269,7 +269,7 @@ def _treevalue_to_dict( def _set_seed(self, seed: Optional[int]): if seed is not None: - print( + logging.info( "MultiThreadedEnvWrapper._set_seed ignored, as setting seed in an existing envorinment is not\ supported by envpool. Please create a new environment, passing the seed to the constructor." ) diff --git a/torchrl/envs/transforms/vc1.py b/torchrl/envs/transforms/vc1.py index e32a3632c4f..5cb038b699a 100644 --- a/torchrl/envs/transforms/vc1.py +++ b/torchrl/envs/transforms/vc1.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import importlib +import logging import os import subprocess from functools import partial @@ -236,12 +237,11 @@ def install_vc_models(cls, auto_exit=False): try: from vc_models import models # noqa: F401 - print("vc_models found, no need to install.") + logging.info("vc_models found, no need to install.") except ModuleNotFoundError: HOME = os.environ.get("HOME") vcdir = HOME + "/.cache/torchrl/eai-vc" parentdir = os.path.dirname(os.path.abspath(vcdir)) - print(parentdir) os.makedirs(parentdir, exist_ok=True) try: from git import Repo diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index e4f0ba87ebf..6605301ed3b 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -7,6 +7,7 @@ import contextlib import importlib.util +import logging import os import re from enum import Enum @@ -517,7 +518,7 @@ def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0): f"spec check failed at root for spec {name}={spec} and data {td}." ) - print("check_env_specs succeeded!") + logging.info("check_env_specs succeeded!") def _selective_unsqueeze(tensor: torch.Tensor, batch_size: torch.Size, dim: int = -1): diff --git a/torchrl/record/loggers/csv.py b/torchrl/record/loggers/csv.py index 4f7ae47606a..6db921f3201 100644 --- a/torchrl/record/loggers/csv.py +++ b/torchrl/record/loggers/csv.py @@ -86,7 +86,6 @@ def __init__(self, exp_name: str, log_dir: Optional[str] = None) -> None: super().__init__(exp_name=exp_name, log_dir=log_dir) self._has_imported_moviepy = False - print(f"self.log_dir: {self.experiment.log_dir}") def _create_experiment(self) -> "CSVExperiment": """Creates a CSV experiment.""" diff --git a/torchrl/trainers/helpers/envs.py b/torchrl/trainers/helpers/envs.py index 582dace8ab9..265be40b785 100644 --- a/torchrl/trainers/helpers/envs.py +++ b/torchrl/trainers/helpers/envs.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging from copy import copy from dataclasses import dataclass, field as dataclass_field from typing import Any, Callable, Optional, Sequence, Tuple, Union @@ -393,7 +394,7 @@ def get_stats_random_rollout( )() if VERBOSE: - print("computing state stats") + logging.info("computing state stats") if not hasattr(cfg, "init_env_steps"): raise AttributeError("init_env_steps missing from arguments.") @@ -426,7 +427,7 @@ def get_stats_random_rollout( s[s == 0] = 1.0 if VERBOSE: - print( + logging.info( f"stats computed for {val_stats.numel()} steps. Got: \n" f"loc = {m}, \n" f"scale = {s}" diff --git a/torchrl/trainers/helpers/trainers.py b/torchrl/trainers/helpers/trainers.py index a2764df2912..13d5ae4c968 100644 --- a/torchrl/trainers/helpers/trainers.py +++ b/torchrl/trainers/helpers/trainers.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging from dataclasses import dataclass from typing import List, Optional, Union from warnings import warn @@ -173,7 +174,7 @@ def make_trainer( raise NotImplementedError(f"lr scheduler {cfg.lr_scheduler}") if VERBOSE: - print( + logging.info( f"collector = {collector}; \n" f"loss_module = {loss_module}; \n" f"recorder = {recorder}; \n" diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index fead31f742d..6985037d17c 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -6,6 +6,7 @@ from __future__ import annotations import abc +import logging import pathlib import warnings from collections import defaultdict, OrderedDict @@ -475,7 +476,7 @@ def __del__(self): def shutdown(self): if VERBOSE: - print("shutting down collector") + logging.info("shutting down collector") self.collector.shutdown() def optim_steps(self, batch: TensorDictBase) -> None: