Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add KeOps MMD detector #548

Merged
merged 53 commits into from
Aug 19, 2022
Merged

Add KeOps MMD detector #548

merged 53 commits into from
Aug 19, 2022

Conversation

arnaudvl
Copy link
Contributor

@arnaudvl arnaudvl commented Jul 6, 2022

Add MMD detector using the KeOps (PyTorch) backend to further accelerate drift detection and scale up to larger datasets. This PR needs to be made compatible with the optional dependency management (incl. #538 and related).

This PR includes:

  • MMD detector implementation using KeOps
  • GaussianRBF kernel using KeOps
  • Tests
  • Docs
  • Basic benchmarking example vs. PyTorch MMD
  • Add a note to docs regarding lack of Windows support.
  • Investigate segfault with MacOS, or drop support for now.
  • Document sigma_mean vs. sigma_median and make foolproof.
  • Update keops infer_sigma check.
  • Update docstrings keops kernels to clarify various dims options + clarify within the forward pass.
  • Clarify GPU requirements and prettify example.
  • Document logic keops kernels more explicitly.
  • Fully compatible tests with torch and tensorflow.
  • Unit test _mmd2 -> check if results match that of the PyTorch implementation.
  • Exception -> error type in keops test.
  • Test sigma_mean for both "usual" (non-batch) and batch setting (unusual and should probably use the first batch entry since it corresponds to the original (x, y)).

Once this PR is merged, it will be followed up by a similar implementation for the Learned (Deep) Kernel detector.

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@arnaudvl arnaudvl changed the title Add KeOps MMD detector WIP: Add KeOps MMD detector Jul 8, 2022
@arnaudvl arnaudvl requested review from ascillitoe and jklaise July 8, 2022 13:50
@arnaudvl arnaudvl changed the title WIP: Add KeOps MMD detector [WIP] Add KeOps MMD detector Jul 8, 2022
@arnaudvl arnaudvl changed the title [WIP] Add KeOps MMD detector Add KeOps MMD detector Jul 8, 2022
@ascillitoe
Copy link
Contributor

ascillitoe commented Jul 14, 2022

@arnaudvl I shall resolve these conflicts and then review once we have #537 merged.

I'm also adding @mauicv for review specifically to check my additions wrt to optional dependency handling (once I've added!)

@ascillitoe ascillitoe requested a review from mauicv July 14, 2022 09:17
@ascillitoe
Copy link
Contributor

Additional note: Need to check suitable error is raised when passed to save_detector. Implementing save/load functionality can be left to a future PR.

Copy link
Contributor

@jklaise jklaise left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please refer to my comments.

Copy link
Contributor

@ascillitoe ascillitoe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally looks very nice!

Just a few minor comments on my end, plus I have not investigated the final task of testing sigma_mean in batch and non-batch settings.

README.md Show resolved Hide resolved
doc/source/cd/methods/mmddrift.ipynb Outdated Show resolved Hide resolved
alibi_detect/cd/keops/tests/test_mmd_keops.py Show resolved Hide resolved
Comment on lines +73 to +75
batch_size_permutations
KeOps computes the n_permutations of the MMD^2 statistics in chunks of batch_size_permutations.
Only relevant for 'keops' backend.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should open an issue for this?

elif backend == 'pytorch' and has_pytorch:
pop_kwargs += ['batch_size_permutations']
detector = MMDDriftTorch
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mistakingly opened a new issue #576 (comment) for this as thought these "if statements" were already present. Since they are actually added here (for the new pop_kwargs bit), it would make more sense to fix here.

The has_tensorflow and has_pytorch are unnecessary as BackendValidator should have already raised an error if backend='tensorflow' and has_tensorflow=False, or the PyTorch equivalent.

@@ -89,7 +89,7 @@ def __init__(
# initialize kernel
sigma = torch.from_numpy(sigma).to(self.device) if isinstance(sigma, # type: ignore[assignment]
np.ndarray) else None
self.kernel = kernel(sigma) if kernel == GaussianRBF else kernel
self.kernel = kernel(sigma).to(self.device) if kernel == GaussianRBF else kernel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure we actually fixed this?? I have opened an issue (#586) and will fix it tomorrow.

alibi_detect/utils/frameworks.py Show resolved Hide resolved
return self.log_sigma.exp()

def forward(self, x: LazyTensor, y: LazyTensor, infer_sigma: bool = False) -> LazyTensor:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with opening an issue since this applies to all kernels (even if the issue is just to review docstring conventions with the conclusion being keep-as-is!).

@codecov-commenter
Copy link

codecov-commenter commented Aug 16, 2022

Codecov Report

❗ No coverage uploaded for pull request base (master@ed519e3). Click here to learn what that means.
The diff coverage is n/a.

Impacted file tree graph

@@            Coverage Diff            @@
##             master     #548   +/-   ##
=========================================
  Coverage          ?   83.51%           
=========================================
  Files             ?      207           
  Lines             ?    13777           
  Branches          ?        0           
=========================================
  Hits              ?    11506           
  Misses            ?     2271           
  Partials          ?        0           

Copy link
Contributor

@jklaise jklaise left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@arnaudvl arnaudvl merged commit 705b718 into SeldonIO:master Aug 19, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants