Skip to content

Commit

Permalink
fix kmeans++ initialization, rename class to Kmeans (#81)
Browse files Browse the repository at this point in the history
* fix kmeans++

* changelog and bump version

* finalize and update tutorial

* changelog

* update large scale tutorial
  • Loading branch information
adamgayoso authored Feb 16, 2023
1 parent 4a04662 commit 84d3a77
Show file tree
Hide file tree
Showing 10 changed files with 393 additions and 186 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ and this project adheres to [Semantic Versioning][].
[keep a changelog]: https://keepachangelog.com/en/1.0.0/
[semantic versioning]: https://semver.org/spec/v2.0.0.html

## 0.2.1 (2022-02-16)
## 0.3.0 (2022-02-16)

- Rename `KmeansJax` to `Kmeans` and fix ++ initialization, use Kmeans as default in benchmarker instead of Leiden ([#81][])
- Warn about joblib, add progress bar postfix str ([#80][])

[#81]: https://github.com/YosefLab/scib-metrics/pull/81
[#80]: https://github.com/YosefLab/scib-metrics/pull/80

## 0.2.0 (2022-02-02)
Expand Down
2 changes: 1 addition & 1 deletion docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ scib_metrics.ilisi_knn(...)
utils.cdist
utils.pdist_squareform
utils.silhouette_samples
utils.KMeansJax
utils.KMeans
utils.pca
utils.principal_component_regression
utils.one_hot
Expand Down
116 changes: 57 additions & 59 deletions docs/notebooks/large_scale.ipynb

Large diffs are not rendered by default.

366 changes: 287 additions & 79 deletions docs/notebooks/lung_example.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ requires = ["hatchling"]

[project]
name = "scib-metrics"
version = "0.2.1"
version = "0.3.0"
description = "Accelerated and Python-only scIB metrics"
readme = "README.md"
requires-python = ">=3.8"
Expand Down
4 changes: 2 additions & 2 deletions src/scib_metrics/_nmi_ari.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
from sklearn.metrics.cluster import adjusted_rand_score, normalized_mutual_info_score
from sklearn.utils import check_array

from .utils import KMeansJax, check_square
from .utils import KMeans, check_square

logger = logging.getLogger(__name__)


def _compute_clustering_kmeans(X: np.ndarray, n_clusters: int) -> np.ndarray:
kmeans = KMeansJax(n_clusters)
kmeans = KMeans(n_clusters)
kmeans.fit(X)
return kmeans.labels_

Expand Down
4 changes: 2 additions & 2 deletions src/scib_metrics/benchmark/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ class BioConservation:
"""

isolated_labels: MetricType = True
nmi_ari_cluster_labels_leiden: MetricType = True
nmi_ari_cluster_labels_kmeans: MetricType = False
nmi_ari_cluster_labels_leiden: MetricType = False
nmi_ari_cluster_labels_kmeans: MetricType = True
silhouette_label: MetricType = True
clisi_knn: MetricType = True

Expand Down
4 changes: 2 additions & 2 deletions src/scib_metrics/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from ._diffusion_nn import diffusion_nn
from ._dist import cdist, pdist_squareform
from ._kmeans import KMeansJax
from ._kmeans import KMeans
from ._lisi import compute_simpson_index
from ._pca import pca
from ._pcr import principal_component_regression
Expand All @@ -12,7 +12,7 @@
"cdist",
"pdist_squareform",
"get_ndarray",
"KMeansJax",
"KMeans",
"pca",
"principal_component_regression",
"one_hot",
Expand Down
75 changes: 37 additions & 38 deletions src/scib_metrics/utils/_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from sklearn.utils import check_array

from .._types import IntOrKey
from ._dist import cdist, pdist_squareform
from ._dist import cdist
from ._utils import get_ndarray, validate_seed


def _initialize_random(X: jnp.ndarray, n_clusters: int, pdists: jnp.ndarray, key: jax.random.KeyArray) -> jnp.ndarray:
def _initialize_random(X: jnp.ndarray, n_clusters: int, key: jax.random.KeyArray) -> jnp.ndarray:
"""Initialize cluster centroids randomly."""
n_obs = X.shape[0]
indices = jax.random.choice(key, n_obs, (n_clusters,), replace=False)
Expand All @@ -20,38 +20,39 @@ def _initialize_random(X: jnp.ndarray, n_clusters: int, pdists: jnp.ndarray, key


@partial(jax.jit, static_argnums=1)
def _initialize_plus_plus(
X: jnp.ndarray, n_clusters: int, pdists: jnp.ndarray, key: jax.random.KeyArray
) -> jnp.ndarray:
def _initialize_plus_plus(X: jnp.ndarray, n_clusters: int, key: jax.random.KeyArray) -> jnp.ndarray:
"""Initialize cluster centroids with k-means++ algorithm."""

def _init(key, pdists):
key, subkey = jax.random.split(key)
n_obs = pdists.shape[0]
# sample first centroid uniformly at random
idx = jax.random.choice(subkey, n_obs)
centroids = jnp.full((n_clusters,), -1, dtype=jnp.int32).at[0].set(idx)
mask = jnp.zeros((n_obs,), dtype=jnp.bool_).at[idx].set(True)
return centroids, mask, pdists, key

def _step(state):
centroids, mask, pdists, key = state
key, subkey = jax.random.split(key)
n_obs = pdists.shape[0]
# d(x) = min_{mu in centroids} ||x - mu||^2, d(x) = 0 if x in centroids
probs = jnp.where(mask, 0, jnp.min(jnp.where(mask, pdists, jnp.inf), axis=1) ** 2)
# sample with probability ~ d(x)
idx = jax.random.choice(subkey, n_obs, p=probs / jnp.sum(probs))
centroids = centroids.at[jnp.sum(mask)].set(idx)
mask = mask.at[idx].set(True)
return centroids, mask, pdists, key

def _convergence(state):
_, mask, _, _ = state
return jnp.sum(mask) < n_clusters

centroids, _, _, _ = jax.lax.while_loop(_convergence, _step, _init(key, pdists))
return X[centroids]
n_obs = X.shape[0]
key, subkey = jax.random.split(key)
initial_centroid_idx = jax.random.choice(subkey, n_obs, (1,), replace=False)
initial_centroid = X[initial_centroid_idx].ravel()
dist_sq = jnp.square(cdist(initial_centroid[jnp.newaxis, :], X)).ravel()
initial_state = {"min_dist_sq": dist_sq, "centroid": initial_centroid, "key": key}
n_local_trials = 2 + int(np.log(n_clusters))

def _step(state, _):
prob = state["min_dist_sq"] / jnp.sum(state["min_dist_sq"])
# note that observations already chosen as centers will have 0 probability
# and will not be chosen again
state["key"], subkey = jax.random.split(state["key"])
next_centroid_idx_candidates = jax.random.choice(subkey, n_obs, (n_local_trials,), replace=False, p=prob)
next_centroid_candidates = X[next_centroid_idx_candidates]
# candidates by observations
dist_sq_candidates = jnp.square(cdist(next_centroid_candidates, X))
dist_sq_candidates = jnp.minimum(state["min_dist_sq"][jnp.newaxis, :], dist_sq_candidates)
candidates_pot = dist_sq_candidates.sum(axis=1)

# Decide which candidate is the best
best_candidate = jnp.argmin(candidates_pot)
min_dist_sq = dist_sq_candidates[best_candidate]
best_candidate = next_centroid_idx_candidates[best_candidate]

state["min_dist_sq"] = min_dist_sq.ravel()
state["centroid"] = X[best_candidate].ravel()
return state, state["centroid"]

_, centroids = jax.lax.scan(_step, initial_state, jnp.arange(n_clusters - 1))
return centroids


@jax.jit
Expand All @@ -62,7 +63,7 @@ def _get_dist_labels(X: jnp.ndarray, centroids: jnp.ndarray) -> jnp.ndarray:
return dist, labels


class KMeansJax:
class KMeans:
"""Jax implementation of :class:`sklearn.cluster.KMeans`.
This implementation is limited to Euclidean distance.
Expand Down Expand Up @@ -91,7 +92,7 @@ class KMeansJax:
def __init__(
self,
n_clusters: int = 8,
init: Literal["k-means++", "random"] = "random",
init: Literal["k-means++", "random"] = "k-means++",
n_init: int = 10,
max_iter: int = 300,
tol: float = 1e-4,
Expand Down Expand Up @@ -122,7 +123,6 @@ def fit(self, X: np.ndarray):
return self

def _fit(self, X: np.ndarray):
self._pdists = pdist_squareform(X)
all_centroids, all_inertias = jax.lax.map(
lambda key: self._kmeans_full_run(X, key), jax.random.split(self.seed, self.n_init)
)
Expand All @@ -131,7 +131,6 @@ def _fit(self, X: np.ndarray):
self.inertia_ = get_ndarray(all_inertias[i])
_, labels = _get_dist_labels(X, self.cluster_centroids_)
self.labels_ = get_ndarray(labels)
del self._pdists

@partial(jax.jit, static_argnums=(0,))
def _kmeans_full_run(self, X: jnp.ndarray, key: jnp.ndarray) -> jnp.ndarray:
Expand Down Expand Up @@ -169,7 +168,7 @@ def _kmeans_convergence(state):
cond2 = n_iter > self.max_iter
return jnp.logical_or(cond1, cond2)[0]

centroids = self._initialize(X, self.n_clusters, self._pdists, key)
centroids = self._initialize(X, self.n_clusters, key)
# centroids, new_inertia, old_inertia, n_iter
state = (centroids, jnp.inf, jnp.inf, jnp.array([0.0]))
state = _kmeans_step(state)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def test_isolated_labels():

def test_kmeans():
X, _ = dummy_x_labels()
kmeans = scib_metrics.utils.KMeansJax(2)
kmeans = scib_metrics.utils.KMeans(2)
kmeans.fit(X)
assert kmeans.labels_.shape == (X.shape[0],)

Expand Down

0 comments on commit 84d3a77

Please sign in to comment.