Skip to content

Commit

Permalink
Update lib versions and remove warnings from CI (#476)
Browse files Browse the repository at this point in the history
* init commit
* labels -> tick_labels
* warn -> warning
* regex string -> raw string
* update to remove "divide by 0" warning
* update to remove "optuna" warnings
* switch docopt to maintained fork docopt-ng
* update changelog
  • Loading branch information
JulienT01 authored Oct 25, 2024
1 parent 7b78b70 commit 2ce012a
Show file tree
Hide file tree
Showing 9 changed files with 30 additions and 14 deletions.
6 changes: 6 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ Changelog
Dev version
-----------



*PR #476*

* Update lib versions and remove warnings from tests/CI : https://github.com/rlberry-py/rlberry/issues/471

*PR #474*

* Create a new tool to load data from tensorboard logs : https://github.com/rlberry-py/rlberry/issues/472
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pygame-ce = "*"
matplotlib = "*"
gymnasium = {version="^0.29.1", extras=["atari", "accept-rom-license"]}
dill = "*"
docopt = "*"
docopt-ng = "*"
pyyaml = "*"
tqdm = "*"
adastop = "*"
Expand All @@ -36,7 +36,7 @@ optuna ={version="*", optional=true}
ffmpeg-python = {version="*", optional=true}
opencv-python = {version="*", optional=true}
ale-py = {version="*", optional=true}
stable-baselines3 = {version="*", optional=true}
stable-baselines3 = {version=">=2.3", optional=true}
tensorboard = {version="*", optional=true}
torch = [
{ version = ">=2.0.0,<2.3", platform = "darwin", optional=true },
Expand Down
2 changes: 1 addition & 1 deletion rlberry/manager/env_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __func_to_script(func):
name = fun_source.split("\n")[1] # skip decorator

m = re.search(
"(?<= )\w+", name
r"(?<= )\w+", name
) # isolate the name of function to use as script name

source = "\n"
Expand Down
2 changes: 1 addition & 1 deletion rlberry/manager/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def evaluate_agents(
# plot
if plot:
plt.figure(fignum)
plt.boxplot(output.values, labels=output.columns)
plt.boxplot(output.values, tick_labels=output.columns)
plt.xlabel("agent")
plt.ylabel("evaluation output")
if show:
Expand Down
4 changes: 3 additions & 1 deletion rlberry/manager/experiment_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,9 @@ def eval_agents(
def clear_output_dir(self):
"""Delete output_dir and all its data."""
if self.optuna_study:
optuna.delete_study(self.optuna_study.study_name, self.optuna_storage_url)
optuna.delete_study(
study_name=self.optuna_study.study_name, storage=self.optuna_storage_url
)
try:
shutil.rmtree(self.output_dir_)
except FileNotFoundError:
Expand Down
12 changes: 7 additions & 5 deletions rlberry/manager/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,9 @@ def process(df):
# with cross validation bandwidth selection
if not isinstance(smoothing_bandwidth, numbers.Number):
if smoothing_bandwidth is None:
bandwidth = np.linspace(min_bandwidth_x, (max_x - min_x) / 100, 10)
bandwidth = np.linspace(
min_bandwidth_x, max((max_x - min_x) / 100, min_bandwidth_x * 3), 10
)
else:
bandwidth = smoothing_bandwidth
nw = SmoothingParameterSearch(
Expand All @@ -358,7 +360,7 @@ def process(df):
except:
raise ValueError("non-finite (or non float) data detected.")
if not np.all(np.isfinite(X)):
logger.warn(
logger.warning(
"Some of the values are not finite. Not plotting the associated curves."
)
Xhat[f] = np.nan
Expand Down Expand Up @@ -419,7 +421,7 @@ def process(df):

elif error_representation == "cb":
if n_tot_simu < 1 / (1 - level):
logger.warn(
logger.warning(
"Computing a cb that cannot achieve the level prescribed because there are not enough seeds."
)

Expand All @@ -445,7 +447,7 @@ def process(df):
)
else:
y_err = np.zeros(len(xplot))
logger.warn(
logger.warning(
"The variance of the curve was 0, the confidence bound is very biased"
)

Expand Down Expand Up @@ -546,7 +548,7 @@ def plot_synchronized_curves(
float
)
if len(x_simu) != len(x_simu_0):
logger.warn("x axis is not the same for all the runs, truncating.")
logger.warning("x axis is not the same for all the runs, truncating.")
x_simu_0 = np.intersect1d(x_simu_0, x_simu)
df_name = df_name.loc[df_name[xlabel].apply(lambda x: x in x_simu_0)]
assert (
Expand Down
10 changes: 8 additions & 2 deletions rlberry/manager/tests/test_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,18 @@
import pandas as pd


it = 1


class DummyAgent(AgentWithSimplePolicy):
def __init__(self, env, eval_val=0, **kwargs):
global it
AgentWithSimplePolicy.__init__(self, env, **kwargs)
self.name = "DummyAgent"
self.fitted = False
self.eval_val = eval_val
# self.eval_val = eval_val
self.eval_val = it
it += 1

self.total_budget = 0.0

Expand Down Expand Up @@ -69,7 +75,7 @@ def test_compare(method, source):
data_source = pd.DataFrame(
{
"agent": (["Agent 1"] * 10) + (["Agent 2"] * 10),
"mean_eval": ([0] * 10) + ([10] * 10),
"mean_eval": range(0, 20),
}
)

Expand Down
2 changes: 1 addition & 1 deletion rlberry/manager/tests/test_experiment_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ def test_not_pickle(compress, agent):


def test_fitbudget_exception():
msg = "\[ExperimentManager\] fit_budget missing in __init__\(\)\." # /!\ regex : need to escape some char.
msg = r"\[ExperimentManager\] fit_budget missing in __init__\(\)\." # /!\ regex : need to escape some char.
with pytest.raises(ValueError, match=msg):
# Define train and evaluation envs
train_env = (GridWorld, {})
Expand Down
2 changes: 1 addition & 1 deletion rlberry/utils/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def get_memory(pid=None):
command, stdout=PIPE, stderr=PIPE, universal_newlines=True, shell=True
).stdout
m = re.findall(
"\| *[0-9] *" + str(pid) + " *C *.*python.*? +([0-9]+).*\|",
r"\| *[0-9] *" + str(pid) + r" *C *.*python.*? +([0-9]+).*\|",
result,
re.MULTILINE,
)
Expand Down

0 comments on commit 2ce012a

Please sign in to comment.