Skip to content

Commit

Permalink
Add the faiss-based KNN model and custom metrics.
Browse files Browse the repository at this point in the history
  • Loading branch information
bojan-karlas committed Sep 28, 2024
1 parent f63bb45 commit f91b050
Showing 1 changed file with 99 additions and 7 deletions.
106 changes: 99 additions & 7 deletions experiments/datascope/experiments/pipelines/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import faiss
import numpy as np
import tempfile
import torch
Expand Down Expand Up @@ -91,6 +92,36 @@ def predict_proba(self, X: Union[NDArray, DataFrame]) -> NDArray:
return self.model.predict_proba(X)


class FaissKNN(BaseEstimator, ClassifierMixin):
def __init__(self, n_neighbors=5):
self.n_neighbors = n_neighbors
self.index = None
self.y = None
self.label_encoder = LabelEncoder()

def fit(self, X, y):
self.X = X.astype(np.float32)
self.y = self.label_encoder.fit_transform(y)
d = self.X.shape[1]
self.index = faiss.IndexFlatL2(d) # L2 distance index
self.index.add(self.X)
return self

def predict(self, X):
X = X.astype(np.float32)
distances, indices = self.index.search(X, self.n_neighbors)
votes = self.y[indices.reshape(-1)].reshape(-1, self.n_neighbors)
predictions = np.array([np.argmax(np.bincount(v, minlength=len(self.label_encoder.classes_))) for v in votes])
return predictions

def predict_proba(self, X):
X = X.astype(np.float32)
distances, indices = self.index.search(X, self.n_neighbors)
votes = self.y[indices.reshape(-1)].reshape(-1, self.n_neighbors)
proba = np.array([np.bincount(v, minlength=len(self.label_encoder.classes_)) / self.n_neighbors for v in votes])
return proba


class EvalLoggerCallback(TrainerCallback):
def __init__(
self,
Expand Down Expand Up @@ -457,44 +488,105 @@ def construct(self: "RandomForestModel", dataset: Dataset) -> BaseEstimator:


class KNearestNeighborsModel(BaseModel, id="knn", longname="K-Nearest Neighbors"):
def __init__(self, num_neighbors: int = 1, **kwargs) -> None:
def __init__(self, num_neighbors: int = 1, metric: str = "minkowski", **kwargs) -> None:
self._num_neighbors = num_neighbors
self._metric = metric

@attribute
def num_neighbors(self) -> int:
"""Number of neighbors to use."""
return self._num_neighbors

@attribute
def metric(self) -> str:
"""The distance metric to use."""
return self._metric

def construct(self: "KNearestNeighborsModel", dataset: Dataset) -> BaseEstimator:
return KNeighborsClassifier(n_neighbors=self.num_neighbors)
return KNeighborsClassifier(n_neighbors=self.num_neighbors, metric=self.metric)


class KNearestNeighborsModelK1(KNearestNeighborsModel, id="knn-1", longname="K-Nearest Neighbors (K=1)"):
def __init__(self, metric: str = "minkowski", **kwargs) -> None:
super().__init__(num_neighbors=1, metric=metric)


class KNearestNeighborsModelK3(KNearestNeighborsModel, id="knn-3", longname="K-Nearest Neighbors (K=3)"):
def __init__(self, metric: str = "minkowski", **kwargs) -> None:
super().__init__(num_neighbors=3, metric=metric)


class KNearestNeighborsModelK5(KNearestNeighborsModel, id="knn-5", longname="K-Nearest Neighbors (K=5)"):
def __init__(self, metric: str = "minkowski", **kwargs) -> None:
super().__init__(num_neighbors=5, metric=metric)


class KNearestNeighborsModelK10(KNearestNeighborsModel, id="knn-10", longname="K-Nearest Neighbors (K=10)"):
def __init__(self, metric: str = "minkowski", **kwargs) -> None:
super().__init__(num_neighbors=10, metric=metric)


class KNearestNeighborsModelK50(KNearestNeighborsModel, id="knn-50", longname="K-Nearest Neighbors (K=50)"):
def __init__(self, metric: str = "minkowski", **kwargs) -> None:
super().__init__(num_neighbors=50, metric=metric)


class KNearestNeighborsModelK100(KNearestNeighborsModel, id="knn-100", longname="K-Nearest Neighbors (K=100)"):
def __init__(self, metric: str = "minkowski", **kwargs) -> None:
super().__init__(num_neighbors=100, metric=metric)


class FastKNearestNeighborsModel(BaseModel, id="fast-knn", longname="Fast K-Nearest Neighbors"):
def __init__(self, num_neighbors: int = 1, **kwargs) -> None:
self._num_neighbors = num_neighbors

@attribute
def num_neighbors(self) -> int:
"""Number of neighbors to use."""
return self._num_neighbors

def construct(self: "FastKNearestNeighborsModel", dataset: Dataset) -> BaseEstimator:
return FaissKNN(n_neighbors=self.num_neighbors)


class FastKNearestNeighborsModelK1(
FastKNearestNeighborsModel, id="fast-knn-1", longname="Fast K-Nearest Neighbors (K=1)"
):
def __init__(self, **kwargs) -> None:
super().__init__(num_neighbors=1)


class KNearestNeighborsModelK3(KNearestNeighborsModel, id="knn-3", longname="K-Nearest Neighbors (K=3)"):
class FastKNearestNeighborsModelK3(
FastKNearestNeighborsModel, id="fast-knn-3", longname="Fast K-Nearest Neighbors (K=3)"
):
def __init__(self, **kwargs) -> None:
super().__init__(num_neighbors=3)


class KNearestNeighborsModelK5(KNearestNeighborsModel, id="knn-5", longname="K-Nearest Neighbors (K=5)"):
class FastKNearestNeighborsModelK5(
FastKNearestNeighborsModel, id="fast-knn-5", longname="Fast K-Nearest Neighbors (K=5)"
):
def __init__(self, **kwargs) -> None:
super().__init__(num_neighbors=5)


class KNearestNeighborsModelK10(KNearestNeighborsModel, id="knn-10", longname="K-Nearest Neighbors (K=10)"):
class FastKNearestNeighborsModelK10(
FastKNearestNeighborsModel, id="fast-knn-10", longname="Fast K-Nearest Neighbors (K=10)"
):
def __init__(self, **kwargs) -> None:
super().__init__(num_neighbors=10)


class KNearestNeighborsModelK50(KNearestNeighborsModel, id="knn-50", longname="K-Nearest Neighbors (K=50)"):
class FastKNearestNeighborsModelK50(
FastKNearestNeighborsModel, id="fast-knn-50", longname="Fast K-Nearest Neighbors (K=50)"
):
def __init__(self, **kwargs) -> None:
super().__init__(num_neighbors=50)


class KNearestNeighborsModelK100(KNearestNeighborsModel, id="knn-100", longname="K-Nearest Neighbors (K=100)"):
class FastKNearestNeighborsModelK100(
KNearestNeighborsModel, id="fast-knn-100", longname="Fast K-Nearest Neighbors (K=100)"
):
def __init__(self, **kwargs) -> None:
super().__init__(num_neighbors=100)

Expand Down

0 comments on commit f91b050

Please sign in to comment.