Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
cedricvincentcuaz committed Nov 6, 2024
2 parents 63477c2 + 38922c0 commit a94c6ac
Show file tree
Hide file tree
Showing 176 changed files with 24,767 additions and 8,698 deletions.
31 changes: 16 additions & 15 deletions .github/CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@ GitHub, clone, and develop on a branch. Steps:
$ cd POT
```

2. Install pre-commit hooks to ensure that your code is properly formatted:

```bash
$ pip install pre-commit
$ pre-commit install
```

This will install the pre-commit hooks that will run on every commit. If the hooks fail, the commit will be aborted.

3. Create a ``feature`` branch to hold your development changes:

```bash
Expand Down Expand Up @@ -56,7 +65,7 @@ Pull Request Checklist
We recommended that your contribution complies with the
following rules before you submit a pull request:

- Follow the PEP8 Guidelines.
- Follow the PEP8 Guidelines which should be handles automatically by pre-commit.

- If your pull request addresses an issue, please use the pull request title
to describe the issue and mention the issue number in the pull request description. This will make sure a link back to the original issue is
Expand Down Expand Up @@ -101,27 +110,19 @@ following rules before you submit a pull request:
You can also check for common programming errors with the following
tools:


- No pyflakes warnings, check with:
- All lint checks pass. You can run the following command to check:

```bash
$ pip install pyflakes
$ pyflakes path/to/module.py
$ pre-commit run --all-files
```

- No PEP8 warnings, check with:
This will run the pre-commit checks on all files in the repository.

```bash
$ pip install pep8
$ pep8 path/to/module.py
```

- AutoPEP8 can help you fix some of the easy redundant errors:
- All tests pass. You can run the following command to check:

```bash
$ pip install autopep8
$ autopep8 path/to/pep8.py
```
$ pytest --durations=20 -v test/ --doctest-modules
```

Bonus points for contributions that include a performance analysis with
a benchmark script and profiling output (please report on the mailing
Expand Down
46 changes: 25 additions & 21 deletions .github/workflows/build_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,38 @@ on:
- '**'

jobs:

Lint:
runs-on: ubuntu-latest
strategy:
fail-fast: false
defaults:
run:
shell: bash -l {0}
steps:


- name: Checking Out Repository
uses: actions/checkout@v2
# Install Python & Packages
- uses: actions/setup-python@v4
with:
python-version: "3.10"
- run: which python
- name: Lint with pre-commit
run: |
pip install pre-commit
pre-commit install --install-hooks
pre-commit run --all-files
linux:

runs-on: ubuntu-latest
if: "!contains(github.event.head_commit.message, 'no ci')"
strategy:
max-parallel: 4
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
python-version: ["3.9", "3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v4
Expand All @@ -44,26 +68,6 @@ jobs:
- name: Upload coverage reports to Codecov with GitHub Action
uses: codecov/codecov-action@v3

pep8:
runs-on: ubuntu-latest
if: "!contains(github.event.head_commit.message, 'no pep8')"
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.x"
- name: Install dependencies
run: |
python -m pip install --upgrade pip setuptools
pip install flake8
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 examples/ ot/ test/ --count --max-line-length=127 --statistics
linux-minimal-deps:

runs-on: ubuntu-latest
Expand Down
51 changes: 51 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
repos:
# Ruff skada
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.2
hooks:
- id: ruff
name: ruff lint
args: ["--fix"]
files: ^ot/
- id: ruff
name: ruff lint preview
args: ["--fix", "--preview", "--select=NPY201"]
files: ^ot/
- id: ruff
name: ruff lint doc, tutorials, tests and examples
# D103: missing docstring in public function
# D400: docstring first line must end with period
args: ["--ignore=D103,D400", "--fix"]
files: ^docs/|^examples/^test/
- id: ruff-format
files: ^ot/|^docs/|^examples/|

# Codespell
- repo: https://github.com/codespell-project/codespell
rev: v2.2.6
hooks:
- id: codespell
additional_dependencies:
- tomli
files: ^ot/|^docs/|^examples/
types_or: [python, bib, rst, inc]
args: [
"--ignore-words",
"ignore-words.txt",
]

# yamllint
- repo: https://github.com/adrienverge/yamllint.git
rev: v1.35.1
hooks:
- id: yamllint
# args: [--strict]

# # rstcheck
# - repo: https://github.com/rstcheck/rstcheck.git
# rev: v6.2.0
# hooks:
# - id: rstcheck
# additional_dependencies:
# - tomli
# files: ^docs/source/.*\.(rst|inc)$
10 changes: 10 additions & 0 deletions .yamllint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
extends: default

ignore: |
.github/workflows/*.yml
.circleci/config.yml
codecov.yml
rules:
line-length: disable
document-start: disable
19 changes: 16 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ POT provides the following generic OT solvers (links to examples):
* [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays.
* [Smooth Strongly Convex Nearest Brenier Potentials](https://pythonot.github.io/auto_examples/others/plot_SSNB.html#sphx-glr-auto-examples-others-plot-ssnb-py) [58], with an extension to bounding potentials using [59].
* Gaussian Mixture Model OT [69]
* [Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_COOT.html) [49] and
[unbalanced Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_learning_weights_with_COOT.html) [71].
* Fused unbalanced Gromov-Wasserstein [70].

POT provides the following Machine Learning related solvers:

Expand All @@ -62,7 +65,7 @@ POT provides the following Machine Learning related solvers:
* [Linear OT mapping](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_linear_mapping.html) [14] and [Joint OT mapping estimation](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_mapping.html) [8].
* [Wasserstein Discriminant Analysis](https://pythonot.github.io/auto_examples/others/plot_WDA.html) [11] (requires autograd + pymanopt).
* [JCPOT algorithm for multi-source domain adaptation with target shift](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_jcpot.html) [27].
* [Graph Neural Network OT layers TFGW](https://pythonot.github.io/auto_examples/gromov/plot_gnn_TFGW.html) [52] and TW (OT-GNN) [53]
* [Graph Neural Network OT layers TFGW](https://pythonot.github.io/auto_examples/gromov/plot_gnn_TFGW.html) [52] and TW (OT-GNN) [53]

Some other examples are available in the [documentation](https://pythonot.github.io/auto_examples/index.html).

Expand Down Expand Up @@ -198,7 +201,7 @@ This toolbox has been created by
* [Rémi Flamary](https://remi.flamary.com/)
* [Nicolas Courty](http://people.irisa.fr/Nicolas.Courty/)

It is currently maintained by
It is currently maintained by

* [Rémi Flamary](https://remi.flamary.com/)
* [Cédric Vincent-Cuaz](https://cedricvincentcuaz.github.io/)
Expand Down Expand Up @@ -370,4 +373,14 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil

[68] Chowdhury, S., Miller, D., & Needham, T. (2021). [Quantized gromov-wasserstein](https://link.springer.com/chapter/10.1007/978-3-030-86523-8_49). ECML PKDD 2021. Springer International Publishing.

[69] Delon, J., & Desolneux, A. (2020). [A Wasserstein-type distance in the space of Gaussian mixture models](https://epubs.siam.org/doi/abs/10.1137/19M1301047). SIAM Journal on Imaging Sciences, 13(2), 936-970.
[69] Delon, J., & Desolneux, A. (2020). [A Wasserstein-type distance in the space of Gaussian mixture models](https://epubs.siam.org/doi/abs/10.1137/19M1301047). SIAM Journal on Imaging Sciences, 13(2), 936-970.

[70] A. Thual, H. Tran, T. Zemskova, N. Courty, R. Flamary, S. Dehaene
& B. Thirion (2022). [Aligning individual brains with Fused Unbalanced Gromov-Wasserstein.](https://proceedings.neurips.cc/paper_files/paper/2022/file/8906cac4ca58dcaf17e97a0486ad57ca-Paper-Conference.pdf). Neural Information Processing Systems (NeurIPS).

[71] H. Tran, H. Janati, N. Courty, R. Flamary, I. Redko, P. Demetci & R. Singh (2023). [Unbalanced Co-Optimal Transport](https://dl.acm.org/doi/10.1609/aaai.v37i8.26193). AAAI Conference on
Artificial Intelligence.

[72] Thibault Séjourné, François-Xavier Vialard, and Gabriel Peyré (2021). [The Unbalanced Gromov Wasserstein Distance: Conic Formulation and Relaxation](https://proceedings.neurips.cc/paper/2021/file/4990974d150d0de5e6e15a1454fe6b0f-Paper.pdf). Neural Information Processing Systems (NeurIPS).

[73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). [Faster Unbalanced Optimal Transport: Translation Invariant Sinkhorn and 1-D Frank-Wolfe](https://proceedings.mlr.press/v151/sejourne22a.html). In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR.
32 changes: 20 additions & 12 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,24 @@

## 0.9.5dev

#### Breaking change
- Custom functions provided as parameter `line_search` to `ot.optim.generic_conditional_gradient` must now have the signature `line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs)`, adding as input `df_G` the gradient of the regularizer evaluated at the transport plan `G`. This change aims at improving speed of solvers having quadratic polynomial functions as regularizer such as the Gromov-Wassertein loss (PR #663).

#### New features
- Add feature `mass=True` for `nx.kl_div` (PR #654)
- Gaussian Mixture Model OT `ot.gmm` (PR #649)
- Add feature `semirelaxed_fgw_barycenters` and generic FGW-related barycenter updates `update_barycenter_structure` and `update_barycenter_feature` (PR #659)
- Add initialization heuristics for sr(F)GW problems via `semirelaxed_init_plan`, integrated in all sr(F)GW solvers (PR #659)
- New linter based on pre-commit using ruff, codespell and yamllint (PR #681)
- Added feature `mass=True` for `nx.kl_div` (PR #654)
- Implemented Gaussian Mixture Model OT `ot.gmm` (PR #649)
- Added feature `semirelaxed_fgw_barycenters` and generic FGW-related barycenter updates `update_barycenter_structure` and `update_barycenter_feature` (PR #659)
- Added initialization heuristics for sr(F)GW problems via `semirelaxed_init_plan`, integrated in all sr(F)GW solvers (PR #659)
- Improved `ot.plot.plot1D_mat` (PR #649)
- Added `nx.det` (PR #649)
- `nx.sqrtm` is now broadcastable (takes ..., d, d) inputs (PR #649)
- restructure `ot.unbalanced` module (PR #658)
- add `ot.unbalanced.lbfgsb_unbalanced2` and add flexible reference measure `c` in all unbalanced solvers (PR #658)
- Restructured `ot.unbalanced` module (PR #658)
- Added `ot.unbalanced.lbfgsb_unbalanced2` and add flexible reference measure `c` in all unbalanced solvers (PR #658)
- Implemented Fused unbalanced Gromov-Wasserstein and unbalanced Co-Optimal Transport (PR #677)
- Notes before depreciating partial Gromov-Wasserstein function in `ot.partial` moved to ot.gromov (PR #663)
- Create `ot.gromov._partial` add new features `loss_fun = "kl_loss"` and `symmetry=False` to all solvers while increasing speed + updating adequatly `ot.solvers` (PR #663)
- Added `ot.unbalanced.sinkhorn_unbalanced_translation_invariant` (PR #676)

#### Closed issues
- Fixed `ot.gaussian` ignoring weights when computing means (PR #649, Issue #648)
Expand Down Expand Up @@ -72,7 +80,7 @@ xs, xt = np.random.randn(100, 2), np.random.randn(50, 2)

# Solve OT problem with empirical samples
sol = ot.solve_sample(xs, xt) # Exact OT betwen smaples with uniform weights
sol = ot.solve_sample(xs, xt, wa, wb) # Exact OT with weights given by user
sol = ot.solve_sample(xs, xt, wa, wb) # Exact OT with weights given by user

sol = ot.solve_sample(xs, xt, reg= 1, metric='euclidean') # sinkhorn with euclidean metric

Expand All @@ -84,15 +92,15 @@ sol = ot.solve_sample(x,x2, method='lowrank', rank=10) # compute lowrank sinkhor

value_bw = ot.solve_sample(xs, xt, method='gaussian').value # Bures-Wasserstein distance

# Solve GW problem
# Solve GW problem
Cs, Ct = ot.dist(xs, xs), ot.dist(xt, xt) # compute cost matrices
sol = ot.solve_gromov(Cs,Ct) # Exact GW between samples with uniform weights

# Solve FGW problem
M = ot.dist(xs, xt) # compute cost matrix

# Exact FGW between samples with uniform weights
sol = ot.solve_gromov(Cs, Ct, M, loss='KL', alpha=0.7) # FGW with KL data fitting
sol = ot.solve_gromov(Cs, Ct, M, loss='KL', alpha=0.7) # FGW with KL data fitting


# recover solutions objects
Expand All @@ -102,14 +110,14 @@ value = sol.value # OT value

# for GW and FGW
value_linear = sol.value_linear # linear part of the loss
value_quad = sol.value_quad # quadratic part of the loss
value_quad = sol.value_quad # quadratic part of the loss

```

Users are encouraged to use the new API (it is much simpler) but it might still be subjects to small changes before the release of POT 1.0 .


We also fixed a number of issues, the most pressing being a problem of GPU memory allocation when pytorch is installed that will not happen now thanks to Lazy initialization of the backends. We now also have the possibility to deactivate some backends using environment which prevents POT from importing them and can lead to large import speedup.
We also fixed a number of issues, the most pressing being a problem of GPU memory allocation when pytorch is installed that will not happen now thanks to Lazy initialization of the backends. We now also have the possibility to deactivate some backends using environment which prevents POT from importing them and can lead to large import speedup.


#### New features
Expand Down Expand Up @@ -143,7 +151,7 @@ We also fixed a number of issues, the most pressing being a problem of GPU memor
- Correct independence of `fgw_barycenters` to `init_C` and `init_X` (Issue #547, PR #566)
- Avoid precision change when computing norm using PyTorch backend (Discussion #570, PR #572)
- Create `ot/bregman/`repository (Issue #567, PR #569)
- Fix matrix feature shape in `entropic_fused_gromov_barycenters`(Issue #574, PR #573)
- Fix matrix feature shape in `entropic_fused_gromov_barycenters`(Issue #574, PR #573)
- Fix (fused) gromov-wasserstein barycenter solvers to support `kl_loss`(PR #576)


Expand Down
12 changes: 6 additions & 6 deletions benchmarks/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
def setup_backends():
if jax:
from jax.config import config

config.update("jax_enable_x64", True)

if tf:
from tensorflow.python.ops.numpy_ops import np_config

np_config.enable_numpy_behavior()


Expand All @@ -36,10 +38,7 @@ def exec_bench(setup, tested_function, param_list, n_runs, warmup_runs):
print(nx, param_list[i])
args = inputs[i]
results_nx = nx._bench(
tested_function,
*args,
n_runs=n_runs,
warmup_runs=warmup_runs
tested_function, *args, n_runs=n_runs, warmup_runs=warmup_runs
)
gc.collect()
results_nx_with_param_in_key = dict()
Expand All @@ -64,10 +63,11 @@ def convert_to_html_table(results, param_name, main_title=None, comments=None):
assert cpus_cols + gpus_cols == len(devices_names)

if main_title is not None:
string += f'<tr><th align="center" colspan="{length}">{str(main_title)}</th></tr>\n'
string += (
f'<tr><th align="center" colspan="{length}">{str(main_title)}</th></tr>\n'
)

for i, bitsize in enumerate(bitsizes):

if i != 0:
string += f'<tr><td colspan="{length}">&nbsp;</td></tr>\n'

Expand Down
20 changes: 9 additions & 11 deletions benchmarks/emd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@

import numpy as np
import ot
from .benchmark import (
setup_backends,
exec_bench,
convert_to_html_table
)
from .benchmark import setup_backends, exec_bench, convert_to_html_table


def setup(n_samples):
Expand All @@ -31,10 +27,12 @@ def setup(n_samples):
tested_function=lambda a, M: ot.emd(a, a, M),
param_list=param_list,
n_runs=n_runs,
warmup_runs=warmup_runs
warmup_runs=warmup_runs,
)
print(
convert_to_html_table(
results,
param_name="Sample size",
main_title=f"EMD - Averaged on {n_runs} runs",
)
)
print(convert_to_html_table(
results,
param_name="Sample size",
main_title=f"EMD - Averaged on {n_runs} runs"
))
Loading

0 comments on commit a94c6ac

Please sign in to comment.