Skip to content

Commit

Permalink
Merge pull request #18 from awslabs/remove-stable-baselines3
Browse files Browse the repository at this point in the history
Remove stable-baselines3 dependency, and move the gym data generator to example.
  • Loading branch information
verdimrc authored Sep 13, 2022
2 parents 2e19f76 + af2910a commit 4a3815b
Show file tree
Hide file tree
Showing 13 changed files with 201 additions and 195 deletions.
1 change: 0 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@
"pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None),
"python": ("https://docs.python.org/3/", None),
"sklearn": ("http://scikit-learn.org/stable/", None),
"stable_baselines3": ("https://stable-baselines3.readthedocs.io/en/master/", None),
"gym": ("https://gymlibrary.dev/", None),
"pytorch_lightning": ("https://pytorch-lightning.readthedocs.io/en/stable/", None),
}
Expand Down
2 changes: 1 addition & 1 deletion docs/references/datastruct.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ Should you need the documentations for the base ``pandas`` functionalities, plea

WiDataFrame
WiSeries
WhatifWrapper
TransitionRecorder
3 changes: 1 addition & 2 deletions docs/references/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ Utility functions.
backtest
better_than_random
conditional_entropy
data_generator_gym
data_generator_simple
entropy
force_assert
Expand All @@ -26,4 +25,4 @@ Utility functions.
reward_function
set_seed
stationary_policy
tokenize
tokenize
2 changes: 1 addition & 1 deletion examples/sagemaker-training/dynamic_pricing/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def agent(simulator, ctx):
model = trainer(policy="MlpPolicy", env=env, verbose=False) # type: ignore[call-arg,arg-type]
model.learn(total_timesteps=gym_timesteps[0])

cap_env = wi.WhatifWrapper(env)
cap_env = wi.TransitionRecorder(env)
model.set_env(cap_env)
model.learn(total_timesteps=gym_timesteps[1])

Expand Down
2 changes: 1 addition & 1 deletion examples/underfloor_heating/underfloor_heating_gym_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import a2rl as wi


class WhatifWrapperUnderfloor(wi.WhatifWrapper):
class WhatifWrapperUnderfloor(wi.TransitionRecorder):
"""This is data collector helper class.
When agent is interacting with the env, it will store the states/actions/reward
Expand Down
147 changes: 99 additions & 48 deletions notebooks/data_properties.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,33 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Data Properties"
"# Data Properties\n",
"\n",
"\n",
"For many sequential decision making problems we look for some key patterns in the data\n",
"\n",
"* Markov property\n",
"* A consistent reward or cost\n",
"* Actions being effective in contributing to the reward or affecting the Environment\n",
"* Seeing if there is a consistent way that actions are picked\n",
"\n",
"We have a few helper visualisations to help these are `markovian_matrix` and\n",
"`normalized_markovian_matrix`.\n",
"\n",
"**Pre-requisite**: this example requires stable-baselines3. To quickly install this library, you may\n",
"uncoment and execute the next cell. Note that the term `'gym>=...'` prevents `stable-baselines3`\n",
"from downgrading `gym` to a version incompatible with `a2rl`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# %pip install stable-baselines3 'gym>=0.23.1,<0.26.0'"
]
},
{
Expand All @@ -20,67 +46,33 @@
"%autoreload 2\n",
"\n",
"import my_nb_path # isort: skip\n",
"import a2rl as wi # isort: skip\n",
"import os\n",
"\n",
"import gym\n",
"from IPython.display import Markdown\n",
"from stable_baselines3 import A2C, DQN, SAC\n",
"from stable_baselines3.common.base_class import BaseAlgorithm\n",
"\n",
"import a2rl as wi\n",
"from a2rl.nbtools import pprint, print # Enable color outputs when rich is installed.\n",
"from a2rl.utils import (\n",
" NotMDPDataError,\n",
" assert_mdp,\n",
" data_generator_gym,\n",
" data_generator_simple,\n",
" plot_information,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"For many sequential decision making problems we look for some key patterns in the data\n",
"\n",
"* Markov property\n",
"\n",
"* A consistent reward or cost\n",
"\n",
"* Actions being effective in contributing to the reward or affecting the Environment\n",
"\n",
"* Seeing if there is a consistent way that actions are picked\n",
"\n",
"\n",
"We have a few helper visualisations to help these are markovian_matrix and normalized_markovian_matrix\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data Inspection\n",
"\n",
"In the offline setting we are restricted only to data. `whatif` offers three ways to generate some:\n",
"\n",
"1. The load-and-discretize workflow <- The main one. See `discretized_sample_dataset()`.\n",
"\n",
"2. `data_generator_gym` to load data interations between a trained agent and a gym environment <- This is for testing and research\n",
"\n",
"3. `data_generator_simple` to generate sample data with different properties <- Also for testing and research\n"
"import a2rl.nbtools # Enable color outputs when rich is installed.\n",
"from a2rl.utils import NotMDPDataError, assert_mdp, data_generator_simple, plot_information"
]
},
{
"cell_type": "markdown",
"metadata": {},
"metadata": {
"tags": []
},
"source": [
"## Helper Functions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def discretized_sample_dataset(dataset_name: str, n_bins=50) -> wi.WiDataFrame:\n",
Expand All @@ -99,7 +91,51 @@
" dirname = wi.sample_dataset_path(dataset_name)\n",
" tokeniser = wi.DiscreteTokenizer(n_bins=n_bins)\n",
" df = tokeniser.fit_transform(wi.read_csv_dataset(dirname))\n",
" return df"
" return df\n",
"\n",
"\n",
"def data_generator_gym(\n",
" env_name: str = \"Taxi-v3\",\n",
" trainer: type[BaseAlgorithm] = A2C,\n",
" training_steps: int = 10000,\n",
" capture_steps: int = 1000,\n",
") -> wi.WiDataFrame:\n",
" \"\"\"Generate a :class:`a2rl.WiDataFrame` from any well-defined OpenAi gym.\n",
" An agent is trained first for ``training_steps``. Then, capture ``capture_steps`` from the\n",
" trained agent.\n",
" Args:\n",
" env_name: Name of the gym environment.\n",
" trainer: An underlying generator algorithm that supports discrete actions, such as\n",
" :class:`stable_baselines3.dqn.DQN` or :class:`stable_baselines3.a2c.A2C`. Raise an error\n",
" when passing a trainer that does not support discrete actions, such as\n",
" :class:`stable_baselines3.sac.SAC`.\n",
" training_steps: The number of steps to train the generator.\n",
" capture_steps: The number of steps to capture.\n",
" Returns:\n",
" A2RL data frame.\n",
" \"\"\"\n",
" env = gym.make(env_name, render_mode=None)\n",
" model = trainer(policy=\"MlpPolicy\", env=env, verbose=False)\n",
" model.learn(total_timesteps=training_steps)\n",
"\n",
" cap_env = wi.TransitionRecorder(env)\n",
" model.set_env(cap_env)\n",
" model.learn(total_timesteps=capture_steps)\n",
"\n",
" tokeniser = wi.DiscreteTokenizer(n_bins=50)\n",
" df = tokeniser.fit_transform(cap_env.df)\n",
"\n",
" return df\n",
"\n",
"\n",
"def test_gym_generator():\n",
" import pytest\n",
"\n",
" gym_data = data_generator_gym(env_name=\"Taxi-v3\", trainer=DQN)\n",
" assert isinstance(gym_data, wi.WiDataFrame)\n",
"\n",
" with pytest.raises(AssertionError, match=r\"Discrete(.*) was provided\"):\n",
" gym_data = data_generator_gym(env_name=\"MountainCar-v0\", trainer=SAC)"
]
},
{
Expand All @@ -124,6 +160,21 @@
"################################################################################"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data Inspection\n",
"\n",
"In the offline setting we are restricted only to data. A2RL offers three ways to generate some:\n",
"\n",
"1. The load-and-discretize workflow <- The main one. See `discretized_sample_dataset()`.\n",
"\n",
"2. `data_generator_gym` to load data interations between a trained agent and a gym environment <- This is for testing and research\n",
"\n",
"3. `data_generator_simple` to generate sample data with different properties <- Also for testing and research"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
3 changes: 3 additions & 0 deletions requirements-docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@ rich
# third party package in example
psychrolib
smopy
stable-baselines3
# HACK: don't let stable-baselines3 downgrade to old gym. MUST be in-sync with requirements.txt.
gym>=0.23.1,<0.26.0
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ loguru
# See: breaking changes in https://github.com/openai/gym/releases/tag/0.26.0
gym>=0.23.1,<0.26.0

stable_baselines3
seaborn
cloudpickle
pytorch-lightning>=1.5.0
2 changes: 1 addition & 1 deletion src/a2rl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _warn(message, category=None, stacklevel=1, source=None):

# flake8: noqa: E402
from . import utils
from ._dataframe import WhatifWrapper, WiDataFrame, WiSeries
from ._dataframe import TransitionRecorder, WiDataFrame, WiSeries
from ._io import (
Metadata,
list_sample_datasets,
Expand Down
Loading

0 comments on commit 4a3815b

Please sign in to comment.