From 049a5ba54071409c4d62642e908d59fa03139c08 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 12 Nov 2024 12:46:33 +0000 Subject: [PATCH] [Feature] DDPG compatibility with compile ghstack-source-id: 9d88c127091ad6e711ef49e9f48aaedad15b1c05 Pull Request resolved: https://github.com/pytorch/rl/pull/2555 --- sota-implementations/ddpg/config.yaml | 5 +- sota-implementations/ddpg/ddpg.py | 125 +++++++++++++++----------- sota-implementations/ddpg/utils.py | 11 ++- 3 files changed, 85 insertions(+), 56 deletions(-) diff --git a/sota-implementations/ddpg/config.yaml b/sota-implementations/ddpg/config.yaml index 43cb5093c09..ec90e59787f 100644 --- a/sota-implementations/ddpg/config.yaml +++ b/sota-implementations/ddpg/config.yaml @@ -13,7 +13,7 @@ collector: frames_per_batch: 1000 init_env_steps: 1000 reset_at_each_iter: False - device: cpu + device: env_per_collector: 1 @@ -39,6 +39,9 @@ network: hidden_sizes: [256, 256] activation: relu noise_type: "ou" # ou or gaussian + compile: False + compile_mode: + cudagraphs: False # logging logger: diff --git a/sota-implementations/ddpg/ddpg.py b/sota-implementations/ddpg/ddpg.py index cebc3685625..cc7663ea9f6 100644 --- a/sota-implementations/ddpg/ddpg.py +++ b/sota-implementations/ddpg/ddpg.py @@ -10,7 +10,7 @@ The helper functions are coded in the utils.py associated with this script. """ -import time +import warnings import hydra @@ -18,9 +18,13 @@ import torch import torch.cuda import tqdm -from torchrl._utils import logger as torchrl_logger +from tensordict import TensorDict +from tensordict.nn import CudaGraphModule + +from torchrl._utils import timeit from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.objectives import group_optimizers from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( dump_video, @@ -73,8 +77,24 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create DDPG loss loss_module, target_net_updater = make_loss_module(cfg, model) + compile_mode = None + if cfg.network.compile: + if cfg.network.compile_mode not in (None, ""): + compile_mode = cfg.network.compile_mode + elif cfg.network.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + # Create off-policy collector - collector = make_collector(cfg, train_env, exploration_policy) + collector = make_collector( + cfg, + train_env, + exploration_policy, + compile=cfg.network.compile, + compile_mode=compile_mode, + cudagraph=cfg.network.cudagraphs, + ) # Create replay buffer replay_buffer = make_replay_buffer( @@ -87,9 +107,29 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create optimizers optimizer_actor, optimizer_critic = make_optimizer(cfg, loss_module) + optimizer = group_optimizers(optimizer_actor, optimizer_critic) + + def update(sampled_tensordict): + optimizer.zero_grad(set_to_none=True) + + td_loss: TensorDict = loss_module(sampled_tensordict) + td_loss.sum(reduce=True).backward() + optimizer.step() + + # Update qnet_target params + target_net_updater.step() + return td_loss.detach() + + if cfg.loss.compile: + update = torch.compile(update, mode=compile_mode) + if cfg.loss.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) + update = CudaGraphModule(update, warmup=50) # Main loop - start_time = time.time() collected_frames = 0 pbar = tqdm.tqdm(total=cfg.collector.total_frames) @@ -104,63 +144,42 @@ def main(cfg: "DictConfig"): # noqa: F821 eval_iter = cfg.logger.eval_iter eval_rollout_steps = cfg.env.max_episode_steps - sampling_start = time.time() - for _, tensordict in enumerate(collector): - sampling_time = time.time() - sampling_start + c_iter = iter(collector) + for i in range(len(collector)): + with timeit("collecting"): + tensordict = next(c_iter) # Update exploration policy exploration_policy[1].step(tensordict.numel()) # Update weights of the inference policy collector.update_policy_weights_() - pbar.update(tensordict.numel()) - - tensordict = tensordict.reshape(-1) current_frames = tensordict.numel() + pbar.update(current_frames) + # Add to replay buffer - replay_buffer.extend(tensordict.cpu()) + with timeit("rb - extend"): + tensordict = tensordict.reshape(-1) + replay_buffer.extend(tensordict) + collected_frames += current_frames # Optimization steps - training_start = time.time() if collected_frames >= init_random_frames: - ( - actor_losses, - q_losses, - ) = ([], []) + tds = [] for _ in range(num_updates): # Sample from replay buffer - sampled_tensordict = replay_buffer.sample() - if sampled_tensordict.device != device: - sampled_tensordict = sampled_tensordict.to( - device, non_blocking=True - ) - else: - sampled_tensordict = sampled_tensordict.clone() - - # Update critic - q_loss, *_ = loss_module.loss_value(sampled_tensordict) - optimizer_critic.zero_grad() - q_loss.backward() - optimizer_critic.step() - - # Update actor - actor_loss, *_ = loss_module.loss_actor(sampled_tensordict) - optimizer_actor.zero_grad() - actor_loss.backward() - optimizer_actor.step() - - q_losses.append(q_loss.item()) - actor_losses.append(actor_loss.item()) - - # Update qnet_target params - target_net_updater.step() + with timeit("rb - sample"): + sampled_tensordict = replay_buffer.sample().to(device) + with timeit("update"): + td_loss = update(sampled_tensordict) + tds.append(td_loss.clone()) # Update priority if prb: replay_buffer.update_priority(sampled_tensordict) + tds = torch.stack(tds) - training_time = time.time() - training_start episode_end = ( tensordict["next", "done"] if tensordict["next", "done"].any() @@ -178,15 +197,14 @@ def main(cfg: "DictConfig"): # noqa: F821 ) if collected_frames >= init_random_frames: - metrics_to_log["train/q_loss"] = np.mean(q_losses) - metrics_to_log["train/a_loss"] = np.mean(actor_losses) - metrics_to_log["train/sampling_time"] = sampling_time - metrics_to_log["train/training_time"] = training_time + tds = TensorDict(train=tds).flatten_keys("/").mean() + metrics_to_log.update(tds.to_dict()) # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): - eval_start = time.time() + with set_exploration_type( + ExplorationType.DETERMINISTIC + ), torch.no_grad(), timeit("eval"): eval_rollout = eval_env.rollout( eval_rollout_steps, exploration_policy, @@ -194,22 +212,21 @@ def main(cfg: "DictConfig"): # noqa: F821 break_when_any_done=True, ) eval_env.apply(dump_video) - eval_time = time.time() - eval_start eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() metrics_to_log["eval/reward"] = eval_reward - metrics_to_log["eval/time"] = eval_time + if i % 20 == 0: + metrics_to_log.update(timeit.todict(prefix="time")) + timeit.print() + timeit.erase() + if logger is not None: log_metrics(logger, metrics_to_log, collected_frames) - sampling_start = time.time() collector.shutdown() - end_time = time.time() - execution_time = end_time - start_time if not eval_env.is_closed: eval_env.close() if not train_env.is_closed: train_env.close() - torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/sota-implementations/ddpg/utils.py b/sota-implementations/ddpg/utils.py index 338081a7e8d..0a1a4a25f8a 100644 --- a/sota-implementations/ddpg/utils.py +++ b/sota-implementations/ddpg/utils.py @@ -113,7 +113,14 @@ def make_environment(cfg, logger): # --------------------------- -def make_collector(cfg, train_env, actor_model_explore): +def make_collector( + cfg, + train_env, + actor_model_explore, + compile=False, + compile_mode=None, + cudagraph=False, +): """Make collector.""" collector = SyncDataCollector( train_env, @@ -123,6 +130,8 @@ def make_collector(cfg, train_env, actor_model_explore): reset_at_each_iter=cfg.collector.reset_at_each_iter, total_frames=cfg.collector.total_frames, device=cfg.collector.device, + compile_policy={"mode": compile_mode} if compile else False, + cudagraph_policy=cudagraph, ) collector.set_seed(cfg.env.seed) return collector