Skip to content

Commit

Permalink
[Feature] DDPG compatibility with compile
Browse files Browse the repository at this point in the history
ghstack-source-id: 5b4f0e90f4aa7ded0128f3729f30cbc69e3e22fa
Pull Request resolved: #2555
  • Loading branch information
vmoens committed Nov 12, 2024
1 parent ce5628a commit 9a5229c
Show file tree
Hide file tree
Showing 7 changed files with 188 additions and 118 deletions.
5 changes: 4 additions & 1 deletion sota-implementations/ddpg/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ collector:
frames_per_batch: 1000
init_env_steps: 1000
reset_at_each_iter: False
device: cpu
device:
env_per_collector: 1


Expand All @@ -39,6 +39,9 @@ network:
hidden_sizes: [256, 256]
activation: relu
noise_type: "ou" # ou or gaussian
compile: False
compile_mode:
cudagraphs: False

# logging
logger:
Expand Down
135 changes: 81 additions & 54 deletions sota-implementations/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,21 @@
The helper functions are coded in the utils.py associated with this script.
"""
import time
import warnings

import hydra

import numpy as np
import torch
import torch.cuda
import tqdm
from torchrl._utils import logger as torchrl_logger
from tensordict import TensorDict
from tensordict.nn import CudaGraphModule

from torchrl._utils import timeit

from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.objectives import group_optimizers
from torchrl.record.loggers import generate_exp_name, get_logger
from utils import (
dump_video,
Expand All @@ -44,6 +48,14 @@ def main(cfg: "DictConfig"): # noqa: F821
device = "cpu"
device = torch.device(device)

collector_device = cfg.collector.device
if collector_device in ("", None):
if torch.cuda.is_available():
collector_device = "cuda:0"
else:
collector_device = "cpu"
collector_device = torch.device(collector_device)

# Create logger
exp_name = generate_exp_name("DDPG", cfg.logger.exp_name)
logger = None
Expand Down Expand Up @@ -73,8 +85,25 @@ def main(cfg: "DictConfig"): # noqa: F821
# Create DDPG loss
loss_module, target_net_updater = make_loss_module(cfg, model)

compile_mode = None
if cfg.network.compile:
if cfg.network.compile_mode not in (None, ""):
compile_mode = cfg.network.compile_mode
elif cfg.network.cudagraphs:
compile_mode = "default"
else:
compile_mode = "reduce-overhead"

# Create off-policy collector
collector = make_collector(cfg, train_env, exploration_policy)
collector = make_collector(
cfg,
train_env,
exploration_policy,
compile=cfg.network.compile,
compile_mode=compile_mode,
cudagraph=cfg.network.cudagraphs,
device=collector_device,
)

# Create replay buffer
replay_buffer = make_replay_buffer(
Expand All @@ -87,9 +116,29 @@ def main(cfg: "DictConfig"): # noqa: F821

# Create optimizers
optimizer_actor, optimizer_critic = make_optimizer(cfg, loss_module)
optimizer = group_optimizers(optimizer_actor, optimizer_critic)

def update(sampled_tensordict):
optimizer.zero_grad(set_to_none=True)

td_loss: TensorDict = loss_module(sampled_tensordict)
td_loss.sum(reduce=True).backward()
optimizer.step()

# Update qnet_target params
target_net_updater.step()
return td_loss.detach()

if cfg.network.compile:
update = torch.compile(update, mode=compile_mode)
if cfg.network.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, warmup=50)

# Main loop
start_time = time.time()
collected_frames = 0
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

Expand All @@ -104,63 +153,43 @@ def main(cfg: "DictConfig"): # noqa: F821
eval_iter = cfg.logger.eval_iter
eval_rollout_steps = cfg.env.max_episode_steps

sampling_start = time.time()
for _, tensordict in enumerate(collector):
sampling_time = time.time() - sampling_start
c_iter = iter(collector)
for i in range(len(collector)):
with timeit("collecting"):
tensordict = next(c_iter)
# Update exploration policy
exploration_policy[1].step(tensordict.numel())

# Update weights of the inference policy
collector.update_policy_weights_()

pbar.update(tensordict.numel())

tensordict = tensordict.reshape(-1)
current_frames = tensordict.numel()
pbar.update(current_frames)

# Add to replay buffer
replay_buffer.extend(tensordict.cpu())
with timeit("rb - extend"):
tensordict = tensordict.reshape(-1)
replay_buffer.extend(tensordict)

collected_frames += current_frames

# Optimization steps
training_start = time.time()
if collected_frames >= init_random_frames:
(
actor_losses,
q_losses,
) = ([], [])
tds = []
for _ in range(num_updates):
# Sample from replay buffer
sampled_tensordict = replay_buffer.sample()
if sampled_tensordict.device != device:
sampled_tensordict = sampled_tensordict.to(
device, non_blocking=True
)
else:
sampled_tensordict = sampled_tensordict.clone()

# Update critic
q_loss, *_ = loss_module.loss_value(sampled_tensordict)
optimizer_critic.zero_grad()
q_loss.backward()
optimizer_critic.step()

# Update actor
actor_loss, *_ = loss_module.loss_actor(sampled_tensordict)
optimizer_actor.zero_grad()
actor_loss.backward()
optimizer_actor.step()

q_losses.append(q_loss.item())
actor_losses.append(actor_loss.item())

# Update qnet_target params
target_net_updater.step()
with timeit("rb - sample"):
sampled_tensordict = replay_buffer.sample().to(device)
with timeit("update"):
torch.compiler.cudagraph_mark_step_begin()
td_loss = update(sampled_tensordict)
tds.append(td_loss.clone())

# Update priority
if prb:
replay_buffer.update_priority(sampled_tensordict)
tds = torch.stack(tds)

training_time = time.time() - training_start
episode_end = (
tensordict["next", "done"]
if tensordict["next", "done"].any()
Expand All @@ -178,38 +207,36 @@ def main(cfg: "DictConfig"): # noqa: F821
)

if collected_frames >= init_random_frames:
metrics_to_log["train/q_loss"] = np.mean(q_losses)
metrics_to_log["train/a_loss"] = np.mean(actor_losses)
metrics_to_log["train/sampling_time"] = sampling_time
metrics_to_log["train/training_time"] = training_time
tds = TensorDict(train=tds).flatten_keys("/").mean()
metrics_to_log.update(tds.to_dict())

# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch:
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_start = time.time()
with set_exploration_type(
ExplorationType.DETERMINISTIC
), torch.no_grad(), timeit("eval"):
eval_rollout = eval_env.rollout(
eval_rollout_steps,
exploration_policy,
auto_cast_to_device=True,
break_when_any_done=True,
)
eval_env.apply(dump_video)
eval_time = time.time() - eval_start
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
metrics_to_log["eval/reward"] = eval_reward
metrics_to_log["eval/time"] = eval_time
if i % 20 == 0:
metrics_to_log.update(timeit.todict(prefix="time"))
timeit.print()
timeit.erase()

if logger is not None:
log_metrics(logger, metrics_to_log, collected_frames)
sampling_start = time.time()

collector.shutdown()
end_time = time.time()
execution_time = end_time - start_time
if not eval_env.is_closed:
eval_env.close()
if not train_env.is_closed:
train_env.close()
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")


if __name__ == "__main__":
Expand Down
33 changes: 19 additions & 14 deletions sota-implementations/ddpg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch

from tensordict.nn import TensorDictSequential
from tensordict.nn import TensorDictModule, TensorDictSequential

from torch import nn, optim
from torchrl.collectors import SyncDataCollector
Expand All @@ -30,8 +30,6 @@
AdditiveGaussianModule,
MLP,
OrnsteinUhlenbeckProcessModule,
SafeModule,
SafeSequential,
TanhModule,
ValueOperator,
)
Expand Down Expand Up @@ -113,7 +111,15 @@ def make_environment(cfg, logger):
# ---------------------------


def make_collector(cfg, train_env, actor_model_explore):
def make_collector(
cfg,
train_env,
actor_model_explore,
compile=False,
compile_mode=None,
cudagraph=False,
device: torch.device|None=None,
):
"""Make collector."""
collector = SyncDataCollector(
train_env,
Expand All @@ -122,7 +128,9 @@ def make_collector(cfg, train_env, actor_model_explore):
init_random_frames=cfg.collector.init_random_frames,
reset_at_each_iter=cfg.collector.reset_at_each_iter,
total_frames=cfg.collector.total_frames,
device=cfg.collector.device,
device=device,
compile_policy={"mode": compile_mode} if compile else False,
cudagraph_policy=cudagraph,
)
collector.set_seed(cfg.env.seed)
return collector
Expand Down Expand Up @@ -172,9 +180,7 @@ def make_ddpg_agent(cfg, train_env, eval_env, device):
"""Make DDPG agent."""
# Define Actor Network
in_keys = ["observation"]
action_spec = train_env.action_spec
if train_env.batch_size:
action_spec = action_spec[(0,) * len(train_env.batch_size)]
action_spec = train_env.single_action_spec
actor_net_kwargs = {
"num_cells": cfg.network.hidden_sizes,
"out_features": action_spec.shape[-1],
Expand All @@ -184,19 +190,16 @@ def make_ddpg_agent(cfg, train_env, eval_env, device):
actor_net = MLP(**actor_net_kwargs)

in_keys_actor = in_keys
actor_module = SafeModule(
actor_module = TensorDictModule(
actor_net,
in_keys=in_keys_actor,
out_keys=[
"param",
],
out_keys=["param"],
)
actor = SafeSequential(
actor = TensorDictSequential(
actor_module,
TanhModule(
in_keys=["param"],
out_keys=["action"],
spec=action_spec,
),
)

Expand Down Expand Up @@ -234,6 +237,7 @@ def make_ddpg_agent(cfg, train_env, eval_env, device):
OrnsteinUhlenbeckProcessModule(
spec=action_spec,
annealing_num_steps=1_000_000,
safe=False,
).to(device),
)
elif cfg.network.noise_type == "gaussian":
Expand All @@ -245,6 +249,7 @@ def make_ddpg_agent(cfg, train_env, eval_env, device):
sigma_init=1.0,
mean=0.0,
std=0.1,
safe=False,
).to(device),
)
else:
Expand Down
7 changes: 7 additions & 0 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@
RandomPolicy,
set_exploration_type,
)
try:
from torch.compiler import cudagraph_mark_step_begin
except ImportError:
def cudagraph_mark_step_begin():
"""Placeholder when cudagraph_mark_step_begin is missing."""
...

_TIMEOUT = 1.0
INSTANTIATE_TIMEOUT = 20
Expand Down Expand Up @@ -1145,6 +1151,7 @@ def rollout(self) -> TensorDictBase:
else:
policy_input = self._shuttle
# we still do the assignment for security
cudagraph_mark_step_begin()
policy_output = self.policy(policy_input)
if self._shuttle is not policy_output:
# ad-hoc update shuttle
Expand Down
Loading

0 comments on commit 9a5229c

Please sign in to comment.