diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f3a10c6dc..fa0c784a3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,7 @@ ci: repos: - repo: https://github.com/psf/black - rev: 23.11.0 + rev: 23.10.1 hooks: - id: black-jupyter - repo: https://github.com/asottile/pyupgrade diff --git a/ema_workbench/analysis/prim.py b/ema_workbench/analysis/prim.py index 5c6a1c279..2faa75a4a 100644 --- a/ema_workbench/analysis/prim.py +++ b/ema_workbench/analysis/prim.py @@ -362,6 +362,8 @@ def __init__(self, prim, box_lims, indices): "res_dim": pd.Series(dtype=int), "mass": pd.Series(dtype=float), "id": pd.Series(dtype=int), + "n": pd.Series(dtype=int), # items in box + "k": pd.Series(dtype=int), # items of interest in box } self.peeling_trajectory = pd.DataFrame(columns) @@ -796,6 +798,8 @@ def update(self, box_lims, indices): "res_dim": restricted_dims.shape[0], "mass": y.shape[0] / self.prim.n, "id": i, + "n": y.shape[0], + "k": coi, } new_row = pd.DataFrame([data]) # self.peeling_trajectory = self.peeling_trajectory.append( @@ -872,7 +876,7 @@ def show_pairs_scatter( if dims is None: dims = sdutil._determine_restricted_dims(self.box_lims[i], self.prim.box_init) - if diag_kind not in diag_kind.__members__: + if diag_kind not in DiagKind: raise ValueError( f"diag_kind should be one of DiagKind.KDE or DiagKind.CDF, not {diag_kind}" ) @@ -916,10 +920,10 @@ def _calculate_quasi_p(self, i, restricted_dims): box_lim = box_lim[restricted_dims] # total nr. of cases in box - Tbox = self.peeling_trajectory["mass"][i] * self.prim.n + Tbox = self.peeling_trajectory.loc[i, "n"] # total nr. of cases of interest in box - Hbox = self.peeling_trajectory["coverage"][i] * self.prim.t_coi + Hbox = self.peeling_trajectory.loc[i, "k"] x = self.prim.x.loc[self.prim.yi_remaining, restricted_dims] y = self.prim.y[self.prim.yi_remaining] diff --git a/ema_workbench/analysis/prim_util.py b/ema_workbench/analysis/prim_util.py index f6b3c709c..0d321ca42 100644 --- a/ema_workbench/analysis/prim_util.py +++ b/ema_workbench/analysis/prim_util.py @@ -37,6 +37,9 @@ class DiagKind(Enum): CDF = "cdf" """constant for plotting diagonal in pairs_scatter as cdf""" + def __contains__(cls, item): + return item in cls.__members__.values() + def get_quantile(data, quantile): """ diff --git a/test/test_analysis/test_prim.py b/test/test_analysis/test_prim.py index 1c249b856..ffd096af6 100644 --- a/test/test_analysis/test_prim.py +++ b/test/test_analysis/test_prim.py @@ -38,7 +38,7 @@ def test_init(self): prim_obj = prim.setup_prim(results, "y", threshold=0.8) box = PrimBox(prim_obj, prim_obj.box_init, prim_obj.yi) - self.assertEqual(box.peeling_trajectory.shape, (1, 6)) + self.assertEqual(box.peeling_trajectory.shape, (1, 8)) def test_select(self): x = pd.DataFrame([(0, 1, 2), (2, 5, 6), (3, 2, 1)], columns=["a", "b", "c"])