Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Huggingface Integration #292

Merged
merged 68 commits into from
Jan 12, 2023
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
1b585d6
initial commit
vwxyzjn Oct 13, 2022
fa82356
pre-commit
vwxyzjn Oct 13, 2022
4074eee
Add hub integration
vwxyzjn Oct 13, 2022
4436ce4
pre-commit
vwxyzjn Oct 14, 2022
df41e3d
use CommitOperation
vwxyzjn Oct 18, 2022
a98383d
Fix pre-commit
vwxyzjn Oct 18, 2022
b430540
refactor
vwxyzjn Oct 18, 2022
dd8ee86
Merge branch 'master' into hf-integration
vwxyzjn Oct 18, 2022
8144562
push changes
vwxyzjn Oct 27, 2022
2f20e17
refactor
vwxyzjn Oct 27, 2022
fdfc2a5
fix pre-commit
vwxyzjn Nov 16, 2022
56413f8
pre-commit
vwxyzjn Nov 16, 2022
b1b1dbd
Merge branch 'master' into hf-integration
vwxyzjn Nov 16, 2022
f6865d4
close the env and writer after eval
vwxyzjn Nov 16, 2022
fbe986c
support dqn jax
vwxyzjn Nov 17, 2022
83aa010
pre-commit
vwxyzjn Nov 17, 2022
ba1bfdb
Update cleanrl_utils/huggingface.py
vwxyzjn Nov 17, 2022
aee6809
address comments
vwxyzjn Nov 17, 2022
80a460f
update docs
vwxyzjn Nov 17, 2022
40be7d8
support dqn_atari_jax
vwxyzjn Dec 10, 2022
65ded2a
bug fix and docs
vwxyzjn Dec 13, 2022
133e6bd
Add cleanrl to the hf's `metadata`
vwxyzjn Dec 13, 2022
10d0b79
Merge branch 'master' into hf-integration
vwxyzjn Dec 15, 2022
ca60f24
include huggingface integration
vwxyzjn Dec 15, 2022
b165e35
test for enjoy.py
vwxyzjn Dec 15, 2022
7163d0d
bump version, pip install extra hack
vwxyzjn Dec 15, 2022
27d9b3d
Update cleanrl_utils/huggingface.py
vwxyzjn Dec 16, 2022
2a2208f
Update cleanrl_utils/huggingface.py
vwxyzjn Dec 16, 2022
4ac5631
Update cleanrl_utils/huggingface.py
vwxyzjn Dec 16, 2022
40358b1
Update cleanrl_utils/huggingface.py
vwxyzjn Dec 16, 2022
df68d57
Update cleanrl_utils/huggingface.py
vwxyzjn Dec 16, 2022
7dddfbd
Update cleanrl_utils/huggingface.py
vwxyzjn Dec 16, 2022
954723f
update docs
vwxyzjn Dec 16, 2022
fb858ae
update pre-commit
vwxyzjn Dec 16, 2022
b508f66
quick fix
vwxyzjn Dec 16, 2022
7d5193b
bug fix
vwxyzjn Dec 16, 2022
c390b8d
lazy load modules to avoid dependency issues
vwxyzjn Dec 20, 2022
cc456d6
Add huggingface shields
vwxyzjn Dec 20, 2022
fd5a737
Add emoji
vwxyzjn Dec 20, 2022
3b0af25
Update docs
vwxyzjn Dec 20, 2022
ff0be11
pre-commit
vwxyzjn Dec 20, 2022
9bd034e
Update docs
vwxyzjn Dec 20, 2022
78022d7
Update docs
vwxyzjn Dec 20, 2022
aae8d4d
Merge branch 'master' into hf-integration
kinalmehta Dec 30, 2022
1c2cd40
fix: use `algorithm_variant_filename` in model card reproduction script
kinalmehta Dec 31, 2022
e172a0c
typo fix
kinalmehta Dec 31, 2022
c733514
feat: add hf support for c51
kinalmehta Dec 31, 2022
15be698
formatting fix
kinalmehta Dec 31, 2022
8fac8e3
support pulling variant depdencies directly
vwxyzjn Dec 31, 2022
35d6fc7
support model saving for `ppo_atari_envpool_xla_jax_scan`
vwxyzjn Dec 31, 2022
1ce42c9
Merge branch 'master' into hf-integration
vwxyzjn Dec 31, 2022
8990794
support `ppo_atari_envpool_xla_jax_scan`
vwxyzjn Jan 1, 2023
ea4a71d
quick change
vwxyzjn Jan 1, 2023
7493ae4
support 'c51_jax'
kinalmehta Jan 1, 2023
fe34419
formatting fix
kinalmehta Jan 1, 2023
4a1f72a
support capture video
vwxyzjn Jan 3, 2023
7f22c25
Add notebook
vwxyzjn Jan 3, 2023
5331287
update docs
vwxyzjn Jan 3, 2023
9aec97e
support `c51_atari` and `c51_atari_jax`
kinalmehta Jan 4, 2023
bc8c014
Merge remote-tracking branch 'origin/hf-integration' into hf-integration
kinalmehta Jan 4, 2023
b202985
typo fix
kinalmehta Jan 4, 2023
54fd64a
add c51 to zoo docs
kinalmehta Jan 4, 2023
9e5841b
add colab badge
vwxyzjn Jan 4, 2023
9178763
fix broken colab svg
vwxyzjn Jan 4, 2023
07961f4
pypi release
vwxyzjn Jan 4, 2023
c09a80d
typo fix
vwxyzjn Jan 4, 2023
a18ffdb
update pre-commit
vwxyzjn Jan 4, 2023
ba7053a
remove hf-integration reference
vwxyzjn Jan 4, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions cleanrl/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ def parse_args():
help="the entity (team) of wandb's project")
parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to capture videos of the agent performances (check out `videos` folder)")
parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to save model into the `runs/{run_name}` folder")
parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to upload the saved model to huggingface")
parser.add_argument("--hf-entity", type=str, default="",
help="the user or org name of the model repository from the Hugging Face Hub")

# Algorithm specific arguments
parser.add_argument("--env-id", type=str, default="CartPole-v1",
Expand Down Expand Up @@ -206,5 +212,31 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
if global_step % args.target_network_frequency == 0:
target_network.load_state_dict(q_network.state_dict())

if args.save_model:
model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
torch.save(q_network.state_dict(), model_path)
print(f"model saved to {model_path}")
from cleanrl_utils.evals.dqn_eval import evaluate

episodic_returns = evaluate(
model_path,
make_env,
args.env_id,
eval_episodes=10,
run_name=f"{run_name}-eval",
Model=QNetwork,
device=device,
epsilon=0.05,
)
for idx, episodic_return in enumerate(episodic_returns):
writer.add_scalar("eval/episodic_return", episodic_return, idx)

if args.upload_model:
from cleanrl_utils.huggingface import push_to_hub

repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}"
repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name
push_to_hub(args, episodic_returns, repo_id, "DQN", f"runs/{run_name}", f"videos/{run_name}-eval")

envs.close()
writer.close()
32 changes: 32 additions & 0 deletions cleanrl/dqn_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ def parse_args():
help="the entity (team) of wandb's project")
parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to capture videos of the agent performances (check out `videos` folder)")
parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to save model into the `runs/{run_name}` folder")
parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to upload the saved model to huggingface")
parser.add_argument("--hf-entity", type=str, default="",
help="the user or org name of the model repository from the Hugging Face Hub")

# Algorithm specific arguments
parser.add_argument("--env-id", type=str, default="BreakoutNoFrameskip-v4",
Expand Down Expand Up @@ -228,5 +234,31 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
if global_step % args.target_network_frequency == 0:
target_network.load_state_dict(q_network.state_dict())

if args.save_model:
model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
torch.save(q_network.state_dict(), model_path)
print(f"model saved to {model_path}")
from cleanrl_utils.evals.dqn_eval import evaluate

episodic_returns = evaluate(
model_path,
make_env,
args.env_id,
eval_episodes=10,
run_name=f"{run_name}-eval",
Model=QNetwork,
device=device,
epsilon=0.05,
)
for idx, episodic_return in enumerate(episodic_returns):
writer.add_scalar("eval/episodic_return", episodic_return, idx)

if args.upload_model:
from cleanrl_utils.huggingface import push_to_hub

repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}"
repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name
push_to_hub(args, episodic_returns, repo_id, "DQN", f"runs/{run_name}", f"videos/{run_name}-eval")

envs.close()
writer.close()
32 changes: 32 additions & 0 deletions cleanrl/dqn_atari_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ def parse_args():
help="the entity (team) of wandb's project")
parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to capture videos of the agent performances (check out `videos` folder)")
parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to save model into the `runs/{run_name}` folder")
parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to upload the saved model to huggingface")
parser.add_argument("--hf-entity", type=str, default="",
help="the user or org name of the model repository from the Hugging Face Hub")

# Algorithm specific arguments
parser.add_argument("--env-id", type=str, default="BreakoutNoFrameskip-v4",
Expand Down Expand Up @@ -258,5 +264,31 @@ def mse_loss(params):
if global_step % args.target_network_frequency == 0:
q_state = q_state.replace(target_params=optax.incremental_update(q_state.params, q_state.target_params, 1))

if args.save_model:
model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
with open(model_path, "wb") as f:
f.write(flax.serialization.to_bytes(q_state.params))
print(f"model saved to {model_path}")
from cleanrl_utils.evals.dqn_jax_eval import evaluate

episodic_returns = evaluate(
model_path,
make_env,
args.env_id,
eval_episodes=10,
run_name=f"{run_name}-eval",
Model=QNetwork,
epsilon=0.05,
)
for idx, episodic_return in enumerate(episodic_returns):
writer.add_scalar("eval/episodic_return", episodic_return, idx)

if args.upload_model:
from cleanrl_utils.huggingface import push_to_hub

repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}"
repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name
push_to_hub(args, episodic_returns, repo_id, "DQN", f"runs/{run_name}", f"videos/{run_name}-eval")

envs.close()
writer.close()
32 changes: 32 additions & 0 deletions cleanrl/dqn_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ def parse_args():
help="the entity (team) of wandb's project")
parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to capture videos of the agent performances (check out `videos` folder)")
parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to save model into the `runs/{run_name}` folder")
parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to upload the saved model to huggingface")
parser.add_argument("--hf-entity", type=str, default="",
help="the user or org name of the model repository from the Hugging Face Hub")

# Algorithm specific arguments
parser.add_argument("--env-id", type=str, default="CartPole-v1",
Expand Down Expand Up @@ -230,5 +236,31 @@ def mse_loss(params):
if global_step % args.target_network_frequency == 0:
q_state = q_state.replace(target_params=optax.incremental_update(q_state.params, q_state.target_params, 1))

if args.save_model:
model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
with open(model_path, "wb") as f:
f.write(flax.serialization.to_bytes(q_state.params))
print(f"model saved to {model_path}")
from cleanrl_utils.evals.dqn_jax_eval import evaluate

episodic_returns = evaluate(
model_path,
make_env,
args.env_id,
eval_episodes=10,
run_name=f"{run_name}-eval",
Model=QNetwork,
epsilon=0.05,
)
for idx, episodic_return in enumerate(episodic_returns):
writer.add_scalar("eval/episodic_return", episodic_return, idx)

if args.upload_model:
from cleanrl_utils.huggingface import push_to_hub

repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}"
repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name
push_to_hub(args, episodic_returns, repo_id, "DQN", f"runs/{run_name}", f"videos/{run_name}-eval")

envs.close()
writer.close()
Empty file added cleanrl_utils/evals/__init__.py
Empty file.
73 changes: 73 additions & 0 deletions cleanrl_utils/evals/dqn_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import random
from typing import Callable

import gym
import numpy as np
import torch


def evaluate(
model_path: str,
make_env: Callable,
env_id: str,
eval_episodes: int,
run_name: str,
Model: torch.nn.Module,
device: torch.device = torch.device("cpu"),
epsilon: float = 0.05,
capture_video: bool = True,
):
envs = gym.vector.SyncVectorEnv([make_env(env_id, 0, 0, capture_video, run_name)])
model = Model(envs).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

obs = envs.reset()
episodic_returns = []
while len(episodic_returns) < eval_episodes:
if random.random() < epsilon:
actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
else:
q_values = model(torch.Tensor(obs).to(device))
actions = torch.argmax(q_values, dim=1).cpu().numpy()
next_obs, _, _, infos = envs.step(actions)
for info in infos:
if "episode" in info.keys():
print(f"eval_episode={len(episodic_returns)}, episodic_return={info['episode']['r']}")
episodic_returns += [info["episode"]["r"]]
obs = next_obs

return episodic_returns


if __name__ == "__main__":
from huggingface_hub import hf_hub_download

from cleanrl.dqn import QNetwork, make_env

model_path = hf_hub_download(repo_id="cleanrl/CartPole-v1-dqn-seed1", filename="q_network.pth")
vwxyzjn marked this conversation as resolved.
Show resolved Hide resolved
evaluate(
model_path,
make_env,
"CartPole-v1",
eval_episodes=10,
run_name=f"eval",
Model=QNetwork,
device="cpu",
capture_video=False,
)

# from cleanrl.dqn_atari import QNetwork, make_env

# model_path = hf_hub_download(repo_id="vwxyzjn/BreakoutNoFrameskip-v4-dqn_atari-seed1", filename="q_network.pth")
# evaluate(
# model_path,
# make_env,
# "BreakoutNoFrameskip-v4",
# eval_episodes=10,
# run_name=f"eval",
# Model=QNetwork,
# device="cpu",
# epsilon=0.05,
# capture_video=False,
# )
78 changes: 78 additions & 0 deletions cleanrl_utils/evals/dqn_jax_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import random
from typing import Callable

import flax
import flax.linen as nn
import gym
import jax
import numpy as np


def evaluate(
model_path: str,
make_env: Callable,
env_id: str,
eval_episodes: int,
run_name: str,
Model: nn.Module,
epsilon: float = 0.05,
capture_video: bool = True,
seed=1,
):
envs = gym.vector.SyncVectorEnv([make_env(env_id, 0, 0, capture_video, run_name)])
obs = envs.reset()
model = Model(action_dim=envs.single_action_space.n)
q_key = jax.random.PRNGKey(seed)
params = model.init(q_key, obs)
with open(model_path, "rb") as f:
params = flax.serialization.from_bytes(params, f.read())
model.apply = jax.jit(model.apply)

episodic_returns = []
while len(episodic_returns) < eval_episodes:
if random.random() < epsilon:
actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
else:
q_values = model.apply(params, obs)
actions = q_values.argmax(axis=-1)
actions = jax.device_get(actions)
next_obs, _, _, infos = envs.step(actions)
for info in infos:
if "episode" in info.keys():
print(f"eval_episode={len(episodic_returns)}, episodic_return={info['episode']['r']}")
episodic_returns += [info["episode"]["r"]]
obs = next_obs

return episodic_returns


if __name__ == "__main__":
from huggingface_hub import hf_hub_download

from cleanrl.dqn_jax import QNetwork, make_env

model_path = hf_hub_download(repo_id="vwxyzjn/CartPole-v1-dqn_jax-seed1", filename="dqn_jax.cleanrl_model")
evaluate(
model_path,
make_env,
"CartPole-v1",
eval_episodes=10,
run_name=f"eval",
Model=QNetwork,
capture_video=False,
)

# from cleanrl.dqn_atari import QNetwork, make_env

# model_path = hf_hub_download(repo_id="vwxyzjn/BreakoutNoFrameskip-v4-dqn_atari-seed1", filename="q_network.pth")
# evaluate(
# model_path,
# make_env,
# "BreakoutNoFrameskip-v4",
# eval_episodes=10,
# run_name=f"eval",
# Model=QNetwork,
# device="cpu",
# epsilon=0.05,
# capture_video=False,
# )
vwxyzjn marked this conversation as resolved.
Show resolved Hide resolved
Loading