Skip to content

Commit

Permalink
Revert "Plotting and Metric for Local classifier 2 sample test"
Browse files Browse the repository at this point in the history
  • Loading branch information
voetberg authored May 21, 2024
1 parent c889c23 commit 4d62631
Show file tree
Hide file tree
Showing 16 changed files with 211 additions and 459 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ numpy = "^1.26.4"
matplotlib = "^3.8.3"
tarp = "^0.1.1"
deprecation = "^2.1.0"
scipy = "1.12.0"


[tool.poetry.group.dev.dependencies]
Expand Down
14 changes: 4 additions & 10 deletions src/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,9 @@ def main():
plots = config.get_section("plots", raise_exception=False)

for metrics_name, metrics_args in metrics.items():
try:
Metrics[metrics_name](model, data, **metrics_args)()
except (NotImplementedError, RuntimeError) as error:
print(f"WARNING - skipping metric {metrics_name} due to error: {error}")
Metrics[metrics_name](model, data, **metrics_args)()

for plot_name, plot_args in plots.items():
try:
Plots[plot_name](model, data, save=True, show=False, out_dir=out_dir)(
**plot_args
)
except (NotImplementedError, RuntimeError) as error:
print(f"WARNING - skipping plot {plot_name} due to error: {error}")
Plots[plot_name](model, data, save=True, show=False, out_dir=out_dir)(
**plot_args
)
2 changes: 1 addition & 1 deletion src/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,4 @@ def load_prior(self, prior, prior_kwargs):
return lambda size: choices[prior](**prior_kwargs, size=size)

except KeyError as e:
raise RuntimeError(f"Data missing a prior specification - {e}")
raise RuntimeError(f"Data missing a prior specification - {e}")
3 changes: 1 addition & 2 deletions src/data/h5_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
class H5Data(Data):
def __init__(self, path: str, simulator: Callable):
super().__init__(path, simulator)
self.theta_true = self.get_theta_true()


def _load(self, path):
assert path.split(".")[-1] == "h5", "File extension must be h5"
loaded_data = {}
Expand Down
5 changes: 1 addition & 4 deletions src/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
from metrics.all_sbc import AllSBC
from metrics.coverage_fraction import CoverageFraction
from metrics.local_two_sample import LocalTwoSampleTest


_all = [CoverageFraction, AllSBC, LocalTwoSampleTest]
Metrics = {m.__name__: m for m in _all}
Metrics = {CoverageFraction.__name__: CoverageFraction, AllSBC.__name__: AllSBC}
175 changes: 0 additions & 175 deletions src/metrics/local_two_sample.py

This file was deleted.

4 changes: 1 addition & 3 deletions src/models/sbi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ def sample_posterior(self, n_samples: int, y_true): # TODO typing
def predict_posterior(self, data):
posterior_samples = self.sample_posterior(data.y_true)
posterior_predictive_samples = data.simulator(
data.get_theta_true(), posterior_samples
data.theta_true(), posterior_samples
)
return posterior_predictive_samples


9 changes: 6 additions & 3 deletions src/plots/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from plots.cdf_ranks import CDFRanks
from plots.coverage_fraction import CoverageFraction
from plots.ranks import Ranks
from plots.local_two_sample import LocalTwoSampleTest
from plots.tarp import TARP

_all = [CoverageFraction, CDFRanks, Ranks, LocalTwoSampleTest, TARP]
Plots = {m.__name__: m for m in _all}
Plots = {
CDFRanks.__name__: CDFRanks,
CoverageFraction.__name__: CoverageFraction,
Ranks.__name__: Ranks,
TARP.__name__: TARP,
}
Loading

0 comments on commit 4d62631

Please sign in to comment.