Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for serialising detectors with keops backends #681

Merged
merged 9 commits into from
Jan 3, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- **New feature** MMD drift detector has been extended with a [KeOps](https://www.kernel-operations.io/keops/index.html) backend to scale and speed up the detector.
See the [documentation](https://docs.seldon.io/projects/alibi-detect/en/latest/cd/methods/mmddrift.html) and [example notebook](https://docs.seldon.io/projects/alibi-detect/en/latest/examples/cd_mmd_keops.html) for more info ([#548](https://github.com/SeldonIO/alibi-detect/pull/548)).
- **New feature** Added support for serializing detectors with PyTorch backends, and detectors containing PyTorch models in their proprocessing functions ([#656](https://github.com/SeldonIO/alibi-detect/pull/656)).
- **New feature** Added support for serializing detectors with KeOps backends ([#681](https://github.com/SeldonIO/alibi-detect/pull/681)).
- **New feature** Added a PyTorch version of the `UAE` preprocessing utility function ([#656](https://github.com/SeldonIO/alibi-detect/pull/656)).
- If a `categories_per_feature` dictionary is not passed to `TabularDrift`, a warning is now raised to inform the user that all features are assumed to be numerical ([#606](https://github.com/SeldonIO/alibi-detect/pull/606)).
- For the `ClassifierDrift` and `SpotTheDiffDrift` detectors, we can also return the out-of-fold instances of the reference and test sets. When using `train_size` for training the detector, this allows to associate the returned prediction probabilities with the correct instances.
Expand Down
9 changes: 9 additions & 0 deletions alibi_detect/saving/_keops/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from alibi_detect.utils.missing_optional_dependency import import_optional

load_kernel_config_ke = import_optional(
'alibi_detect.saving._keops.loading',
names=['load_kernel_config'])

__all__ = [
"load_kernel_config_ke",
]
37 changes: 37 additions & 0 deletions alibi_detect/saving/_keops/loading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import Callable
from alibi_detect.utils.keops.kernels import DeepKernel


def load_kernel_config(cfg: dict) -> Callable:
"""
Loads a kernel from a kernel config dict.

Parameters
----------
cfg
A kernel config dict. (see pydantic schema's).

Returns
-------
The kernel.
"""
if 'src' in cfg: # Standard kernel config
kernel = cfg.pop('src')
if hasattr(kernel, 'from_config'):
kernel = kernel.from_config(cfg)

elif 'proj' in cfg: # DeepKernel config
# Kernel a
kernel_a = cfg['kernel_a']
kernel_b = cfg['kernel_b']
if kernel_a != 'rbf':
cfg['kernel_a'] = load_kernel_config(kernel_a)
if kernel_b != 'rbf':
cfg['kernel_b'] = load_kernel_config(kernel_b)
# Assemble deep kernel
kernel = DeepKernel.from_config(cfg)

else:
raise ValueError('Unable to process kernel. The kernel config dict must either be a `KernelConfig` with a '
'`src` field, or a `DeepkernelConfig` with a `proj` field.)')
return kernel
10 changes: 4 additions & 6 deletions alibi_detect/saving/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
load_model_tf, load_optimizer_tf, prep_model_and_emb_tf, get_tf_dtype
from alibi_detect.saving._pytorch import load_embedding_pt, load_kernel_config_pt, load_model_pt, \
load_optimizer_pt, prep_model_and_emb_pt, get_pt_dtype
from alibi_detect.saving._keops import load_kernel_config_ke
from alibi_detect.saving._sklearn import load_model_sk
from alibi_detect.saving.validate import validate_config
from alibi_detect.base import Detector, ConfigurableDetector
Expand Down Expand Up @@ -135,11 +136,6 @@ def _load_detector_config(filepath: Union[str, os.PathLike]) -> ConfigurableDete
cfg = validate_config(cfg, resolved=True)
logger.info('Validated resolved config.')

# Backend
backend = cfg.get('backend')
if backend is not None and backend.lower() not in (Framework.TENSORFLOW, Framework.PYTORCH, Framework.SKLEARN):
raise NotImplementedError('Loading detectors with keops backend is not yet supported.')

# Init detector from config
logger.info('Instantiating detector.')
detector = _init_detector(cfg)
Expand Down Expand Up @@ -186,8 +182,10 @@ def _load_kernel_config(cfg: dict, backend: str = Framework.TENSORFLOW) -> Calla
"""
if backend == Framework.TENSORFLOW:
kernel = load_kernel_config_tf(cfg)
else:
elif backend == Framework.PYTORCH:
kernel = load_kernel_config_pt(cfg)
else: # backend=='keops'
kernel = load_kernel_config_ke(cfg)
ascillitoe marked this conversation as resolved.
Show resolved Hide resolved
return kernel


Expand Down
10 changes: 9 additions & 1 deletion alibi_detect/saving/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def my_function(x: np.ndarray) -> np.ndarray:

import catalogue

from alibi_detect.utils.frameworks import has_pytorch, has_tensorflow
from alibi_detect.utils.frameworks import has_pytorch, has_tensorflow, has_keops

if has_tensorflow:
from alibi_detect.cd.tensorflow import \
Expand All @@ -52,6 +52,10 @@ def my_function(x: np.ndarray) -> np.ndarray:
GaussianRBF as GaussianRBF_torch, sigma_median as sigma_median_torch
from alibi_detect.cd.pytorch.context_aware import _sigma_median_diag as _sigma_median_diag_torch

if has_keops:
from alibi_detect.utils.keops.kernels import \
GaussianRBF as GaussianRBF_keops, sigma_mean as sigma_mean_keops

# Create registry
registry = catalogue.create("alibi_detect", "registry")

Expand All @@ -68,3 +72,7 @@ def my_function(x: np.ndarray) -> np.ndarray:
registry.register('utils.pytorch.kernels.sigma_median', func=sigma_median_torch)
registry.register('cd.pytorch.context_aware._sigma_median_diag', func=_sigma_median_diag_torch)
registry.register('cd.pytorch.preprocess.preprocess_drift', func=preprocess_drift_torch)

if has_keops:
registry.register('utils.keops.kernels.GaussianRBF', func=GaussianRBF_keops)
registry.register('utils.keops.kernels.sigma_mean', func=sigma_mean_keops)
10 changes: 1 addition & 9 deletions alibi_detect/saving/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from alibi_detect.saving.registry import registry
from alibi_detect.utils._types import supported_models_all, supported_models_tf, supported_models_torch, \
supported_models_sklearn
from alibi_detect.utils.frameworks import Framework
from alibi_detect.base import Detector, ConfigurableDetector
from alibi_detect.saving._tensorflow import save_detector_legacy, save_model_config_tf, save_optimizer_config_tf
from alibi_detect.saving._pytorch import save_model_config_pt
Expand Down Expand Up @@ -53,9 +52,6 @@ def save_detector(
if legacy:
warnings.warn('The `legacy` option will be removed in a future version.', DeprecationWarning)

if 'backend' in list(detector.meta.keys()) and detector.meta['backend'] == Framework.KEOPS:
raise NotImplementedError('Saving detectors with keops backend is not yet supported.')

# TODO: Replace .__args__ w/ typing.get_args() once Python 3.7 dropped (and remove type ignore below)
detector_name = detector.__class__.__name__
if detector_name not in [detector for detector in VALID_DETECTORS]:
Expand Down Expand Up @@ -129,11 +125,7 @@ def _save_detector_config(detector: ConfigurableDetector, filepath: Union[str, o
filepath
File path to save serialized artefacts to.
"""
# Get backend, input_shape and detector_name
backend = detector.meta.get('backend')
if backend not in (None, Framework.TENSORFLOW, Framework.PYTORCH, Framework.SKLEARN):
raise NotImplementedError("Currently, saving is only supported with backend='tensorflow', 'pytorch', and "
"'sklearn'.")
# detector name
detector_name = detector.__class__.__name__

# Process file paths
Expand Down
2 changes: 1 addition & 1 deletion alibi_detect/saving/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ class KernelConfig(CustomBaseModelWithKwargs):
"A string referencing a filepath to a serialized kernel in `.dill` format, or an object registry reference."

# Below kwargs are only passed if kernel == @GaussianRBF
flavour: Literal['tensorflow', 'pytorch']
flavour: Literal['tensorflow', 'pytorch', 'keops']
"""
Whether the kernel is a `tensorflow` or `pytorch` kernel.
"""
Expand Down
34 changes: 23 additions & 11 deletions alibi_detect/saving/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
from alibi_detect.models.tensorflow import TransformerEmbedding as TransformerEmbedding_tf
from alibi_detect.cd.pytorch import HiddenOutput as HiddenOutput_pt
from alibi_detect.cd.tensorflow import HiddenOutput as HiddenOutput_tf
from alibi_detect.utils.frameworks import has_keops
if has_keops:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this conditional purely because we skip keops on some platforms in CI? Should we add a comment for this? Same question for the other test module.

Copy link
Contributor Author

@ascillitoe ascillitoe Nov 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep precisely. I'll add notes to both modules now!

Edit: Done in 5bb7314

from alibi_detect.utils.keops.kernels import GaussianRBF as GaussianRBF_ke
from alibi_detect.utils.keops.kernels import DeepKernel as DeepKernel_ke

LATENT_DIM = 2 # Must be less than input_dim set in ./datasets.py
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
Expand All @@ -46,7 +50,7 @@ def encoder_model(backend, current_cases):
tf.keras.layers.Dense(LATENT_DIM, activation=None)
]
)
elif backend == 'pytorch':
elif backend in ('pytorch', 'keops'):
model = nn.Sequential(nn.Linear(input_dim, 5),
nn.ReLU(),
nn.Linear(5, LATENT_DIM))
Expand Down Expand Up @@ -74,7 +78,7 @@ def encoder_dropout_model(backend, current_cases):
tf.keras.layers.Dense(LATENT_DIM, activation=None)
]
)
elif backend == 'pytorch':
elif backend in ('pytorch', 'keops'):
model = nn.Sequential(nn.Linear(input_dim, 5),
nn.ReLU(),
nn.Dropout(0.0), # 0.0 to ensure determinism
Expand Down Expand Up @@ -115,8 +119,12 @@ def kernel(request, backend):
if sigma is not None and not isinstance(sigma, torch.Tensor):
sigma = torch.tensor(sigma)
kernel = GaussianRBF_pt(sigma=sigma, **kernel_cfg)
elif backend == 'keops':
if sigma is not None and not isinstance(sigma, torch.Tensor):
sigma = torch.tensor(sigma)
kernel = GaussianRBF_ke(sigma=sigma, **kernel_cfg)
else:
pytest.skip('`kernel` only implemented for tensorflow and pytorch.')
pytest.skip('`kernel` only implemented for tensorflow, pytorch and keops.')
return kernel


Expand All @@ -129,8 +137,8 @@ def optimizer(request, backend):
the optimizer is a `torch.optim.Optimizer` class (NOT instantiated).
"""
optimizer = request.param # Get parametrized setting
if backend not in ('tensorflow', 'pytorch'):
pytest.skip('`optimizer` only implemented for tensorflow and pytorch.')
if backend not in ('tensorflow', 'pytorch', 'keops'):
pytest.skip('`optimizer` only implemented for tensorflow, pytorch and keops.')
if isinstance(optimizer, str):
module = 'tensorflow.keras.optimizers' if backend == 'tensorflow' else 'torch.optim'
try:
Expand Down Expand Up @@ -163,6 +171,10 @@ def deep_kernel(request, backend, encoder_model):
kernel_a = GaussianRBF_pt(**kernel_a) if isinstance(kernel_a, dict) else kernel_a
kernel_b = GaussianRBF_pt(**kernel_b) if isinstance(kernel_b, dict) else kernel_b
deep_kernel = DeepKernel_pt(proj, kernel_a=kernel_a, kernel_b=kernel_b, eps=eps)
elif backend == 'keops':
kernel_a = GaussianRBF_ke(**kernel_a) if isinstance(kernel_a, dict) else kernel_a
kernel_b = GaussianRBF_ke(**kernel_b) if isinstance(kernel_b, dict) else kernel_b
deep_kernel = DeepKernel_ke(proj, kernel_a=kernel_a, kernel_b=kernel_b, eps=eps)
else:
pytest.skip('`deep_kernel` only implemented for tensorflow and pytorch.')
return deep_kernel
Expand All @@ -182,13 +194,13 @@ def classifier_model(backend, current_cases):
tf.keras.layers.Dense(2, activation=tf.nn.softmax),
]
)
elif backend == 'pytorch':
elif backend in ('pytorch', 'keops'):
model = nn.Sequential(nn.Linear(input_dim, 2),
nn.Softmax(1))
elif backend == 'sklearn':
model = RandomForestClassifier()
else:
pytest.skip('`classifier_model` only implemented for tensorflow, pytorch, and sklearn.')
pytest.skip('`classifier_model` only implemented for tensorflow, pytorch, keops and sklearn.')
return model


Expand Down Expand Up @@ -259,12 +271,12 @@ def preprocess_nlp(embedding, tokenizer, max_len, backend):
if backend == 'tensorflow':
preprocess_fn = partial(preprocess_drift_tf, model=embedding, tokenizer=tokenizer,
max_len=max_len, preprocess_batch_fn=preprocess_simple)
elif backend == 'pytorch':
elif backend in ('pytorch', 'keops'):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
preprocess_fn = partial(preprocess_drift_pt, model=embedding, tokenizer=tokenizer, max_len=max_len,
preprocess_batch_fn=preprocess_simple, device=device)
else:
pytest.skip('`preprocess_nlp` only implemented for tensorflow and pytorch.')
pytest.skip('`preprocess_nlp` only implemented for tensorflow, pytorch and keops.')
return preprocess_fn


Expand All @@ -279,10 +291,10 @@ def preprocess_hiddenoutput(classifier_model, current_cases, backend):
if backend == 'tensorflow':
model = HiddenOutput_tf(classifier_model, layer=-1, input_shape=(None, input_dim))
preprocess_fn = partial(preprocess_drift_tf, model=model)
elif backend == 'pytorch':
elif backend in ('pytorch', 'keops'):
model = HiddenOutput_pt(classifier_model, layer=-1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
preprocess_fn = partial(preprocess_drift_pt, model=model, device=device)
else:
pytest.skip('`preprocess_hiddenoutput` only implemented for tensorflow and pytorch.')
pytest.skip('`preprocess_hiddenoutput` only implemented for tensorflow, pytorch and keops.')
return preprocess_fn
Loading