Skip to content

Commit

Permalink
[Feature] DDPG compatibility with compile
Browse files Browse the repository at this point in the history
ghstack-source-id: 25f076bfb94d0a22dc6f9222faf893e515dc9291
Pull Request resolved: #2555
  • Loading branch information
vmoens committed Nov 12, 2024
1 parent 672ebe4 commit 8b1c094
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 112 deletions.
5 changes: 4 additions & 1 deletion sota-implementations/ddpg/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ collector:
frames_per_batch: 1000
init_env_steps: 1000
reset_at_each_iter: False
device: cpu
device:
env_per_collector: 1


Expand All @@ -39,6 +39,9 @@ network:
hidden_sizes: [256, 256]
activation: relu
noise_type: "ou" # ou or gaussian
compile: False
compile_mode:
cudagraphs: False

# logging
logger:
Expand Down
125 changes: 71 additions & 54 deletions sota-implementations/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,21 @@
The helper functions are coded in the utils.py associated with this script.
"""
import time
import warnings

import hydra

import numpy as np
import torch
import torch.cuda
import tqdm
from torchrl._utils import logger as torchrl_logger
from tensordict import TensorDict
from tensordict.nn import CudaGraphModule

from torchrl._utils import timeit

from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.objectives import group_optimizers
from torchrl.record.loggers import generate_exp_name, get_logger
from utils import (
dump_video,
Expand Down Expand Up @@ -73,8 +77,24 @@ def main(cfg: "DictConfig"): # noqa: F821
# Create DDPG loss
loss_module, target_net_updater = make_loss_module(cfg, model)

compile_mode = None
if cfg.network.compile:
if cfg.network.compile_mode not in (None, ""):
compile_mode = cfg.network.compile_mode
elif cfg.network.cudagraphs:
compile_mode = "default"
else:
compile_mode = "reduce-overhead"

# Create off-policy collector
collector = make_collector(cfg, train_env, exploration_policy)
collector = make_collector(
cfg,
train_env,
exploration_policy,
compile=cfg.network.compile,
compile_mode=compile_mode,
cudagraph=cfg.network.cudagraphs,
)

# Create replay buffer
replay_buffer = make_replay_buffer(
Expand All @@ -87,9 +107,29 @@ def main(cfg: "DictConfig"): # noqa: F821

# Create optimizers
optimizer_actor, optimizer_critic = make_optimizer(cfg, loss_module)
optimizer = group_optimizers(optimizer_actor, optimizer_critic)

def update(sampled_tensordict):
optimizer.zero_grad(set_to_none=True)

td_loss: TensorDict = loss_module(sampled_tensordict)
td_loss.sum(reduce=True).backward()
optimizer.step()

# Update qnet_target params
target_net_updater.step()
return td_loss.detach()

if cfg.network.compile:
update = torch.compile(update, mode=compile_mode)
if cfg.network.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, warmup=50)

# Main loop
start_time = time.time()
collected_frames = 0
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

Expand All @@ -104,63 +144,42 @@ def main(cfg: "DictConfig"): # noqa: F821
eval_iter = cfg.logger.eval_iter
eval_rollout_steps = cfg.env.max_episode_steps

sampling_start = time.time()
for _, tensordict in enumerate(collector):
sampling_time = time.time() - sampling_start
c_iter = iter(collector)
for i in range(len(collector)):
with timeit("collecting"):
tensordict = next(c_iter)
# Update exploration policy
exploration_policy[1].step(tensordict.numel())

# Update weights of the inference policy
collector.update_policy_weights_()

pbar.update(tensordict.numel())

tensordict = tensordict.reshape(-1)
current_frames = tensordict.numel()
pbar.update(current_frames)

# Add to replay buffer
replay_buffer.extend(tensordict.cpu())
with timeit("rb - extend"):
tensordict = tensordict.reshape(-1)
replay_buffer.extend(tensordict)

collected_frames += current_frames

# Optimization steps
training_start = time.time()
if collected_frames >= init_random_frames:
(
actor_losses,
q_losses,
) = ([], [])
tds = []
for _ in range(num_updates):
# Sample from replay buffer
sampled_tensordict = replay_buffer.sample()
if sampled_tensordict.device != device:
sampled_tensordict = sampled_tensordict.to(
device, non_blocking=True
)
else:
sampled_tensordict = sampled_tensordict.clone()

# Update critic
q_loss, *_ = loss_module.loss_value(sampled_tensordict)
optimizer_critic.zero_grad()
q_loss.backward()
optimizer_critic.step()

# Update actor
actor_loss, *_ = loss_module.loss_actor(sampled_tensordict)
optimizer_actor.zero_grad()
actor_loss.backward()
optimizer_actor.step()

q_losses.append(q_loss.item())
actor_losses.append(actor_loss.item())

# Update qnet_target params
target_net_updater.step()
with timeit("rb - sample"):
sampled_tensordict = replay_buffer.sample().to(device)
with timeit("update"):
td_loss = update(sampled_tensordict)
tds.append(td_loss.clone())

# Update priority
if prb:
replay_buffer.update_priority(sampled_tensordict)
tds = torch.stack(tds)

training_time = time.time() - training_start
episode_end = (
tensordict["next", "done"]
if tensordict["next", "done"].any()
Expand All @@ -178,38 +197,36 @@ def main(cfg: "DictConfig"): # noqa: F821
)

if collected_frames >= init_random_frames:
metrics_to_log["train/q_loss"] = np.mean(q_losses)
metrics_to_log["train/a_loss"] = np.mean(actor_losses)
metrics_to_log["train/sampling_time"] = sampling_time
metrics_to_log["train/training_time"] = training_time
tds = TensorDict(train=tds).flatten_keys("/").mean()
metrics_to_log.update(tds.to_dict())

# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch:
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_start = time.time()
with set_exploration_type(
ExplorationType.DETERMINISTIC
), torch.no_grad(), timeit("eval"):
eval_rollout = eval_env.rollout(
eval_rollout_steps,
exploration_policy,
auto_cast_to_device=True,
break_when_any_done=True,
)
eval_env.apply(dump_video)
eval_time = time.time() - eval_start
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
metrics_to_log["eval/reward"] = eval_reward
metrics_to_log["eval/time"] = eval_time
if i % 20 == 0:
metrics_to_log.update(timeit.todict(prefix="time"))
timeit.print()
timeit.erase()

if logger is not None:
log_metrics(logger, metrics_to_log, collected_frames)
sampling_start = time.time()

collector.shutdown()
end_time = time.time()
execution_time = end_time - start_time
if not eval_env.is_closed:
eval_env.close()
if not train_env.is_closed:
train_env.close()
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")


if __name__ == "__main__":
Expand Down
30 changes: 17 additions & 13 deletions sota-implementations/ddpg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch

from tensordict.nn import TensorDictSequential
from tensordict.nn import TensorDictModule, TensorDictSequential

from torch import nn, optim
from torchrl.collectors import SyncDataCollector
Expand All @@ -30,8 +30,6 @@
AdditiveGaussianModule,
MLP,
OrnsteinUhlenbeckProcessModule,
SafeModule,
SafeSequential,
TanhModule,
ValueOperator,
)
Expand Down Expand Up @@ -113,7 +111,14 @@ def make_environment(cfg, logger):
# ---------------------------


def make_collector(cfg, train_env, actor_model_explore):
def make_collector(
cfg,
train_env,
actor_model_explore,
compile=False,
compile_mode=None,
cudagraph=False,
):
"""Make collector."""
collector = SyncDataCollector(
train_env,
Expand All @@ -123,6 +128,8 @@ def make_collector(cfg, train_env, actor_model_explore):
reset_at_each_iter=cfg.collector.reset_at_each_iter,
total_frames=cfg.collector.total_frames,
device=cfg.collector.device,
compile_policy={"mode": compile_mode} if compile else False,
cudagraph_policy=cudagraph,
)
collector.set_seed(cfg.env.seed)
return collector
Expand Down Expand Up @@ -172,9 +179,7 @@ def make_ddpg_agent(cfg, train_env, eval_env, device):
"""Make DDPG agent."""
# Define Actor Network
in_keys = ["observation"]
action_spec = train_env.action_spec
if train_env.batch_size:
action_spec = action_spec[(0,) * len(train_env.batch_size)]
action_spec = train_env.single_action_spec
actor_net_kwargs = {
"num_cells": cfg.network.hidden_sizes,
"out_features": action_spec.shape[-1],
Expand All @@ -184,19 +189,16 @@ def make_ddpg_agent(cfg, train_env, eval_env, device):
actor_net = MLP(**actor_net_kwargs)

in_keys_actor = in_keys
actor_module = SafeModule(
actor_module = TensorDictModule(
actor_net,
in_keys=in_keys_actor,
out_keys=[
"param",
],
out_keys=["param"],
)
actor = SafeSequential(
actor = TensorDictSequential(
actor_module,
TanhModule(
in_keys=["param"],
out_keys=["action"],
spec=action_spec,
),
)

Expand Down Expand Up @@ -234,6 +236,7 @@ def make_ddpg_agent(cfg, train_env, eval_env, device):
OrnsteinUhlenbeckProcessModule(
spec=action_spec,
annealing_num_steps=1_000_000,
safe=False,
).to(device),
)
elif cfg.network.noise_type == "gaussian":
Expand All @@ -245,6 +248,7 @@ def make_ddpg_agent(cfg, train_env, eval_env, device):
sigma_init=1.0,
mean=0.0,
std=0.1,
safe=False,
).to(device),
)
else:
Expand Down
8 changes: 8 additions & 0 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1392,6 +1392,7 @@ def type_check(self, value: torch.Tensor, key: NestedKey | None = None) -> None:
spec.type_check(val)

def is_in(self, value) -> bool:
raise RuntimeError
if self.dim == 0 and not hasattr(value, "unbind"):
# We don't use unbind because value could be a tuple or a nested tensor
return all(
Expand Down Expand Up @@ -1796,6 +1797,7 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
return val

def is_in(self, val: torch.Tensor) -> bool:
raise RuntimeError
if self.mask is None:
shape = torch.broadcast_shapes(self._safe_shape, val.shape)
shape_match = val.shape == shape
Expand Down Expand Up @@ -2246,6 +2248,7 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
return val

def is_in(self, val: torch.Tensor) -> bool:
raise RuntimeError
val_shape = _remove_neg_shapes(tensordict.utils._shape(val))
shape = torch.broadcast_shapes(self._safe_shape, val_shape)
shape = list(shape)
Expand Down Expand Up @@ -2443,6 +2446,7 @@ def one(self, shape=None):
)

def is_in(self, val: torch.Tensor) -> bool:
raise RuntimeError
shape = torch.broadcast_shapes(self._safe_shape, val.shape)
return (
isinstance(val, NonTensorData)
Expand Down Expand Up @@ -2635,6 +2639,7 @@ def rand(self, shape: torch.Size = None) -> torch.Tensor:
return torch.empty(shape, device=self.device, dtype=self.dtype).random_()

def is_in(self, val: torch.Tensor) -> bool:
raise RuntimeError
shape = torch.broadcast_shapes(self._safe_shape, val.shape)
return val.shape == shape and val.dtype == self.dtype

Expand Down Expand Up @@ -2983,6 +2988,7 @@ def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Ten
return torch.cat(out, -1)

def is_in(self, val: torch.Tensor) -> bool:
raise RuntimeError
vals = self._split(val)
if vals is None:
return False
Expand Down Expand Up @@ -3328,6 +3334,7 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
return val

def is_in(self, val: torch.Tensor) -> bool:
raise RuntimeError
if self.mask is None:
shape = torch.broadcast_shapes(self._safe_shape, val.shape)
shape_match = val.shape == shape
Expand Down Expand Up @@ -3953,6 +3960,7 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
return val.squeeze(0) if val_is_scalar else val

def is_in(self, val: torch.Tensor) -> bool:
raise RuntimeError
if self.mask is not None:
vals = val.unbind(-1)
splits = self._split_self()
Expand Down
3 changes: 2 additions & 1 deletion torchrl/modules/tensordict_module/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def _forward_hook_safe_action(module, tensordict_in, tensordict_out):
keys = [out_key]
values = [spec]
else:
keys = list(spec.keys(True, True))
# Make dynamo happy with the list creation
keys = [key for key in spec.keys(True, True)] # noqa: C416
values = [spec[key] for key in keys]
for _spec, _key in zip(values, keys):
if _spec is None:
Expand Down
Loading

0 comments on commit 8b1c094

Please sign in to comment.