-
Notifications
You must be signed in to change notification settings - Fork 834
/
Copy pathcd_model.py
85 lines (76 loc) · 2.36 KB
/
cd_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import json
from typing import List, Dict, Optional
import logging
import numpy as np
from .numpy_encoder import NumpyEncoder
from alibi_detect.utils.saving import load_detector, Data
from adserver.base import AlibiDetectModel
class AlibiDetectConceptDriftModel(
AlibiDetectModel
): # pylint:disable=c-extension-no-member
def __init__(
self,
name: str,
storage_uri: str,
model: Optional[Data] = None,
drift_batch_size: int = 1000,
):
"""
Outlier Detection / Concept Drift Model
Parameters
----------
name
The name of the model
storage_uri
The URI location of the model
drift_batch_size
The batch size to fill before checking for drift
model
Alibi detect model
"""
super().__init__(name, storage_uri, model)
self.drift_batch_size = drift_batch_size
self.batch: np.array = None
self.model: Data = model
def process_event(self, inputs: List, headers: Dict) -> Optional[Dict]:
"""
Process the event and return Alibi Detect score
Parameters
----------
inputs
Input data
headers
Header options
Returns
-------
Alibi Detect response
"""
logging.info("PROCESSING EVENT.")
logging.info(str(headers))
logging.info("----")
try:
X = np.array(inputs)
except Exception as e:
raise Exception(
"Failed to initialize NumPy array from inputs: %s, %s" % (e, inputs)
)
if self.batch is None:
self.batch = X
else:
self.batch = np.vstack((self.batch, X))
if self.batch.shape[0] >= self.drift_batch_size:
logging.info(
"Running drift detection. Batch size is %d. Needed %d",
self.batch.shape[0],
self.drift_batch_size,
)
cd_preds = self.model.predict(X)
self.batch = None
return json.loads(json.dumps(cd_preds, cls=NumpyEncoder))
else:
logging.info(
"Not running drift detection. Batch size is %d. Need %d",
self.batch.shape[0],
self.drift_batch_size,
)
return None