Skip to content

Commit

Permalink
Fix bugs & log loading
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertTLange committed Feb 20, 2022
1 parent c6252b0 commit dce1742
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 13 deletions.
11 changes: 10 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
## [v0.0.6] - [02/19/2022]
## [v0.0.7] - [02/20/2022]

### Added
- Log reloading helper for post-processing.

### Fixed
- Bug fix in `mle-search` with imports of dependencies. Needed to append path.
- Bug fix with cleaning nested dictionaries. Have to make sure not to delete entire sub-dictionary.

## [v0.0.6] - [02/20/2022]
### Added

- Adds a command line interface for running a sequential search given a python script `<script>.py` containing a function `main(config)`, a default configuration file `<base>.yaml` & a search configuration `<search>.yaml`. The `main` function should return a single scalar performance score. You can then start the search via:
Expand Down
2 changes: 1 addition & 1 deletion mle_hyperopt/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.6"
__version__ = "0.0.7"
10 changes: 7 additions & 3 deletions mle_hyperopt/cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import os
import sys
import importlib
from mle_logging import load_config
from .strategies import Strategies
Expand All @@ -19,7 +20,7 @@ def get_search_args() -> None:
"-base",
"--base_config",
type=str,
default="base.yaml",
default=None,
help="Filename to load base configuration from.",
)
parser.add_argument(
Expand Down Expand Up @@ -88,14 +89,16 @@ def search() -> None:
if "categorical" in search_config.search_config.keys()
else None
)
if base_config is not None:
base_config = base_config.toDict()

strategy = Strategies[search_config.search_type](
real,
integer,
categorical,
search_config.search_config,
search_config.maximize_objective,
fixed_params=base_config.toDict(),
fixed_params=base_config,
verbose=search_config.verbose,
)

Expand All @@ -112,7 +115,8 @@ def search() -> None:
else search_config.num_iters
)

# Load the main function module
# Append path for correct imports & load the main function module
sys.path.append(os.getcwd())
spec = importlib.util.spec_from_file_location(
"main", os.path.join(os.getcwd(), args.exec_fname)
)
Expand Down
3 changes: 3 additions & 0 deletions mle_hyperopt/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def __init__(
self.categorical_names = list(self.categorical.keys())
else:
self.categorical_names = []
self.variable_names = (
self.real_names + self.integer_names + self.categorical_names
)

def check(self):
"""Check that all inputs are provided correctly."""
Expand Down
15 changes: 11 additions & 4 deletions mle_hyperopt/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pandas as pd
import numbers
from .utils import (
merge_config_dicts,
load_log,
save_log,
load_strategy,
Expand Down Expand Up @@ -127,7 +128,9 @@ def ask(
if self.fixed_params is not None:
for i in range(len(param_batch)):
# Important that param_batch 2nd - overwrites fixed k,v!
param_batch[i] = {**self.fixed_params, **param_batch[i]}
param_batch[i] = dict(
merge_config_dicts(self.fixed_params, param_batch[i])
)

# If string for storage is given: Save configs as .yaml
if store:
Expand Down Expand Up @@ -253,11 +256,15 @@ def clean_data(
proposal_clean = proposal_clean["params"]
else:
extra_data = None
if self.fixed_params is not None:
for k in self.fixed_params.keys():
del proposal_clean[k]

# After extra/fixed parameter clean up - flatten remaining params
proposal_clean = flatten_config(proposal_clean)
if self.fixed_params is not None:
fixed_flat = flatten_config(self.fixed_params)
for k in fixed_flat.keys():
if k not in self.space.variable_names:
if k in proposal_clean.keys():
del proposal_clean[k]

if (
proposal_clean in self.all_evaluated_params
Expand Down
7 changes: 6 additions & 1 deletion mle_hyperopt/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from .helpers import (
save_yaml,
load_log,
save_log,
load_strategy,
save_strategy,
write_configs,
flatten_config,
unflatten_config,
merge_config_dicts,
)
from .plotting import visualize_2D_grid
from .plotting import visualize_2D_grid, load_search_log
from .comms import (
welcome_message,
update_message,
Expand All @@ -23,14 +25,17 @@


__all__ = [
"save_yaml",
"load_log",
"save_log",
"load_strategy",
"save_strategy",
"write_configs",
"flatten_config",
"unflatten_config",
"merge_config_dicts",
"visualize_2D_grid",
"load_search_log",
"welcome_message",
"update_message",
"ranking_message",
Expand Down
26 changes: 26 additions & 0 deletions mle_hyperopt/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,29 @@ def flatten_config(dictionary, parent_key="", sep="/") -> dict:
else:
items.append((new_key, v))
return dict(items)


def merge_config_dicts(dict1: dict, dict2: dict):
"""Merge two potentially nested dictionaries.
Important: dict2 overwrites dict1 in case of shared entries.
Args:
dict1 (dict): Fixed parameter dictionary.
dict2 (dict): New hyperparameters to evaluate.
Yields:
_type_: Generator - wrap with dict outside of function.
"""
for k in set(dict1.keys()).union(dict2.keys()):
if k in dict1 and k in dict2:
if isinstance(dict1[k], dict) and isinstance(dict2[k], dict):
yield (k, dict(merge_config_dicts(dict1[k], dict2[k])))
else:
# If one of the values is not a dict, you can't continue merging it.
# Value from second dict overrides one in first and we move on.
yield (k, dict2[k])
# Alternatively, replace this with exception raiser to alert you of value conflicts
elif k in dict1:
yield (k, dict1[k])
else:
yield (k, dict2[k])
29 changes: 26 additions & 3 deletions mle_hyperopt/utils/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,31 @@
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from typing import List, Union

import seaborn as sns
from mle_logging import load_config


def load_search_log(log_fname: str) -> pd.core.frame.DataFrame:
"""Reload the stored log yaml file.
Args:
log_fname (str): Filename to load
Returns:
pd.core.frame.DataFrame: Reloaded log as pandas dataframe.
"""
log_dict = load_config(log_fname)

log_list = []
for k in log_dict.keys():
log_list.append(log_dict[k])
# Load in json format for nested dictionaries
df = pd.json_normalize(log_list)
# Rename columns and get rid of 'params.'
new_cols = [df.columns[i].split(".")[-1] for i in range(len(df.columns))]
df.columns = new_cols
return df


# Set overall plots appearance sns style
sns.set(
Expand All @@ -22,11 +45,11 @@ def visualize_2D_grid(
hyper_df: pd.core.frame.DataFrame,
fixed_params: Union[None, dict] = None,
params_to_plot: list = [],
target_to_plot: str = "target",
target_to_plot: str = "objective",
plot_title: str = "Temp Title",
plot_subtitle: Union[None, str] = None,
xy_labels: Union[None, List[str]] = ["x-label", "y-label"],
variable_name: Union[None, str] = "Var Label",
variable_name: Union[None, str] = "Performance",
every_nth_tick: int = 1,
plot_colorbar: bool = True,
text_in_cell: bool = False,
Expand Down

0 comments on commit dce1742

Please sign in to comment.