-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 5ee6947
Showing
1,052 changed files
with
177,110 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Sphinx build info version 1 | ||
# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. | ||
config: dee68c8b7c7bb3dc5c0c7db1a6eb2393 | ||
tags: 645f666f9bcd5a90fca523b33c5a78b7 |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added
BIN
+11.8 KB
.doctrees/auto_examples/demo_agents/video_plot_rs_kernel_ucbvi.doctree
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added
BIN
+22 KB
.doctrees/auto_examples/demo_bandits/plot_compare_index_bandits.doctree
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added
BIN
+18.7 KB
.doctrees/auto_examples/demo_env/example_atari_atlantis_vectorized_ppo.doctree
Binary file not shown.
Binary file added
BIN
+18.7 KB
.doctrees/auto_examples/demo_env/example_atari_breakout_vectorized_ppo.doctree
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added
BIN
+12.6 KB
...trees/auto_examples/demo_env/video_plot_old_gym_compatibility_wrapper_old_acrobot.doctree
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added
BIN
+124 KB
.doctrees/generated/rlberry.agents.stable_baselines.StableBaselinesAgent.doctree
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added
BIN
+112 KB
.doctrees/generated/rlberry.wrappers.discretize_state.DiscretizeStateWrapper.doctree
Binary file not shown.
Binary file added
BIN
+114 KB
.doctrees/generated/rlberry.wrappers.gym_utils.OldGymCompatibilityWrapper.doctree
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Empty file.
Binary file not shown.
29 changes: 29 additions & 0 deletions
29
_downloads/0c468b13b91ea3c7663a48c76c453e96/video_plot_mountain_car.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
""" | ||
=============================== | ||
A demo of MountainCar environment | ||
=============================== | ||
Illustration of MountainCar environment | ||
.. video:: ../../video_plot_montain_car.mp4 | ||
:width: 600 | ||
""" | ||
# sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_montain_car.jpg' | ||
|
||
from rlberry_scool.agents.mbqvi import MBQVIAgent | ||
from rlberry_research.envs.classic_control import MountainCar | ||
from rlberry.wrappers import DiscretizeStateWrapper | ||
|
||
_env = MountainCar() | ||
env = DiscretizeStateWrapper(_env, 20) | ||
agent = MBQVIAgent(env, n_samples=40, gamma=0.99) | ||
agent.fit() | ||
|
||
env.enable_rendering() | ||
observation, info = env.reset() | ||
for tt in range(200): | ||
action = agent.policy(observation) | ||
observation, reward, terminated, truncated, info = env.step(action) | ||
done = terminated or truncated | ||
|
||
video = env.save_video("_video/video_plot_montain_car.mp4") |
43 changes: 43 additions & 0 deletions
43
_downloads/0ffc6bd397932d17ce61dd19fa2aeddf/video_plot_gridworld.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"\n\n# A demo of Gridworld environment with ValueIterationAgent\nIllustration of the training and video rendering ofValueIteration Agent in\nGridworld environment.\n\n.. video:: ../../video_plot_gridworld.mp4\n :width: 600\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from rlberry_scool.agents.dynprog import ValueIterationAgent\nfrom rlberry_scool.envs.finite import GridWorld\n\n\nenv = GridWorld(7, 10, walls=((2, 2), (3, 3)))\n\nagent = ValueIterationAgent(env, gamma=0.95)\ninfo = agent.fit()\nprint(info)\n\nenv.enable_rendering()\nobservation, info = env.reset()\nfor tt in range(50):\n action = agent.policy(observation)\n observation, reward, terminated, truncated, info = env.step(action)\n done = terminated or truncated\n if done:\n # Warning: this will never happen in the present case because there is no terminal state.\n # See the doc of GridWorld for more informations on the default parameters of GridWorld.\n break\n# Save the video\nenv.save_video(\"_video/video_plot_gridworld.mp4\", framerate=10)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.15" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 0 | ||
} |
180 changes: 180 additions & 0 deletions
180
_downloads/10b71a51b8ae10571824280d3b92d89b/plot_mirror_bandit.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
""" | ||
======================================================== | ||
A demo of Bandit BAI on a real dataset to select mirrors | ||
======================================================== | ||
In this exemple we use a sequential halving agent to find the best server | ||
to download ubuntu from among a choice of 8 french servers. | ||
The quirck of this application is that there is a possible timeout when pinging | ||
a server. We handle this by using the median instead of the mean in sequential | ||
halving's objective. | ||
The code is in three parts: definition of environment, definition of agent, | ||
and finally definition of the experiment. | ||
""" | ||
import numpy as np | ||
|
||
from rlberry.manager import ExperimentManager, read_writer_data | ||
from rlberry.envs.interface import Model | ||
from rlberry_research.agents.bandits import BanditWithSimplePolicy | ||
import rlberry.spaces as spaces | ||
|
||
import requests | ||
import matplotlib.pyplot as plt | ||
|
||
|
||
import rlberry | ||
|
||
logger = rlberry.logger | ||
|
||
# Environment definition | ||
|
||
|
||
TIMEOUT = 2 | ||
|
||
mirrors_ubuntu = np.array( | ||
[ | ||
"https://ubuntu.lafibre.info/ubuntu/", | ||
"https://mirror.ubuntu.ikoula.com/", | ||
"http://ubuntu.mirrors.ovh.net/ubuntu/", | ||
"http://miroir.univ-lorraine.fr/ubuntu/", | ||
"http://ubuntu.univ-nantes.fr/ubuntu/", | ||
"https://ftp.u-picardie.fr/mirror/ubuntu/ubuntu/", | ||
"http://ubuntu.univ-reims.fr/ubuntu/", | ||
"http://www-ftp.lip6.fr/pub/linux/distributions/Ubuntu/archive/", | ||
] | ||
) | ||
|
||
|
||
def get_time(url): | ||
try: | ||
resp = requests.get(url, timeout=TIMEOUT) | ||
return resp.elapsed.total_seconds() | ||
except: | ||
return np.inf # timeout | ||
|
||
|
||
class MirrorBandit(Model): | ||
""" | ||
Real environment for bandit problems. | ||
The reward is the response time for French servers meant to download ubuntu. | ||
On action i, gives a negative waiting time to reach url i in mirror_ubuntu. | ||
WARNING : if there is a timeout when querying the mirror, will result in | ||
a negative infinite reward. | ||
Parameters | ||
---------- | ||
url_ids : list of int or None, | ||
list of ids used to select a subset of the url list provided in the source. | ||
if None, all the urls are selected (i.e. 8 arms bandits). | ||
""" | ||
|
||
name = "MirrorEnv" | ||
|
||
def __init__(self, url_ids=None, **kwargs): | ||
Model.__init__(self, **kwargs) | ||
if url_ids: | ||
self.url_list = mirrors_ubuntu[url_ids] | ||
else: | ||
self.url_list = mirrors_ubuntu | ||
|
||
self.n_arms = len(self.url_list) | ||
self.action_space = spaces.Discrete(self.n_arms) | ||
|
||
def step(self, action): | ||
""" | ||
Sample the reward associated to the action. | ||
""" | ||
# test that the action exists | ||
assert action < self.n_arms | ||
|
||
reward = -get_time(self.url_list[action]) | ||
terminated = True | ||
truncated = False | ||
return 0, reward, terminated, truncated, {} | ||
|
||
def reset(self, seed=None): | ||
""" | ||
Reset the environment to a default state. | ||
""" | ||
return 0, {} | ||
|
||
|
||
env_ctor = MirrorBandit | ||
env_kwargs = {} | ||
|
||
# BAI Agent definition | ||
|
||
|
||
class SeqHalvAgent(BanditWithSimplePolicy): | ||
""" | ||
Sequential Halving Agent | ||
""" | ||
|
||
name = "SeqHalvAgent" | ||
|
||
def __init__(self, env, **kwargs): | ||
BanditWithSimplePolicy.__init__( | ||
self, env, writer_extra="action_and_reward", **kwargs | ||
) | ||
|
||
def fit(self, budget=None, **kwargs): | ||
horizon = budget | ||
rewards = [] | ||
actions = [] | ||
active_set = np.arange(self.n_arms) | ||
|
||
logk = int(np.ceil(np.log2(self.n_arms))) | ||
ep = 0 | ||
|
||
for r in range(logk): | ||
tr = np.floor(horizon / (len(active_set) * logk)) | ||
for _ in range(int(tr)): | ||
for k in active_set: | ||
action = k | ||
actions += [action] | ||
observation, reward, terminated, truncated, info = self.env.step( | ||
action | ||
) | ||
rewards += [reward] | ||
ep += 1 | ||
reward_est = [ | ||
np.median(np.array(rewards)[actions == k]) for k in active_set | ||
] | ||
# We estimate the reward using the median instead of the mean to | ||
# handle timeout. | ||
half_len = int(np.ceil(len(active_set) / 2)) | ||
active_set = active_set[np.argsort(reward_est)[-half_len:]] | ||
|
||
self.optimal_action = active_set[0] | ||
self.writer.add_scalar("optimal_action", self.optimal_action, ep) | ||
|
||
return actions | ||
|
||
|
||
# Experiment | ||
|
||
xp_manager = ExperimentManager( | ||
SeqHalvAgent, | ||
(env_ctor, env_kwargs), | ||
fit_budget=100, # we use only 100 iterations for faster example run in doc. | ||
n_fit=1, | ||
agent_name="SH", | ||
) | ||
xp_manager.fit() | ||
|
||
rewards = read_writer_data([xp_manager], preprocess_tag="reward")["value"] | ||
actions = read_writer_data([xp_manager], preprocess_tag="action")["value"] | ||
|
||
|
||
plt.boxplot([-rewards[actions == a] for a in range(6)]) | ||
plt.xlabel("Server") | ||
plt.ylabel("Waiting time (in s)") | ||
plt.show() | ||
|
||
print( | ||
"The optimal action (fastest server) is server number ", | ||
xp_manager.agent_handlers[0].optimal_action + 1, | ||
) |
43 changes: 43 additions & 0 deletions
43
_downloads/127146715fc2fd3e933915742ab6105d/video_plot_rs_kernel_ucbvi.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"\n# A demo of RSKernelUCBVIAgent algorithm in Acrobot environment\n Illustration of how to set up a RSKernelUCBVI algorithm in rlberry.\n The environment chosen here is Acrobot environment.\n\n.. video:: ../../video_plot_rs_kernel_ucbvi.mp4\n :width: 600\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from rlberry_research.envs import Acrobot\nfrom rlberry_research.agents import RSKernelUCBVIAgent\nfrom rlberry.wrappers import RescaleRewardWrapper\n\nenv = Acrobot()\n# rescake rewards to [0, 1]\nenv = RescaleRewardWrapper(env, (0.0, 1.0))\n\nagent = RSKernelUCBVIAgent(\n env,\n gamma=0.99,\n horizon=300,\n bonus_scale_factor=0.01,\n min_dist=0.2,\n bandwidth=0.05,\n beta=1.0,\n kernel_type=\"gaussian\",\n)\nagent.fit(budget=500)\n\nenv.enable_rendering()\nobservation, info = env.reset()\n\ntime_before_done = 0\nended = False\nfor tt in range(2 * agent.horizon):\n action = agent.policy(observation)\n observation, reward, terminated, truncated, info = env.step(action)\n done = terminated or truncated\n if not done and not ended:\n time_before_done += 1\n if done:\n ended = True\n\nprint(\"steps to achieve the goal for the first time = \", time_before_done)\nvideo = env.save_video(\"_video/video_plot_rs_kernel_ucbvi.mp4\")" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.15" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 0 | ||
} |
43 changes: 43 additions & 0 deletions
43
_downloads/173ff6e8bc2ce53868f709d27fee17db/demo_SAC.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"\n# SAC Soft Actor-Critic\n\nThis script shows how to train a SAC agent on a Pendulum environment.\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import time\n\nimport gymnasium as gym\nfrom rlberry_research.agents.torch.sac import SACAgent\nfrom rlberry_research.envs import Pendulum\nfrom rlberry.manager import ExperimentManager\n\n\ndef env_ctor(env, wrap_spaces=True):\n return env\n\n\n# Setup agent parameters\nenv_name = \"Pendulum\"\nfit_budget = int(2e5)\nagent_name = f\"{env_name}_{fit_budget}_{int(time.time())}\"\n\n# Setup environment parameters\nenv = Pendulum()\nenv = gym.wrappers.TimeLimit(env, max_episode_steps=200)\nenv = gym.wrappers.RecordEpisodeStatistics(env)\nenv_kwargs = dict(env=env)\n\n# Create agent instance\nxp_manager = ExperimentManager(\n SACAgent,\n (env_ctor, env_kwargs),\n fit_budget=fit_budget,\n n_fit=1,\n enable_tensorboard=True,\n agent_name=agent_name,\n)\n\n# Start training\nxp_manager.fit()" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.15" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 0 | ||
} |
43 changes: 43 additions & 0 deletions
43
_downloads/17eaa60d740f201a5f49ef97a8999352/plot_TS_bandit.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"\n# Comparison of Thompson sampling and UCB on Bernoulli and Gaussian bandits\n\nThis script shows how to use Thompson sampling on two examples: Bernoulli and Gaussian bandits.\n\nIn the Bernoulli case, we use Thompson sampling with a Beta prior. We compare it to a UCB for\nbounded rewards with support in [0,1].\nFor the Gaussian case, we use a Gaussian prior and compare it to a sub-Gaussian UCB.\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np\nfrom rlberry_research.envs.bandits import BernoulliBandit, NormalBandit\nfrom rlberry_research.agents.bandits import (\n IndexAgent,\n TSAgent,\n makeBoundedUCBIndex,\n makeSubgaussianUCBIndex,\n makeBetaPrior,\n makeGaussianPrior,\n)\nfrom rlberry.manager import ExperimentManager, plot_writer_data\n\n\n# Bernoulli\n\n# Agents definition\n\n\nclass BernoulliTSAgent(TSAgent):\n \"\"\"Thompson sampling for Bernoulli bandit\"\"\"\n\n name = \"Bernoulli TS Agent\"\n\n def __init__(self, env, **kwargs):\n prior, _ = makeBetaPrior()\n TSAgent.__init__(self, env, prior, writer_extra=\"action\", **kwargs)\n\n\nclass BoundedUCBAgent(IndexAgent):\n \"\"\"UCB agent for bounded bandits\"\"\"\n\n name = \"Bounded UCB Agent\"\n\n def __init__(self, env, **kwargs):\n index, _ = makeBoundedUCBIndex(0, 1)\n IndexAgent.__init__(self, env, index, writer_extra=\"action\", **kwargs)\n\n\n# Parameters of the problem\nmeans = np.array([0.8, 0.8, 0.9, 1]) # means of the arms\nA = len(means)\nT = 2000 # Horizon\nM = 10 # number of MC simu\n\n# Construction of the experiment\n\nenv_ctor = BernoulliBandit\nenv_kwargs = {\"p\": means}\n\nagents = [\n ExperimentManager(\n Agent,\n (env_ctor, env_kwargs),\n fit_budget=T,\n n_fit=M,\n )\n for Agent in [BoundedUCBAgent, BernoulliTSAgent]\n]\n\n# Agent training\n\nfor agent in agents:\n agent.fit()\n\n\n# Compute and plot (pseudo-)regret\ndef compute_pseudo_regret(actions):\n return np.cumsum(np.max(means) - means[actions.astype(int)])\n\n\noutput = plot_writer_data(\n agents,\n tag=\"action\",\n preprocess_func=compute_pseudo_regret,\n title=\"Cumulative Pseudo-Regret\",\n)\n\n\n# Gaussian\n\n\nclass GaussianTSAgent(TSAgent):\n \"\"\"Thompson sampling for Gaussian bandit\"\"\"\n\n name = \"Gaussian TS Agent\"\n\n def __init__(self, env, sigma=1.0, **kwargs):\n prior, _ = makeGaussianPrior(sigma)\n TSAgent.__init__(self, env, prior, writer_extra=\"action\", **kwargs)\n\n\nclass GaussianUCBAgent(IndexAgent):\n \"\"\"UCB agent for Gaussian bandits\"\"\"\n\n name = \"Gaussian UCB Agent\"\n\n def __init__(self, env, sigma=1.0, **kwargs):\n index, _ = makeSubgaussianUCBIndex(sigma)\n IndexAgent.__init__(self, env, index, writer_extra=\"action\", **kwargs)\n\n\n# Parameters of the problem\nmeans = np.array([0.3, 0.5]) # means of the arms\nsigma = 1.0 # means of the arms\nA = len(means)\nT = 2000 # Horizon\nM = 10 # number of MC simu\n\n# Construction of the experiment\n\nenv_ctor = NormalBandit\nenv_kwargs = {\"means\": means, \"stds\": sigma * np.ones(A)}\n\nagents = [\n ExperimentManager(\n Agent,\n (env_ctor, env_kwargs),\n fit_budget=T,\n n_fit=M,\n )\n for Agent in [GaussianUCBAgent, GaussianTSAgent]\n]\n\n# Agent training\n\nfor agent in agents:\n agent.fit()\n\n\n# Compute and plot (pseudo-)regret\ndef compute_pseudo_regret(actions):\n return np.cumsum(np.max(means) - means[actions.astype(int)])\n\n\noutput = plot_writer_data(\n agents,\n tag=\"action\",\n preprocess_func=compute_pseudo_regret,\n title=\"Cumulative Pseudo-Regret\",\n)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.15" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 0 | ||
} |
Oops, something went wrong.