Skip to content

Commit

Permalink
test for get_results and categorical fix (#241)
Browse files Browse the repository at this point in the history
* test for get_results and categorical fix for #236
  • Loading branch information
quaquel authored Apr 3, 2023
1 parent 76d64c8 commit 4463a7f
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 7 deletions.
14 changes: 14 additions & 0 deletions ema_workbench/em_framework/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,9 @@ def __init__(
columns.append(name)
dtypes.append("object")

self.columns = columns
self.dtypes = dtypes

index = np.arange(nr_experiments)
column_dict = {
name: pd.Series(dtype=dtype, index=index) for name, dtype in zip(columns, dtypes)
Expand Down Expand Up @@ -306,6 +309,17 @@ def get_results(self):
_logger.warning("some experiments have failed, returning masked result arrays")
results[k] = v

# we want to ensure the dtypes for the columns in the experiments dataframe match
# the type of uncertainty. The exception is needed in case their are missing values (i.e. nans).
# nans can only ever be a float.
for name, dtype in zip(self.columns, self.dtypes):
try:
if dtype == "object":
dtype = "category"
self.cases[name] = self.cases[name].astype(dtype)
except Exception:
pass

return self.cases, results

def _setup_outcomes_array(self, shape, dtype):
Expand Down
49 changes: 42 additions & 7 deletions test/test_em_framework/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
.. codeauthor:: jhkwakkel <j.h.kwakkel (at) tudelft (dot) nl>
"""
import random

import pandas as pd
import pytest

import numpy as np
Expand Down Expand Up @@ -153,6 +155,7 @@ def test_store_cases():
case = {unc.name: random.random() for unc in uncs}
case["c"] = int(round(case["c"] * 2))
case["d"] = int(round(case["d"]))
case["e"] = True

model = NamedObject("test")
policy = Policy("policy")
Expand All @@ -166,15 +169,15 @@ def test_store_cases():
callback(experiment, model_outcomes)

experiments, _ = callback.get_results()
design = case
design["policy"] = policy.name
design["model"] = model.name
design["scenario"] = scenario.name
# design = case
case["policy"] = policy.name
case["model"] = model.name
case["scenario"] = scenario.name

names = experiments.columns.values.tolist()
for name in names:
entry_a = experiments[name][0]
entry_b = design[name]
entry_b = case[name]

assert entry_a == entry_b, "failed for " + name

Expand Down Expand Up @@ -213,25 +216,57 @@ def test_get_results(mocker):
nr_experiments = 3
uncs = [
RealParameter("a", 0, 1),
CategoricalParameter("b", ["0", "1", "2", "3"]),
IntegerParameter("c", 0, 5),
BooleanParameter("d"),
]
outcomes = [ScalarOutcome("other_test")]
outcomes[0].shape = (1,)

callback = DefaultCallback(
uncs, [], outcomes, nr_experiments=nr_experiments, reporting_interval=1
)

# test warning
mock = mocker.patch("ema_workbench.em_framework.callbacks._logger.warning")
callback.get_results()
assert mock.call_count == 1

# test without warning
callback = DefaultCallback(
uncs, [], outcomes, nr_experiments=nr_experiments, reporting_interval=1
)

cases = []
for i in range(nr_experiments):
model = NamedObject("test")
policy = Policy("policy")
case = {"a": i * 0.15, "b": f"{i}", "c": i, "d": True if i % 2 == 0 else False}
scenario = Scenario(**case)
experiment = Experiment(0, model.name, policy, scenario, i)
model_outcomes = {outcomes[0].name: i * 1.25}
callback(experiment, model_outcomes)
cases.append(case)

mock = mocker.patch("ema_workbench.em_framework.callbacks._logger.warning")
callback.results = {k: v.data for k, v in callback.results.items()}
callback.get_results()
experiments, results = callback.get_results()
assert mock.call_count == 0

# check if experiments dataframe contains the experiments correctly
data = pd.DataFrame.from_dict(cases)
assert np.all(data == experiments.loc[:, data.columns])

# check data types of columns in experiments dataframe
dtype_mapping = {
RealParameter: float,
CategoricalParameter: "category",
IntegerParameter: int,
BooleanParameter: bool,
}

for u in uncs:
assert experiments.loc[:, u.name].dtype == dtype_mapping[u.__class__]


def test_filebasedcallback(mocker):
# only most basic assertions are checked
Expand Down

0 comments on commit 4463a7f

Please sign in to comment.