Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

beam-search implementation for more exhausting sampling #35

Merged
merged 6 commits into from
May 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions examples/sagemaker-training/dynamic_pricing/flight_sales.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def render(self):
pass

def step(self, action):

self.freight_price = self.config["freight_price"] + np.random.random()
self.freight_price = np.round(self.freight_price, decimals=1)

Expand All @@ -113,9 +112,7 @@ def fsigmoid(x, a, b, c):

for i in range(self.visitors):
if seats_left > 0:

if np.random.random() < fsigmoid([action], *self.params)[0]:

seats_left -= 1
tickets += 1

Expand Down Expand Up @@ -149,7 +146,6 @@ def fsigmoid(x, a, b, c):
return state, reward, done, {}

def context(self):

return wi.WiDataFrame(
self.history.fillna(method="ffill"),
states=["season", "freight_price"],
Expand All @@ -158,7 +154,6 @@ def context(self):
)

def reset(self):

self.config = config
self.day = 1
self.max_time = self.config["max_time"]
Expand Down
2 changes: 0 additions & 2 deletions examples/underfloor_heating/underfloor_heating_gym_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ class UnderfloorEnv(gym.Env):
"""

def __init__(self, env_config: UnderfloorEnvConfig):

if not isinstance(env_config, UnderfloorEnvConfig):
raise ValueError(f"Config must be of type UnderfloorEnvConfig, not {type(env_config)}")

Expand Down Expand Up @@ -228,7 +227,6 @@ def reset(self, **kwargs) -> np.ndarray | tuple[np.ndarray, dict]:
return self.state

def step(self, action: list[int]) -> tuple[np.ndarray, float, bool, dict]:

state_action = np.concatenate((self.state, action), axis=None)
# print(f"{self.state=}")
# print(f"{state_action=}")
Expand Down
5 changes: 4 additions & 1 deletion notebooks/dataframe.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,10 @@
"\n",
"\n",
"# Split the chiller df into 5 series (i.e., 1 for each column)\n",
"sers = [ser for _, ser in df.iteritems()]\n",
"if pd.__version__ >= '1.5.0':\n",
" sers = [ser for _, ser in df.items()]\n",
"else:\n",
" sers = [ser for _, ser in df.iteritems()]\n",
"assert_same_sar(sers)\n",
"\n",
"# Scale the states and rewards\n",
Expand Down
156 changes: 147 additions & 9 deletions notebooks/planner_byo_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -214,12 +214,14 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define two planners\n",
"## Define three planners\n",
"\n",
"Here we define two planner classes as examples to illustrate how to Bring Your Own planner to work with the `A2RL` simulator. We will add more planners (e.g. `BeamSearchPlanner`, etc.) as needed as per your feedback."
"Here we define three planner classes as examples to illustrate how to Bring Your Own planner to work with the `A2RL` simulator.\n",
"<!-- We will add more planners (e.g. `BeamSearchPlanner`, etc.) as needed as per your feedback. -->"
]
},
{
Expand Down Expand Up @@ -435,7 +437,112 @@
"\n",
" q_accum_cost_list = q_accum_cost_list.transpose()\n",
" q_non_accum_cost_list = q_non_accum_cost_list.transpose()\n",
" return [q_non_accum_cost_list, q_accum_cost_list]\n"
" return [q_non_accum_cost_list, q_accum_cost_list]\n",
"\n",
"\n",
"class BeamSearchQPlanner(A2RLPLanner):\n",
" \"\"\"\n",
" This planner has similar logic to the QPlanner, only it uses `a2rl.Simulator.beam_search_n_steps`\n",
" to obtain all the actions and rewards in one go.\n",
" The actions are still chosen with the highest / lowest sum_reward (immediate_reward + reward-to-go), \n",
" and take that action to the next step.\n",
" \"\"\"\n",
"\n",
" def __init__(self, simulator: Simulator, beam_width: int, beam_random: bool, objective: str = 'min') -> None:\n",
" super().__init__(simulator)\n",
"\n",
" self.beam_width = beam_width\n",
" self.beam_random = beam_random\n",
"\n",
" if objective.lower() not in ['min', 'max']:\n",
" raise ValueError('objective must be either min or max')\n",
" if 'min' == objective:\n",
" self.obj_op = np.argmin\n",
" else:\n",
" self.obj_op = np.argmax\n",
"\n",
" def rollout(self, horizon: int = 20, nb_runs: int = 3) -> List[np.array]:\n",
" if nb_runs != 1:\n",
" print(\"WARN: multiple runs in beam search is implemented as a loop and not vectorized and performance may be slow\")\n",
"\n",
" if nb_runs != 1 and not self.beam_random:\n",
" raise ValueError(\"'beam_random' should be True when using multiple runs\")\n",
"\n",
" dataframe_per_run = []\n",
" non_accum_rewards_list = []\n",
" accum_rewards_list = []\n",
"\n",
" initial_context = self.tokenizer.df_tokenized.iloc[0, : self.tokenizer.state_dim].values\n",
"\n",
" for i_run in range(nb_runs):\n",
" non_accum_rewards = []\n",
"\n",
" if initial_context.ndim != 1:\n",
" raise NotImplementedError(\"batching not implemented\")\n",
"\n",
" # Overwite some tokens here if you need\n",
" overwrite_valid_tokens = {}\n",
"\n",
" # Generate A+R+S tokens each time\n",
" context = initial_context\n",
" n_steps = self.tokenizer.action_dim + self.tokenizer.reward_dim + self.tokenizer.state_dim\n",
"\n",
" for i in tqdm(range(horizon)):\n",
" new_context, accum_logprobs = self.simulator.beam_search_n_steps(\n",
" seq=context,\n",
" n_steps=n_steps,\n",
" beam_width=self.beam_width,\n",
" randomness=self.beam_random,\n",
" overwrite_valid_tokens=overwrite_valid_tokens,\n",
" return_logprobs=True,\n",
" )\n",
"\n",
" ars_tokens = new_context[:, len(context) :]\n",
" df_ars = wi.WiDataFrame(\n",
" ars_tokens,\n",
" **self.tokenizer.df_tokenized.sar_d,\n",
" columns=[\n",
" *self.tokenizer.action_columns,\n",
" *self.tokenizer.reward_columns,\n",
" *self.tokenizer.state_columns,\n",
" ],\n",
" )\n",
"\n",
" df_sar = df_ars[df_ars.sar]\n",
" df_sar = self.tokenizer.field_tokenizer.inverse_transform(df_sar)\n",
"\n",
" rewards = df_sar[self.tokenizer.reward_columns].values\n",
" best_idx = self.obj_op(rewards.sum(axis=1))\n",
" non_accum_rewards.append(rewards[best_idx, 0])\n",
"\n",
" context = new_context[best_idx]\n",
"\n",
" # Uncomment the following if you want to record a dataframe per run\n",
" # widf_searched = wi.WiDataFrame(\n",
" # context[len(initial_context) :].reshape(horizon, -1),\n",
" # **self.tokenizer.df_tokenized.sar_d,\n",
" # columns=[\n",
" # *self.tokenizer.df_tokenized.actions,\n",
" # *self.tokenizer.df_tokenized.rewards,\n",
" # *self.tokenizer.df_tokenized.states,\n",
" # ],\n",
" # )\n",
" # widf_searched = widf_searched[widf_searched.sar]\n",
" # widf_searched = self.tokenizer.field_tokenizer.inverse_transform(widf_searched)\n",
" # widf_searched[\"nb_run\"] = i_run\n",
" # widf_searched[\"timestep\"] = range(1, len(widf_searched) + 1)\n",
" # dataframe_per_run.append(widf_searched)\n",
"\n",
" non_accum_rewards = np.array(non_accum_rewards)\n",
" accum_rewards = np.cumsum(non_accum_rewards, axis=0)\n",
"\n",
" non_accum_rewards_list.append(non_accum_rewards)\n",
" accum_rewards_list.append(accum_rewards)\n",
"\n",
" non_accum_rewards_list = np.array(non_accum_rewards_list)\n",
" accum_rewards_list = np.array(accum_rewards_list)\n",
"\n",
" return [non_accum_rewards_list, accum_rewards_list]\n"
]
},
{
Expand Down Expand Up @@ -473,12 +580,33 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Compare the costs (`system_power_consumption`) between two planners\n",
"### Create and run the `BeamSearchQPlanner` "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"bsqp = BeamSearchQPlanner(simulator, beam_width=8, beam_random=True)\n",
"bsq_non_accum_cost_list, bsq_accum_cost_list = bsqp.rollout(horizon, nb_runs)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Compare the costs (`system_power_consumption`) between three planners\n",
"\n",
"On average (in the sense of **expected** outcome), the `Q-value Maximisation` planner (`QPlanner` for short) produces relatively lower `system_power_consumption`. However, the `Bahaviour Clone` actions may occasionally perform equally well. This is due to the non-deterministic nature of both the *Simulator* when performing `simulator.lookahead()` and the randomness associated with `simulator.sample()`. Moreover, the GPT model associated with the *Simulator* in this example was not trained sufficiently in terms of both the number of epochs and the size of the training data.\n",
"\n",
"On average (in the sense of **expected** outcome), the `Q-value Maximisation` planner (`QPlanner` for short) produces relatively lower `system_power_consumption`. However, the `Bahaviour Clone` actions may occasionally perform equally well. This is due to the non-deterministic nature of both the *Simulator* when performing `simulator.lookahead()` and the randomness associated with `simulator.sample()`. Moreover, the GPT model associated with the *Simulator* in this example was not trained sufficiently in terms of both the number of epochs and the size of the training data."
"The beam search planner should demonstrate a performance between behaviour cloning and Q-planner, since the idea of beam search is to create a better simulation and ask the planner not to be over-confident about the results."
]
},
{
Expand Down Expand Up @@ -513,7 +641,12 @@
" step_list.append(j)\n",
" acc_cost.append(q_accum_cost_list[i][j])\n",
" inst_cost.append(q_non_accum_cost_list[i][j])\n",
" policy_list.append(\"q-value\")"
" policy_list.append(\"q-value\")\n",
"\n",
" step_list.append(j)\n",
" acc_cost.append(bsq_accum_cost_list[i][j])\n",
" inst_cost.append(bsq_non_accum_cost_list[i][j])\n",
" policy_list.append(\"beam-search\")"
]
},
{
Expand Down Expand Up @@ -550,6 +683,9 @@
"sns.lineplot(\n",
" data=df_result[df_result.policy == \"q-value\"], x=\"step\", y=\"step_cost\", label=\"Q-value optimal\"\n",
")\n",
"sns.lineplot(\n",
" data=df_result[df_result.policy == \"beam-search\"], x=\"step\", y=\"step_cost\", label=\"Beam search\"\n",
")\n",
"plt.legend(fontsize=14)\n",
"plt.grid(ls=\"--\")\n",
"plt.xlabel(\"Step\", fontsize=16)\n",
Expand All @@ -568,9 +704,11 @@
"source": [
"data1 = df_result[(df_result.policy == \"behaviour\")]\n",
"data2 = df_result[(df_result.policy == \"q-value\")]\n",
"data3 = df_result[(df_result.policy == \"beam-search\")]\n",
"\n",
"sns.lineplot(data=data1, x=\"step\", y=\"acc_cost\", label=\"Behaviour clone\")\n",
"sns.lineplot(data=data2, x=\"step\", y=\"acc_cost\", label=\"Q-value optimal\")\n",
"sns.lineplot(data=data3, x=\"step\", y=\"acc_cost\", label=\"Beam search\")\n",
"plt.legend(fontsize=14)\n",
"plt.grid(ls=\"--\")\n",
"plt.xlabel(\"Step\", fontsize=16)\n",
Expand All @@ -581,7 +719,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.12 ('a2rl')",
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
Expand All @@ -595,12 +733,12 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
"version": "3.11.0"
},
"toc-autonumbering": true,
"vscode": {
"interpreter": {
"hash": "62263fd135fd753cfd7c1bf88d5e743cb8b5f0e0f18aad3aa6722c0590b39cdb"
"hash": "5c7b89af1651d0b8571dde13640ecdccf7d5a6204171d6ab33e7c296e100e08a"
}
}
},
Expand Down
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ torch
tqdm>=4.64.1
PyYaml>=5.1
typing_extensions
typeguard
typeguard>=3.0.0
nptyping
loguru

Expand All @@ -17,3 +17,5 @@ gym>=0.23.1,<0.26.0
seaborn
cloudpickle
pytorch-lightning>=1.5.0

tensorboardX
13 changes: 7 additions & 6 deletions src/a2rl/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,12 @@ class Metadata:
tags: dict[str, Any] = field(default_factory=dict)

def __post_init__(self) -> None:
check_type("states", self.states, List[str])
check_type("actions", self.actions, List[str])
check_type("rewards", self.rewards, List[str])
check_type("forced_categories", self.forced_categories, Optional[List[str]])
check_type("frequency", self.frequency, Optional[str])
check_type("tags", self.tags, Dict[str, Any])
check_type(self.states, List[str])
check_type(self.actions, List[str])
check_type(self.rewards, List[str])
check_type(self.forced_categories, Optional[List[str]])
check_type(self.frequency, Optional[str])
check_type(self.tags, Dict[str, Any])


def read_metadata(yaml_file: str | Path) -> Metadata:
Expand Down Expand Up @@ -337,6 +337,7 @@ def save_metadata(
tags: {}
<BLANKLINE>
"""

# Based on https://github.com/yaml/pyyaml/issues/127#issuecomment-525800484
class BlankLiner(yaml.SafeDumper):
def write_line_break(self, data=None):
Expand Down
Loading