Skip to content

Commit

Permalink
Add missing_values parameter
Browse files Browse the repository at this point in the history
Signed-off-by: zethson <[email protected]>
  • Loading branch information
Zethson committed Apr 25, 2024
1 parent 4505b7e commit 6c863a4
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion src/fknni/faiss/faiss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Literal
from typing import Literal, Union

import faiss
import numpy as np
Expand All @@ -12,14 +12,17 @@ class FaissImputer(BaseEstimator, TransformerMixin):

def __init__(
self,
missing_values: Union[int, float, str, None] = np.nan,
n_neighbors: int = 5,
*,
metric: Literal["l2", "ip"] = "l2",
strategy: Literal["mean", "median", "weighted"] = "mean",
index_factory: str = "Flat",
):
"""Initializes FaissImputer with specified parameters that are used for the imputation.
Args:
missing_values: The missing value to impute. Defaults to np.nan.
n_neighbors: Number of neighbors to use for imputation. Defaults to 5.
metric: Distance metric to use for neighbor search. Defaults to 'l2'.
strategy: Method to compute imputed values among neighbors.
Expand All @@ -28,6 +31,7 @@ def __init__(
index_factory: Description of the Faiss index type to build. Defaults to 'Flat'.
"""
super().__init__()
self.missing_values = missing_values
self.n_neighbors = n_neighbors
self.metric = metric
self.strategy = strategy
Expand All @@ -44,6 +48,11 @@ def fit(self, X: np.ndarray | pd.DataFrame, *, y: np.ndarray | None = None) -> "
ValueError: If any parameters are set to an invalid value.
"""
X = np.asarray(X, dtype=np.float32)
if isinstance(X, pd.DataFrame):
X = X.replace(self.missing_values, np.nan).values
else:
X = np.where(X == self.missing_values, np.nan, X)

if np.isnan(X).all(axis=0).any():
raise ValueError("Features with all values missing cannot be handled.")

Expand Down

0 comments on commit 6c863a4

Please sign in to comment.