Skip to content

Commit

Permalink
Wasserstein distances (#500)
Browse files Browse the repository at this point in the history
* refactor distance submodule

* cont

* implement wasserstein distances

* fix style

* install cython

* fix2

* add raise test
  • Loading branch information
yannikschaelte authored Oct 26, 2021
1 parent 197fb91 commit 275f7de
Show file tree
Hide file tree
Showing 16 changed files with 959 additions and 26 deletions.
2 changes: 1 addition & 1 deletion doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ API reference
.. automodule:: pyabc

.. toctree::
:maxdepth: 2
:maxdepth: 3

api_inference
api_distance
Expand Down
2 changes: 2 additions & 0 deletions doc/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ Potentially, further dependencies may be required.
examples/adaptive_distances.ipynb
examples/informative.ipynb
examples/aggregated_distances.ipynb
examples/wasserstein.ipynb
examples/external_simulators.ipynb
examples/data_plots.ipynb
examples/noise.ipynb
Expand All @@ -58,6 +59,7 @@ Download the examples as notebooks
* :download:`Adaptive distances <examples/adaptive_distances.ipynb>`
* :download:`Informative distances and summary statistics <examples/informative.ipynb>`
* :download:`Aggregated distances <examples/aggregated_distances.ipynb>`
* :download:`Wasserstein distances <examples/wasserstein.ipynb>`
* :download:`External simulators <examples/external_simulators.ipynb>`
* :download:`Data plots <examples/data_plots.ipynb>`
* :download:`Measurement noise and exact inference <examples/noise.ipynb>`
Expand Down
4 changes: 2 additions & 2 deletions doc/examples/informative.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"id": "df5d57e9-6819-4d6a-bdbb-8639f315602a",
"metadata": {},
"outputs": [],
"source": [
"# install if not done yet\n",
"!pip install pyabc --quiet"
"!pip install pyabc[plotly] --quiet"
]
},
{
Expand Down
438 changes: 438 additions & 0 deletions doc/examples/wasserstein.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ If you use it in your work, you can cite the paper:


.. toctree::
:maxdepth: 2
:maxdepth: 3
:caption: API reference

api
Expand Down
2 changes: 2 additions & 0 deletions pyabc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
PercentileDistance,
RangeEstimatorDistance,
DistanceWithMeasureList,
WassersteinDistance,
SlicedWassersteinDistance,
StochasticKernel,
NormalKernel,
IndependentNormalKernel,
Expand Down
14 changes: 9 additions & 5 deletions pyabc/distance/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""
Distance functions
==================
Distances
=========
Distance functions measure closeness of observed and sampled data. This
module implements various commonly used distance functions for ABC, featuring
a few advanced concepts.
Distance functions or metrics measure closeness of observed and sampled data.
This module implements various commonly used distance functions for ABC,
featuring a few advanced concepts.
For custom distance functions, either pass a plain function to ABCSMC, or
subclass the pyabc.Distance class.
Expand Down Expand Up @@ -34,6 +34,10 @@
AggregatedDistance,
AdaptiveAggregatedDistance,
)
from .ot import (
WassersteinDistance,
SlicedWassersteinDistance,
)
from .scale import (
median_absolute_deviation,
mad,
Expand Down
6 changes: 1 addition & 5 deletions pyabc/distance/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging

from ..population import Sample

from .base import Distance


Expand All @@ -20,7 +21,6 @@ class DistanceWithMeasureList(Distance, ABC):
Parameters
----------
measures_to_use: Union[str, List[str]].
* If set to "all", all measures are used. This is the default.
* If a list is provided, the measures in the list are used.
Expand Down Expand Up @@ -167,13 +167,11 @@ def lower(parameter_list: List[float]):
Parameters
----------
parameter_list: List[float]
List of values of a parameter.
Returns
-------
lower_margin: float
The lower margin of the range calculated from these parameters
"""
Expand All @@ -185,13 +183,11 @@ def upper(parameter_list: List[float]):
Parameters
----------
parameter_list: List[float]
List of values of a parameter.
Returns
-------
upper_margin: float
The upper margin of the range calculated from these parameters
"""
Expand Down
Loading

0 comments on commit 275f7de

Please sign in to comment.