Skip to content

Commit

Permalink
Adapt structure
Browse files Browse the repository at this point in the history
Signed-off-by: zethson <[email protected]>
  • Loading branch information
Zethson committed Apr 23, 2024
1 parent 96e3717 commit 392be82
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 31 deletions.
9 changes: 5 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name = "fknni"
version = "0.0.1"
description = "Fast implementations of KNN imputation."
readme = "README.md"
requires-python = ">=3.9"
requires-python = ">=3.10"
license = {file = "LICENSE"}
authors = [
{name = "Lukas Heumos"},
Expand All @@ -19,6 +19,7 @@ urls.Documentation = "https://fknni.readthedocs.io/"
urls.Source = "https://github.com/zethson/fknni"
urls.Home-page = "https://github.com/zethson/fknni"
dependencies = [
"scikit-learn",
"faiss-cpu",
]

Expand All @@ -39,11 +40,11 @@ doc = [
"ipykernel",
"ipython",
"sphinx-copybutton",
"pandas",
]
test = [
"pytest",
"coverage",
"pandas",
]

[tool.coverage.run]
Expand Down Expand Up @@ -96,7 +97,7 @@ ignore = [
"D107",
# Errors from function calls in argument defaults. These are fine when the result is immutable.
"B008",
# __magic__ methods are are often self-explanatory, allow missing docstrings
# __magic__ methods are often self-explanatory, allow missing docstrings
"D105",
# first line should end with a period [Bug: doesn't work with single-line docstrings]
"D400",
Expand All @@ -110,7 +111,7 @@ ignore = [
]

[tool.ruff.lint.pydocstyle]
convention = "numpy"
convention = "google"

[tool.ruff.lint.per-file-ignores]
"docs/*" = ["I"]
Expand Down
2 changes: 1 addition & 1 deletion src/fknni/faiss/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .faiss import basic_tool
from .faiss import FaissImputer
115 changes: 101 additions & 14 deletions src/fknni/faiss/faiss.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,101 @@
def basic_tool() -> int:
"""Run a tool on the AnnData object.
Parameters
----------
adata
The AnnData object to preprocess.
Returns
-------
Some integer value.
"""
print("Implement a tool to run on the AnnData object.")
return 0
from typing import Literal

import faiss
import numpy as np
import pandas as pd
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils.validation import check_array, check_is_fitted


class FaissImputer(BaseEstimator, TransformerMixin):
"""Imputer for completing missing values using Faiss."""

def __init__(
self,
n_neighbors: int = 3,
metric: Literal["l2", "ip"] = "l2",
strategy: Literal["mean", "median"] = "mean",
index_factory: str = "Flat",
):
"""Initializes FaissImputer with specified parameters.
Args:
n_neighbors: Number of neighbors to use for imputation.
metric: Distance metric to use for neighbor search.
strategy: Method to compute imputed values.
index_factory: Description of the Faiss index type to build.
"""
super().__init__()
self.n_neighbors = n_neighbors
self.metric = metric
self.strategy = strategy
self.index_factory = index_factory

def fit(self, X: np.ndarray | pd.DataFrame, *, y: np.ndarray | None = None) -> "FaissImputer":
"""Fits the FaissImputer to the provided data.
Args:
X: Input data with potential missing values. Expected to be either a NumPy array or a pandas DataFrame.
y: Ignored, present for compatibility with sklearn's TransformerMixin.
Raises:
ValueError: If any parameters are set to an invalid value.
Returns:
self: Instance with fitted data.
"""
X = check_array(X, dtype=np.float32, force_all_finite="allow-nan")

if not isinstance(self.n_neighbors, int) or self.n_neighbors <= 0:
raise ValueError("n_neighbors must be a positive integer")
if self.metric not in {"l2", "ip"}:
raise ValueError("metric must be either 'l2' or 'ip'")
if self.strategy not in {"mean", "median"}:
raise ValueError("strategy must be either 'mean' or 'median'")

mask = ~np.isnan(X).any(axis=1)
X_non_missing = X[mask]

index = faiss.index_factory(
X_non_missing.shape[1],
self.index_factory,
faiss.METRIC_L2 if self.metric == "l2" else faiss.METRIC_INNER_PRODUCT,
)
index.train(X_non_missing)
index.add(X_non_missing)
self.index_ = index

return self

def transform(self, X: np.ndarray | pd.DataFrame) -> np.ndarray:
"""Imputes missing values in the data using the fitted Faiss index.
Args:
X: Data with missing values to impute. Expected to be either a NumPy array or a pandas DataFrame.
Returns:
X_imputed: Data with imputed values as a NumPy array.
"""
X = check_array(X, dtype=np.float32, force_all_finite="allow-nan")
check_is_fitted(self, "index_")
X_imputed = np.array(X, copy=True)
missing_mask = np.isnan(X_imputed)

placeholder_values = (
np.nanmean(X_imputed, axis=0) if self.strategy == "mean" else np.nanmedian(X_imputed, axis=0)
)

for sample_idx in np.where(missing_mask.any(axis=1))[0]:
sample_row = X_imputed[sample_idx, :]
sample_missing_cols = np.where(missing_mask[sample_idx])[0]
sample_row[sample_missing_cols] = placeholder_values[sample_missing_cols]

_, neighbor_indices = self.index_.search(sample_row.reshape(1, -1), self.n_neighbors)
selected_values = X_imputed[neighbor_indices[0], :][:, sample_missing_cols]

sample_row[sample_missing_cols] = (
np.mean(selected_values, axis=0) if self.strategy == "mean" else np.median(selected_values, axis=0)
)
X_imputed[sample_idx, :] = sample_row

return X_imputed
12 changes: 0 additions & 12 deletions tests/test_basic.py

This file was deleted.

37 changes: 37 additions & 0 deletions tests/test_faiss_imputation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import numpy as np
import pandas as pd
import pytest

from fknni.faiss.faiss import FaissImputer


@pytest.fixture
def simple_test_df():
rng = np.random.default_rng(0)
data = pd.DataFrame(rng.integers(0, 100, size=(10, 5)), columns=list("ABCDE"))
data_missing = data.copy()
indices = [(i, j) for i in range(data.shape[0]) for j in range(data.shape[1])]
rng.shuffle(indices)
for i, j in indices[:5]: # Making 5 entries NaN
data_missing.iat[i, j] = np.nan

return data.to_numpy(), data_missing.to_numpy()


def test_median_imputation(simple_test_df):
data, data_missing = simple_test_df
imputer = FaissImputer(n_neighbors=5, strategy="median")
imputer.fit(data_missing)

df_imputed = imputer.transform(data_missing)

assert not np.isnan(df_imputed).any(), "NaNs remain after median imputation"


def test_imputer_with_no_missing_values(simple_test_df):
data, _ = simple_test_df
imputer = FaissImputer(n_neighbors=5, strategy="median")
imputer.fit(data)
df_imputed = imputer.transform(data)

np.testing.assert_array_equal(data, df_imputed, err_msg="Imputer altered data without missing values")

0 comments on commit 392be82

Please sign in to comment.