Skip to content

Commit

Permalink
Huggingface Integration (#292)
Browse files Browse the repository at this point in the history
* initial commit

* pre-commit

* Add hub integration

* pre-commit

* use CommitOperation

* Fix pre-commit

* refactor

* push changes

* refactor

* fix pre-commit

* pre-commit

* close the env and writer after eval

* support dqn jax

* pre-commit

* Update cleanrl_utils/huggingface.py

Co-authored-by: Lucain <[email protected]>

* address comments

* update docs

* support dqn_atari_jax

* bug fix and docs

* Add cleanrl to the hf's `metadata`

* include huggingface integration

* test for enjoy.py

* bump version, pip install extra hack

python-poetry/poetry#4842 (comment)

* Update cleanrl_utils/huggingface.py

Co-authored-by: Lucain <[email protected]>

* Update cleanrl_utils/huggingface.py

Co-authored-by: Lucain <[email protected]>

* Update cleanrl_utils/huggingface.py

Co-authored-by: Lucain <[email protected]>

* Update cleanrl_utils/huggingface.py

Co-authored-by: Lucain <[email protected]>

* Update cleanrl_utils/huggingface.py

Co-authored-by: Lucain <[email protected]>

* Update cleanrl_utils/huggingface.py

Co-authored-by: Lucain <[email protected]>

* update docs

* update pre-commit

* quick fix

* bug fix

* lazy load modules to avoid dependency issues

* Add huggingface shields

* Add emoji

* Update docs

* pre-commit

* Update docs

* Update docs

* fix: use `algorithm_variant_filename` in model card reproduction script

* typo fix

* feat: add hf support for c51

* formatting fix

* support pulling variant depdencies directly

* support model saving for `ppo_atari_envpool_xla_jax_scan`

* support `ppo_atari_envpool_xla_jax_scan`

* quick change

* support 'c51_jax'

* formatting fix

* support capture video

* Add notebook

* update docs

* support `c51_atari` and `c51_atari_jax`

* typo fix

* add c51 to zoo docs

* add colab badge

* fix broken colab svg

* pypi release

* typo fix

* update pre-commit

* remove hf-integration reference

Co-authored-by: Lucain <[email protected]>
Co-authored-by: Kinal <[email protected]>
Co-authored-by: Kinal Mehta <[email protected]>
  • Loading branch information
4 people authored Jan 12, 2023
1 parent 3f5535c commit 30381ee
Show file tree
Hide file tree
Showing 40 changed files with 7,042 additions and 134 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
runs
balance_bot.xml
cleanrl/ppo_continuous_action_isaacgym/isaacgym/examples
cleanrl/ppo_continuous_action_isaacgym/isaacgym/isaacgym
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ repos:
- id: codespell
args:
- --ignore-words-list=nd,reacher,thist,ths,magent
- --skip=docs/css/termynal.css,docs/js/termynal.js
- --skip=docs/css/termynal.css,docs/js/termynal.js,docs/get-started/CleanRL_Huggingface_Integration_Demo.ipynb
- repo: https://github.com/python-poetry/poetry
rev: 1.2.1
hooks:
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
[<img src="https://img.shields.io/youtube/channel/views/UCDdC6BIFRI0jvcwuhi3aI6w?style=social">](https://www.youtube.com/channel/UCDdC6BIFRI0jvcwuhi3aI6w/videos)
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![Imports: isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/)

[<img src="https://img.shields.io/badge/%F0%9F%A4%97%20Models-Huggingface-F8D521">](https://huggingface.co/cleanrl)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/vwxyzjn/cleanrl/blob/master/docs/get-started/CleanRL_Huggingface_Integration_Demo.ipynb)


CleanRL is a Deep Reinforcement Learning library that provides high-quality single-file implementation with research-friendly features. The implementation is clean and simple, yet we can scale it to run thousands of experiments using AWS Batch. The highlight features of CleanRL are:
Expand Down
36 changes: 36 additions & 0 deletions cleanrl/c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ def parse_args():
help="the entity (team) of wandb's project")
parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to capture videos of the agent performances (check out `videos` folder)")
parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to save model into the `runs/{run_name}` folder")
parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to upload the saved model to huggingface")
parser.add_argument("--hf-entity", type=str, default="",
help="the user or org name of the model repository from the Hugging Face Hub")

# Algorithm specific arguments
parser.add_argument("--env-id", type=str, default="CartPole-v1",
Expand Down Expand Up @@ -238,5 +244,35 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
if global_step % args.target_network_frequency == 0:
target_network.load_state_dict(q_network.state_dict())

if args.save_model:
model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
model_data = {
"model_weights": q_network.state_dict(),
"args": vars(args),
}
torch.save(model_data, model_path)
print(f"model saved to {model_path}")
from cleanrl_utils.evals.c51_eval import evaluate

episodic_returns = evaluate(
model_path,
make_env,
args.env_id,
eval_episodes=10,
run_name=f"{run_name}-eval",
Model=QNetwork,
device=device,
epsilon=0.05,
)
for idx, episodic_return in enumerate(episodic_returns):
writer.add_scalar("eval/episodic_return", episodic_return, idx)

if args.upload_model:
from cleanrl_utils.huggingface import push_to_hub

repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}"
repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name
push_to_hub(args, episodic_returns, repo_id, "C51", f"runs/{run_name}", f"videos/{run_name}-eval")

envs.close()
writer.close()
36 changes: 36 additions & 0 deletions cleanrl/c51_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ def parse_args():
help="the entity (team) of wandb's project")
parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to capture videos of the agent performances (check out `videos` folder)")
parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to save model into the `runs/{run_name}` folder")
parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to upload the saved model to huggingface")
parser.add_argument("--hf-entity", type=str, default="",
help="the user or org name of the model repository from the Hugging Face Hub")

# Algorithm specific arguments
parser.add_argument("--env-id", type=str, default="BreakoutNoFrameskip-v4",
Expand Down Expand Up @@ -260,5 +266,35 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
if global_step % args.target_network_frequency == 0:
target_network.load_state_dict(q_network.state_dict())

if args.save_model:
model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
model_data = {
"model_weights": q_network.state_dict(),
"args": vars(args),
}
torch.save(model_data, model_path)
print(f"model saved to {model_path}")
from cleanrl_utils.evals.c51_eval import evaluate

episodic_returns = evaluate(
model_path,
make_env,
args.env_id,
eval_episodes=10,
run_name=f"{run_name}-eval",
Model=QNetwork,
device=device,
epsilon=0.05,
)
for idx, episodic_return in enumerate(episodic_returns):
writer.add_scalar("eval/episodic_return", episodic_return, idx)

if args.upload_model:
from cleanrl_utils.huggingface import push_to_hub

repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}"
repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name
push_to_hub(args, episodic_returns, repo_id, "C51", f"runs/{run_name}", f"videos/{run_name}-eval")

envs.close()
writer.close()
42 changes: 37 additions & 5 deletions cleanrl/c51_atari_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,20 @@ def parse_args():
help="the name of this experiment")
parser.add_argument("--seed", type=int, default=1,
help="seed of the experiment")
parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="if toggled, `torch.backends.cudnn.deterministic=False`")
parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="if toggled, cuda will be enabled by default")
parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="if toggled, this experiment will be tracked with Weights and Biases")
parser.add_argument("--wandb-project-name", type=str, default="cleanRL",
help="the wandb's project name")
parser.add_argument("--wandb-entity", type=str, default=None,
help="the entity (team) of wandb's project")
parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="weather to capture videos of the agent performances (check out `videos` folder)")
help="whether to capture videos of the agent performances (check out `videos` folder)")
parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to save model into the `runs/{run_name}` folder")
parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to upload the saved model to huggingface")
parser.add_argument("--hf-entity", type=str, default="",
help="the user or org name of the model repository from the Hugging Face Hub")

# Algorithm specific arguments
parser.add_argument("--env-id", type=str, default="BreakoutNoFrameskip-v4",
Expand Down Expand Up @@ -301,5 +303,35 @@ def get_action(q_state, obs):
if global_step % args.target_network_frequency == 0:
q_state = q_state.replace(target_params=optax.incremental_update(q_state.params, q_state.target_params, 1))

if args.save_model:
model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
model_data = {
"model_weights": q_state.params,
"args": vars(args),
}
with open(model_path, "wb") as f:
f.write(flax.serialization.to_bytes(model_data))
print(f"model saved to {model_path}")
from cleanrl_utils.evals.c51_jax_eval import evaluate

episodic_returns = evaluate(
model_path,
make_env,
args.env_id,
eval_episodes=10,
run_name=f"{run_name}-eval",
Model=QNetwork,
epsilon=0.05,
)
for idx, episodic_return in enumerate(episodic_returns):
writer.add_scalar("eval/episodic_return", episodic_return, idx)

if args.upload_model:
from cleanrl_utils.huggingface import push_to_hub

repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}"
repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name
push_to_hub(args, episodic_returns, repo_id, "C51", f"runs/{run_name}", f"videos/{run_name}-eval")

envs.close()
writer.close()
42 changes: 37 additions & 5 deletions cleanrl/c51_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,20 @@ def parse_args():
help="the name of this experiment")
parser.add_argument("--seed", type=int, default=1,
help="seed of the experiment")
parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="if toggled, `torch.backends.cudnn.deterministic=False`")
parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="if toggled, cuda will be enabled by default")
parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="if toggled, this experiment will be tracked with Weights and Biases")
parser.add_argument("--wandb-project-name", type=str, default="cleanRL",
help="the wandb's project name")
parser.add_argument("--wandb-entity", type=str, default=None,
help="the entity (team) of wandb's project")
parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="weather to capture videos of the agent performances (check out `videos` folder)")
help="whether to capture videos of the agent performances (check out `videos` folder)")
parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to save model into the `runs/{run_name}` folder")
parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to upload the saved model to huggingface")
parser.add_argument("--hf-entity", type=str, default="",
help="the user or org name of the model repository from the Hugging Face Hub")

# Algorithm specific arguments
parser.add_argument("--env-id", type=str, default="CartPole-v1",
Expand Down Expand Up @@ -268,5 +270,35 @@ def loss(q_params, observations, actions, target_pmfs):
if global_step % args.target_network_frequency == 0:
q_state = q_state.replace(target_params=optax.incremental_update(q_state.params, q_state.target_params, 1))

if args.save_model:
model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
model_data = {
"model_weights": q_state.params,
"args": vars(args),
}
with open(model_path, "wb") as f:
f.write(flax.serialization.to_bytes(model_data))
print(f"model saved to {model_path}")
from cleanrl_utils.evals.c51_jax_eval import evaluate

episodic_returns = evaluate(
model_path,
make_env,
args.env_id,
eval_episodes=10,
run_name=f"{run_name}-eval",
Model=QNetwork,
epsilon=0.05,
)
for idx, episodic_return in enumerate(episodic_returns):
writer.add_scalar("eval/episodic_return", episodic_return, idx)

if args.upload_model:
from cleanrl_utils.huggingface import push_to_hub

repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}"
repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name
push_to_hub(args, episodic_returns, repo_id, "C51", f"runs/{run_name}", f"videos/{run_name}-eval")

envs.close()
writer.close()
32 changes: 32 additions & 0 deletions cleanrl/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ def parse_args():
help="the entity (team) of wandb's project")
parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to capture videos of the agent performances (check out `videos` folder)")
parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to save model into the `runs/{run_name}` folder")
parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to upload the saved model to huggingface")
parser.add_argument("--hf-entity", type=str, default="",
help="the user or org name of the model repository from the Hugging Face Hub")

# Algorithm specific arguments
parser.add_argument("--env-id", type=str, default="CartPole-v1",
Expand Down Expand Up @@ -206,5 +212,31 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
if global_step % args.target_network_frequency == 0:
target_network.load_state_dict(q_network.state_dict())

if args.save_model:
model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
torch.save(q_network.state_dict(), model_path)
print(f"model saved to {model_path}")
from cleanrl_utils.evals.dqn_eval import evaluate

episodic_returns = evaluate(
model_path,
make_env,
args.env_id,
eval_episodes=10,
run_name=f"{run_name}-eval",
Model=QNetwork,
device=device,
epsilon=0.05,
)
for idx, episodic_return in enumerate(episodic_returns):
writer.add_scalar("eval/episodic_return", episodic_return, idx)

if args.upload_model:
from cleanrl_utils.huggingface import push_to_hub

repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}"
repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name
push_to_hub(args, episodic_returns, repo_id, "DQN", f"runs/{run_name}", f"videos/{run_name}-eval")

envs.close()
writer.close()
32 changes: 32 additions & 0 deletions cleanrl/dqn_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ def parse_args():
help="the entity (team) of wandb's project")
parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to capture videos of the agent performances (check out `videos` folder)")
parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to save model into the `runs/{run_name}` folder")
parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to upload the saved model to huggingface")
parser.add_argument("--hf-entity", type=str, default="",
help="the user or org name of the model repository from the Hugging Face Hub")

# Algorithm specific arguments
parser.add_argument("--env-id", type=str, default="BreakoutNoFrameskip-v4",
Expand Down Expand Up @@ -228,5 +234,31 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
if global_step % args.target_network_frequency == 0:
target_network.load_state_dict(q_network.state_dict())

if args.save_model:
model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
torch.save(q_network.state_dict(), model_path)
print(f"model saved to {model_path}")
from cleanrl_utils.evals.dqn_eval import evaluate

episodic_returns = evaluate(
model_path,
make_env,
args.env_id,
eval_episodes=10,
run_name=f"{run_name}-eval",
Model=QNetwork,
device=device,
epsilon=0.05,
)
for idx, episodic_return in enumerate(episodic_returns):
writer.add_scalar("eval/episodic_return", episodic_return, idx)

if args.upload_model:
from cleanrl_utils.huggingface import push_to_hub

repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}"
repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name
push_to_hub(args, episodic_returns, repo_id, "DQN", f"runs/{run_name}", f"videos/{run_name}-eval")

envs.close()
writer.close()
32 changes: 32 additions & 0 deletions cleanrl/dqn_atari_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ def parse_args():
help="the entity (team) of wandb's project")
parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to capture videos of the agent performances (check out `videos` folder)")
parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to save model into the `runs/{run_name}` folder")
parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to upload the saved model to huggingface")
parser.add_argument("--hf-entity", type=str, default="",
help="the user or org name of the model repository from the Hugging Face Hub")

# Algorithm specific arguments
parser.add_argument("--env-id", type=str, default="BreakoutNoFrameskip-v4",
Expand Down Expand Up @@ -258,5 +264,31 @@ def mse_loss(params):
if global_step % args.target_network_frequency == 0:
q_state = q_state.replace(target_params=optax.incremental_update(q_state.params, q_state.target_params, 1))

if args.save_model:
model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
with open(model_path, "wb") as f:
f.write(flax.serialization.to_bytes(q_state.params))
print(f"model saved to {model_path}")
from cleanrl_utils.evals.dqn_jax_eval import evaluate

episodic_returns = evaluate(
model_path,
make_env,
args.env_id,
eval_episodes=10,
run_name=f"{run_name}-eval",
Model=QNetwork,
epsilon=0.05,
)
for idx, episodic_return in enumerate(episodic_returns):
writer.add_scalar("eval/episodic_return", episodic_return, idx)

if args.upload_model:
from cleanrl_utils.huggingface import push_to_hub

repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}"
repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name
push_to_hub(args, episodic_returns, repo_id, "DQN", f"runs/{run_name}", f"videos/{run_name}-eval")

envs.close()
writer.close()
Loading

1 comment on commit 30381ee

@vercel
Copy link

@vercel vercel bot commented on 30381ee Jan 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Successfully deployed to the following URLs:

cleanrl – ./

cleanrl.vercel.app
cleanrl-git-master-vwxyzjn.vercel.app
docs.cleanrl.dev
cleanrl-vwxyzjn.vercel.app

Please sign in to comment.