-
Notifications
You must be signed in to change notification settings - Fork 225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add KeOps MMD detector #548
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Additional note: Need to check suitable error is raised when passed to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please refer to my comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally looks very nice!
Just a few minor comments on my end, plus I have not investigated the final task of testing sigma_mean
in batch and non-batch settings.
batch_size_permutations | ||
KeOps computes the n_permutations of the MMD^2 statistics in chunks of batch_size_permutations. | ||
Only relevant for 'keops' backend. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should open an issue for this?
elif backend == 'pytorch' and has_pytorch: | ||
pop_kwargs += ['batch_size_permutations'] | ||
detector = MMDDriftTorch | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mistakingly opened a new issue #576 (comment) for this as thought these "if statements" were already present. Since they are actually added here (for the new pop_kwargs
bit), it would make more sense to fix here.
The has_tensorflow
and has_pytorch
are unnecessary as BackendValidator
should have already raised an error if backend='tensorflow'
and has_tensorflow=False
, or the PyTorch equivalent.
@@ -89,7 +89,7 @@ def __init__( | |||
# initialize kernel | |||
sigma = torch.from_numpy(sigma).to(self.device) if isinstance(sigma, # type: ignore[assignment] | |||
np.ndarray) else None | |||
self.kernel = kernel(sigma) if kernel == GaussianRBF else kernel | |||
self.kernel = kernel(sigma).to(self.device) if kernel == GaussianRBF else kernel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure we actually fixed this?? I have opened an issue (#586) and will fix it tomorrow.
return self.log_sigma.exp() | ||
|
||
def forward(self, x: LazyTensor, y: LazyTensor, infer_sigma: bool = False) -> LazyTensor: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with opening an issue since this applies to all kernels (even if the issue is just to review docstring conventions with the conclusion being keep-as-is!).
Codecov Report
@@ Coverage Diff @@
## master #548 +/- ##
=========================================
Coverage ? 83.51%
=========================================
Files ? 207
Lines ? 13777
Branches ? 0
=========================================
Hits ? 11506
Misses ? 2271
Partials ? 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Add MMD detector using the KeOps (PyTorch) backend to further accelerate drift detection and scale up to larger datasets. This PR needs to be made compatible with the optional dependency management (incl. #538 and related).
This PR includes:
sigma_mean
vs.sigma_median
and make foolproof.infer_sigma
check.torch
andtensorflow
._mmd2
-> check if results match that of the PyTorch implementation.sigma_mean
for both "usual" (non-batch) and batch setting (unusual and should probably use the first batch entry since it corresponds to the original(x, y)
).Once this PR is merged, it will be followed up by a similar implementation for the Learned (Deep) Kernel detector.