Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set x_ref_preprocessed=True during legacy loading (v0.10.5 patch) #732

Merged
merged 3 commits into from
Jan 25, 2023
Merged

Set x_ref_preprocessed=True during legacy loading (v0.10.5 patch) #732

merged 3 commits into from
Jan 25, 2023

Conversation

ascillitoe
Copy link
Contributor

@ascillitoe ascillitoe commented Jan 25, 2023

This PR fixes a second bug in the legacy loading behavior (first one here), used when loading detectors saved in alibi-detect < v0.10 or with save_detector(..., legacy=True) in newer versions.

In short, the bug was due to the change in kwarg's related to preprocessing of x_ref being changed in v0.10. This was reflected in the functions that set the state_dict used for legacy saving (e.g. state_ksdrift), so legacy save/load within v0.10 works fine. However, the state_dict of detectors saved in older versions would of course contain old kwarg's with different meanings. The result was x_ref being incorrect preprocessed for a second time during the loading of detectors saved in <v0.10.

Background

In the old legacy save/load code, a state_dict is defined in state_ksdrift etc:

state_dict = {
'args':
{
'x_ref': cd.x_ref
},
'kwargs':
{
'p_val': cd.p_val,
'preprocess_x_ref': False,
'update_x_ref': cd.update_x_ref,
'correction': cd.correction,
'alternative': cd.alternative,
'n_features': cd.n_features,
'input_shape': cd.input_shape,
},
'other':
{
'n': cd.n,
'preprocess_x_ref': cd.preprocess_x_ref,
'load_text_embedding': load_emb,
'preprocess_fn': preprocess_fn,
'preprocess_kwargs': preprocess_kwargs
}
}

At load time, in init_cd_ksdrift etc the nested kwargs dict is used to instantiate the detector, whilst selected values from the others dict are used to update attributes after instantiation.

preprocess_fn, preprocess_kwargs = init_preprocess(state_dict['other'], model, emb, tokenizer, **kwargs)
if isinstance(preprocess_fn, Callable) and isinstance(preprocess_kwargs, dict):
state_dict['kwargs'].update({'preprocess_fn': partial(preprocess_fn, **preprocess_kwargs)})
cd = KSDrift(*list(state_dict['args'].values()), **state_dict['kwargs'])
attrs = state_dict['other']
cd.n = attrs['n']
cd.preprocess_x_ref = attrs['preprocess_x_ref']
return cd

In <v0.10 the preprocess_x_ref kwarg (which controls whether x_ref is preprocessed at init) was handled in this way, to ensure that the already preprocessed x_ref was not preprocessed again when a detector was loaded (it is set to False at instantiation and then updated to its real value afterward).

In >=v0.10 this behavior is updated. preprocess_x_ref is renamed preprocess_at_init, and x_ref_preprocessed is introduced. Therefore, whether x_ref has been preprocessed (and thus it shouldn't be preprocessed again) is controlled by x_ref_preprocessed, and preprocess_at_init only controls whether to preprocess at init or predict time. Since the preprocessing behavior is now fully controlled by the detector's kwargs (with no adhoc editing of attributes), the setting of the state_dict is altered in state_ksdrift etc:

state_dict = {
'args':
{
'x_ref': cd.x_ref
},
'kwargs':
{
'p_val': cd.p_val,
'x_ref_preprocessed': True,
'preprocess_at_init': cd.preprocess_at_init,
'update_x_ref': cd.update_x_ref,
'correction': cd.correction,
'alternative': cd.alternative,
'n_features': cd.n_features,
'input_shape': cd.input_shape,
},
'other':
{
'n': cd.n,
'load_text_embedding': load_emb,
'preprocess_fn': preprocess_fn,
'preprocess_kwargs': preprocess_kwargs
}
}

Then, at load time, the preprocess_x_ref attribute no longer needs to be updated:

preprocess_fn, preprocess_kwargs = init_preprocess(state_dict['other'], model, emb, tokenizer, **kwargs)
if callable(preprocess_fn) and isinstance(preprocess_kwargs, dict):
state_dict['kwargs'].update({'preprocess_fn': partial(preprocess_fn, **preprocess_kwargs)})
cd = KSDrift(*list(state_dict['args'].values()), **state_dict['kwargs'])
attrs = state_dict['other']
cd.n = attrs['n']
return cd

Problem

The state_ksdrift and init_cd_ksdrift functions were updated in v0.10 to reflect the new preprocessing behavior, meaning legacy save/load within >v0.10 versions works fine (hence the CI passes). However, old state_dict's generated with <v0.10 do not contain a x_ref_preprocessed kwarg, and the preprocess_at_init (still called preprocess_x_ref) kwarg is in other instead of kwargs (because it was to be updated after instantiation). The result is that when loading detectors generated in old versions, they are loaded with the default x_ref_preprocessed=False and preprocess_at_init=True, meaning the already preprocessed x_ref is preprocessed again.

Fix

The proposed fix is to first identify if the state_dict is created in an old version by checking that x_ref_preprocessed (doesn't) exist:

if 'x_ref_preprocessed' not in state_dict['kwargs']:  # if already exists then must have been saved w/ >=v0.10

If the state_dict is "old", then:

            state_dict['kwargs']['x_ref_preprocessed'] = True
            # Move `preprocess_x_ref` from `other` to `kwargs`
            state_dict['kwargs']['preprocess_x_ref'] = state_dict['other']['preprocess_x_ref']
  • x_ref_preprocessed is set to True (as is done in the updated state_ksdrift etc).
  • preprocess_x_ref is moved from others to kwargs so that it is passed as a kwarg at instantiation instead of the attribute being updated afterward. (The name change to preprocess_at_init doesn't matter since we use the decorator @deprecated_alias(preprocess_x_ref='preprocess_at_init') for all relevant detectors).

@ascillitoe ascillitoe requested review from jklaise and mauicv January 25, 2023 16:00
@ascillitoe
Copy link
Contributor Author

The test script has been updated from #732 to actually run some sample predictions, because simply checking that loading passed missed this bug:

import numpy as np
import alibi_detect
import transformers
from alibi_detect.saving import load_detector
from alibi_detect.utils.fetching.fetching import fetch_detector, fetch_state_dict, fetch_tf_model
from alibi_detect.datasets import fetch_kdd
from alibi_detect.utils.data import create_outlier_batch
from alibi_detect.utils.url import _join_url
import tensorflow as tf
import pytest
import itertools
from cloudpathlib import CloudPath


TESTS = [
    # Outlier detectors
    {'detector_type': 'outlier', 'detector_name': 'IForest', 'dataset': ['kddcup']},
    {'detector_type': 'outlier', 'detector_name': 'LLR', 'dataset': ['fashion_mnist', 'genome']},
    {'detector_type': 'outlier', 'detector_name': 'Mahalanobis', 'dataset': ['kddcup']},
    {'detector_type': 'outlier', 'detector_name': 'OutlierAE', 'dataset': ['cifar10']},
    {'detector_type': 'outlier', 'detector_name': 'OutlierAEGMM', 'dataset': ['kddcup']},
    {'detector_type': 'outlier', 'detector_name': 'OutlierProphet', 'dataset': ['weather']},
    {'detector_type': 'outlier', 'detector_name': 'OutlierSeq2Seq', 'dataset': ['ecg']},
    {'detector_type': 'outlier', 'detector_name': 'OutlierVAE', 'dataset': ['adult', 'cifar10', 'kddcup']},
    {'detector_type': 'outlier', 'detector_name': 'OutlierVAEGMM', 'dataset': ['kddcup']},
    # Adversarial detectors
    {'detector_type': 'adversarial', 'detector_name': 'model_distillation', 'dataset': ['cifar10'], 'model': ['resnet32']},
   # Drift detectors (not supported by `fetch_detector`...)
    {'detector_type': 'drift', 'detector_name': 'ks', 'dataset': ['cifar10', 'imdb'], 'version': ['0.6.2']},
    {'detector_type': 'drift', 'detector_name': 'mmd', 'dataset': ['cifar10'], 'version': ['0.8.1']},
    {'detector_type': 'drift', 'detector_name': 'tabular', 'dataset': ['income'], 'version': ['0.7.0', '0.8.1']},
]


def dict_product(dicts):
    return (dict(zip(dicts, x)) for x in itertools.product(*dicts.values()))


trials = []
for test in TESTS:
    for k, v in test.items():
        if not isinstance(v, list):
            test[k] = [v]
    trials += list(dict_product(test))
n_tests = len(trials)


@pytest.fixture
def unpack_trials(request):
    return trials[request.param]


@pytest.mark.parametrize("unpack_trials", list(range(n_tests)), indirect=True)
def test_fetch_detector(unpack_trials, tmp_path):
    kwargs = unpack_trials
    print(kwargs)
    if kwargs['detector_type'] in ('outlier', 'adversarial'):
        det = fetch_detector(tmp_path, **kwargs)

    else:
        # create url of detector
        version = kwargs.get('version', '')
        version = version.replace('.', '_')
        url = 'gs://seldon-models/alibi-detect/' 
        url += 'cd/' + kwargs['detector_name'] + '/' + kwargs['dataset'] + '-' + version + '/'

        # Download bucket directory
        cloudpath = CloudPath(url)
        cloudpath.copytree(tmp_path)

        # Load detector
        det = load_detector(tmp_path)

    # Check loaded detector
    dataset = kwargs['dataset']
    detector_type = kwargs['detector_type']
    det_backend = det._detector if hasattr(det, '_detector') else det
    det_backend = det_backend._detector if hasattr(det_backend, '_detector') else det_backend

    X = None
    if dataset == 'imdb':
        if detector_type == 'drift':
            assert isinstance(det_backend.preprocess_fn.keywords['model'], alibi_detect.cd.tensorflow.preprocess.UAE)
            assert isinstance(det_backend.preprocess_fn.keywords['tokenizer'], transformers.PreTrainedTokenizerBase)
        X = ["This is one of the dumbest films, I've ever seen. It rips off nearly ever type of thriller and "
             "manages to make a mess of them all.<br /><br />There's not a single good line or character in the "
             "whole mess. If there was a plot, it was an afterthought and as far as acting goes, there's nothing "
             "good to say so Ill say nothing. I honestly cant understand how this type of nonsense gets produced "
             "and actually released, does somebody somewhere not at some stage think, 'Oh my god this really is a "
             "load of shite' and call it a day. Its crap like this that has people downloading illegally, the trailer "
             "looks like a completely different film, at least if you have download it, you haven't wasted your time "
             "or money Don't waste your time, this is painful."]

    elif dataset == 'cifar10':
        if detector_type == 'drift':
            assert isinstance(det_backend.preprocess_fn.keywords['model'], tf.keras.Model)
        X = np.random.uniform(size=(5, 32, 32, 3))  # dummy batch of (32,32,3) images

    if dataset == 'kddcup':
        kddcup = fetch_kdd(percent10=True)
        normal_batch = create_outlier_batch(kddcup.data, kddcup.target, n_samples=5, perc_outlier=0)
        X, y = normal_batch.data.astype('float'), normal_batch.target

    if X is not None:
        if hasattr(det, 'infer_threshold'):
            det.infer_threshold(X, threshold_perc=95)  # inferring on test data just for basic test...
        _ = det.predict(X)

Copy link
Contributor

@jklaise jklaise left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@mauicv mauicv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ascillitoe
Copy link
Contributor Author

@ascillitoe ascillitoe changed the title Set x_ref_preprocessed=True during legacy loading Set x_ref_preprocessed=True during legacy loading (v0.10.5 patch) Jan 25, 2023
@ascillitoe ascillitoe merged commit 0c3395f into SeldonIO:patch/v0.10.5 Jan 25, 2023
@ascillitoe ascillitoe deleted the fix/legacy_load_preproc branch January 25, 2023 17:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants