-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add an almost working k-means cluster sample viz
* take everything back about other peoples stream-of-consciousness streamlit code, framework makes the mode hard to avoid!
- Loading branch information
Showing
1 changed file
with
57 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,73 @@ | ||
from sklearn.cluster import KMeans | ||
import streamlit as st | ||
from typing import Optional | ||
from ..visualisation_app import image_embeddings | ||
from cyto_ml.visualisation.visualisation_app import ( | ||
image_embeddings, | ||
image_ids, | ||
cached_image, | ||
) | ||
|
||
|
||
@st.cache_resource | ||
def kmeans_cluster(n_clusters: Optional[int] = 10): | ||
def kmeans_cluster() -> KMeans: | ||
""" | ||
K-means cluster the embeddings, option for default size | ||
""" | ||
X = image_embeddings() | ||
print("model") | ||
X = image_embeddings("plankton") | ||
n_clusters = st.session_state["n_clusters"] | ||
# Initialize and fit KMeans | ||
kmeans = KMeans(n_clusters=n_clusters, random_state=42) | ||
kmeans.fit(X) | ||
return kmeans | ||
|
||
|
||
@st.cache_data | ||
def image_labels() -> dict: | ||
""" | ||
TODO good form to move all this into cyto_ml, call from there | ||
""" | ||
km = kmeans_cluster() | ||
clusters = dict(zip(set(km.labels_), [[] for _ in range(len(set(km.labels_)))])) | ||
|
||
for index, id in enumerate(image_ids("plankton")): | ||
label = km.labels_[index] | ||
clusters[label].append(id) | ||
return clusters | ||
|
||
|
||
def show_cluster(): | ||
|
||
# TODO n_clusters configurable with selector | ||
fitted = image_labels() | ||
closest = fitted[st.session_state["cluster"]] | ||
|
||
# seems backwards, something in session state? | ||
rows = [] | ||
for _ in range(0, 8): | ||
rows.append(st.columns(8)) | ||
for index, _ in enumerate(rows): | ||
for c in rows[index]: | ||
c.image(cached_image(closest.pop()), width=60) | ||
|
||
|
||
# TODO some visualisation, actual content, etc | ||
def main() -> None: | ||
|
||
if "cluster" not in st.session_state: | ||
st.session_state["cluster"] = 1 | ||
if "n_clusters" not in st.session_state: | ||
st.session_state["n_clusters"] = 5 | ||
|
||
st.selectbox( | ||
"cluster", | ||
[x for x in range(0, st.session_state["n_clusters"])], | ||
key="cluster", | ||
on_change=show_cluster, | ||
) | ||
|
||
show_cluster() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |