diff --git a/applications/contrastive_phenotyping/demo_fit.py b/applications/contrastive_phenotyping/demo_fit.py new file mode 100644 index 00000000..aba11397 --- /dev/null +++ b/applications/contrastive_phenotyping/demo_fit.py @@ -0,0 +1,45 @@ +from lightning.pytorch import Trainer +from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.loggers import TensorBoardLogger +from lightning.pytorch.callbacks import DeviceStatsMonitor + + +from viscy.data.triplet import TripletDataModule +from viscy.light.engine import ContrastiveModule + + +def main(): + dm = TripletDataModule( + data_path="/hpc/projects/virtual_staining/2024_02_04_A549_DENV_ZIKV_timelapse/registered_chunked.zarr", + tracks_path="/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr", + source_channel=["Phase3D", "RFP"], + z_range=(20, 35), + batch_size=16, + num_workers=10, + initial_yx_patch_size=(384, 384), + final_yx_patch_size=(224, 224), + ) + model = ContrastiveModule( + backbone="resnet50", + in_channels=2, + log_batches_per_epoch=2, + log_samples_per_batch=3, + ) + trainer = Trainer( + max_epochs=5, + limit_train_batches=10, + limit_val_batches=5, + logger=TensorBoardLogger( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/test_tb", + log_graph=True, + default_hp_metric=True, + ), + log_every_n_steps=1, + callbacks=[ModelCheckpoint()], + profiler="simple", # other options: "advanced" uses cprofiler, "pytorch" uses pytorch profiler. + ) + trainer.fit(model, dm) + + +if __name__ == "__main__": + main() diff --git a/applications/contrastive_phenotyping/evaluation/GMM_clustering.py b/applications/contrastive_phenotyping/evaluation/GMM_clustering.py new file mode 100644 index 00000000..a64bc8a8 --- /dev/null +++ b/applications/contrastive_phenotyping/evaluation/GMM_clustering.py @@ -0,0 +1,355 @@ +# %% import statements +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +import numpy as np +from sklearn.mixture import GaussianMixture +from sklearn.metrics import confusion_matrix +from sklearn.decomposition import PCA +from viscy.representation.evalutation.clustering import GMMClustering +from viscy.representation.evalutation.dimensionality_reduction import compute_pca +from viscy.representation.evalutation.dimensionality_reduction import compute_umap +# %% Paths and parameters. + +features_path_30_min = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr" +) + + +feature_path_no_track = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/feb_fixed_test_predict.zarr" +) + + +features_path_any_time = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_difcell_randomtime_sampling/Ver2_updateTracking_refineModel/predictions/Feb_2chan_128patch_32projDim/2chan_128patch_56ckpt_FebTest.zarr" +) + +features_path_june = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/jun_time_interval_1_epoch_178.zarr") + +# load annotation +def load_annotation(da, path, name, categories: dict | None = None): + """ + Load annotations from a CSV file and map them to the dataset. + Parameters + ---------- + da : xarray.DataArray + The dataset array containing 'fov_name' and 'id' coordinates. + path : str + Path to the CSV file containing annotations. + name : str + The column name in the CSV file to be used as annotations. + categories : dict, optional + A dictionary to rename categories in the annotation column. Default is None. + Returns + ------- + pd.Series + A pandas Series containing the selected annotations mapped to the dataset. + """ + # Read the annotation CSV file + annotation = pd.read_csv(path) + # Add a leading slash to 'fov name' column and set it as 'fov_name' + annotation["fov_name"] = "/" + annotation["fov_name"] + # Set the index of the annotation DataFrame to ['fov_name', 'id'] + annotation = annotation.set_index(["fov_name", "id"]) + # Create a MultiIndex from the dataset array's 'fov_name' and 'id' values + mi = pd.MultiIndex.from_arrays( + [da["fov_name"].values, da["id"].values], names=["fov_name", "id"] + ) + # Select the annotations corresponding to the MultiIndex + selected = annotation.loc[mi][name] + # If categories are provided, rename the categories in the selected annotations + if categories: + selected = selected.astype("category").cat.rename_categories(categories) + return selected + +# %% visualize distribution of embeddings +embedding_dataset = read_embedding_dataset(features_path_30_min) +features_data = embedding_dataset["features"] +n_samples, n_features = features_data.shape + +random_dimensions = np.random.choice(n_features, 5, replace=False) + +plt.figure(figsize=(15, 10)) +for i, dim in enumerate(random_dimensions, 1): + plt.subplot(2, 3, i) + sns.histplot(features_data[:, dim], kde=True) + plt.title(f"Dimension {dim} Distribution") + +plt.tight_layout() +plt.show() + +# %% initialize GMM clustering and ground truth labels + +embedding_dataset = read_embedding_dataset(features_path_30_min) +features_data = embedding_dataset['features'] + +cluster_evaluator = GMMClustering(features_data) + +# %% Find best n_clusters, can skip this if already known + +aic_scores, bic_scores = cluster_evaluator.find_best_n_clusters() + +plt.figure(figsize=(8, 6)) +plt.plot(cluster_evaluator.n_clusters_range, aic_scores, label="AIC", marker="o") +plt.plot(cluster_evaluator.n_clusters_range, bic_scores, label="BIC", marker="o") +plt.xlabel("Number of clusters") +plt.ylabel("AIC / BIC Score") +plt.title("AIC and BIC Scores for Different Numbers of Clusters") +plt.legend() +plt.show() + +# %% +# Choose the best model (with the lowest BIC score) +# set n_clusters to the best number of clusters +best_gmm = cluster_evaluator.fit_best_model(criterion='bic', n_clusters=2) +cluster_labels = cluster_evaluator.predict_clusters() + +# %% ground truth labels (if available!) +# need to update path to this +ann_root = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/supervised_inf_pred" +) + +infection = load_annotation( + features_data, + ann_root / "extracted_inf_state.csv", + "infection_state", + {0.0: "background", 1.0: "uninfected", 2.0: "infected"}, +) + +# %% the confusion matrix with ground truth states + +ground_truth_labels_numeric = infection.cat.codes + +cm = confusion_matrix(ground_truth_labels_numeric, cluster_labels) + +cm_df = pd.DataFrame( + cm, + index=["Background", "Uninfected", "Infected"], + columns=["Cluster 0", "Cluster 1", "Cluster 2"], +) + +plt.figure(figsize=(8, 6)) +sns.heatmap(cm_df, annot=True, fmt="g", cmap="Blues") + +plt.title("Confusion Matrix: Clusters vs Ground Truth") +plt.ylabel("Ground Truth Labels") +plt.xlabel("Cluster Labels") +plt.show() + +# %% +# Reduce dimensions to 2 for vis +_, _, pca_df = compute_pca(embedding_dataset, n_components=2) + +pca1 = pca_df["PCA1"] +pca2 = pca_df["PCA2"] + +color_map = {"background": "gray", "uninfected": "blue", "infected": "red"} +colors = infection.map(color_map) + +plt.figure(figsize=(10, 8)) + +# Plot Cluster 0 with circle markers ('o') +plt.scatter( + pca1[cluster_labels == 0], + pca2[cluster_labels == 0], + c=colors[cluster_labels == 0], + edgecolor="black", + s=50, + alpha=0.7, + label="Cluster 0 (circle)", + marker="o", +) + +# Plot Cluster 1 with X markers ('x') +plt.scatter( + pca1[cluster_labels == 1], + pca2[cluster_labels == 1], + c=colors[cluster_labels == 1], + edgecolor="black", + s=50, + alpha=0.7, + label="Cluster 1 (X)", + marker="x", +) + +plt.xlabel("PCA 1") +plt.ylabel("PCA 2") +plt.title("Ground Truth Colors with GMM Cluster Marker Types") + +handles = [ + plt.Line2D( + [0], + [0], + marker="o", + color="w", + label=label, + markerfacecolor=color_map[label], + markersize=10, + markeredgecolor="black", + ) + for label in color_map.keys() +] +plt.legend(handles=handles, title="Ground Truth") + +plt.show() + +# %% Visualize GMM Clusters in PCA space (without ground truth) +_, _, pca_df = compute_pca(embedding_dataset, n_components=2) + +pca1 = pca_df["PCA1"] +pca2 = pca_df["PCA2"] + +plt.figure(figsize=(10, 8)) + +# Plot Cluster 0 with circle markers ('o') +plt.scatter( + pca1[cluster_labels == 0], + pca2[cluster_labels == 0], + c="green", + edgecolor="black", + s=50, + alpha=0.7, + label="Cluster 0 (GMM)", + marker="o", +) + +# Plot Cluster 1 with X markers ('x') +plt.scatter( + pca1[cluster_labels == 1], + pca2[cluster_labels == 1], + c="orange", + edgecolor="black", + s=50, + alpha=0.7, + label="Cluster 1 (GMM)", + marker="x", +) + +plt.xlabel("PCA 1") +plt.ylabel("PCA 2") +plt.title("GMM Clusters") + +plt.legend() +plt.show() + + +# %% Visualize UMAP embeddings colored by GMM cluster weights (without ground truth) +umap_features, umap_projection, umap_df = compute_umap(embedding_dataset) + +gmm_weights = best_gmm.weights_ + +plt.figure(figsize=(10, 8)) +plt.scatter( + umap_df["UMAP1"], + umap_df["UMAP2"], + c=gmm_weights[cluster_labels], + cmap="viridis", + s=50, + alpha=0.8, + edgecolor="k", +) +plt.colorbar(label="GMM Cluster Weights") +plt.title("UMAP Embeddings Colored by GMM Cluster Weights") +plt.xlabel("UMAP 1") +plt.ylabel("UMAP 2") +plt.show() + + +# %% Visualize UMAP embeddings colored by cluster labels (without ground truth) +umap_features, umap_projection, umap_df = compute_umap(embedding_dataset) + +plt.figure(figsize=(10, 8)) + +plt.scatter( + umap_df["UMAP1"][cluster_labels == 0], + umap_df["UMAP2"][cluster_labels == 0], + c="green", + edgecolor="black", + s=50, + alpha=0.7, + label="Cluster 0 (GMM)", + marker="o", +) + +plt.scatter( + umap_df["UMAP1"][cluster_labels == 1], + umap_df["UMAP2"][cluster_labels == 1], + c="orange", + edgecolor="black", + s=50, + alpha=0.7, + label="Cluster 1 (GMM)", + marker="o", +) + +plt.xlabel("UMAP 1") +plt.ylabel("UMAP 2") +plt.title("GMM Clusters in UMAP Space") + +plt.legend() +plt.show() + +# %% UMAP vis (w/ ground truth colors and GMM cluster markers) +umap_features, umap_projection, umap_df = compute_umap( + embedding_dataset, normalize_features=True +) + +umap1 = umap_df["UMAP1"] +umap2 = umap_df["UMAP2"] + +color_map = {"background": "gray", "uninfected": "blue", "infected": "red"} +colors = infection.map(color_map) + +plt.figure(figsize=(10, 8)) + +# Plot Cluster 0 with circle markers ('o') +plt.scatter( + umap1[cluster_labels == 0], + umap2[cluster_labels == 0], + c=colors[cluster_labels == 0], + edgecolor="black", + s=50, + alpha=0.7, + label="Cluster 0 (circle)", + marker="o", +) + +# Plot Cluster 1 with X markers ('x') +plt.scatter( + umap1[cluster_labels == 1], + umap2[cluster_labels == 1], + c=colors[cluster_labels == 1], + edgecolor="black", + s=50, + alpha=0.7, + label="Cluster 1 (X)", + marker="x", +) + +plt.xlabel("UMAP 1") +plt.ylabel("UMAP 2") +plt.title("Ground Truth Colors with GMM Cluster Marker Types in UMAP Space") + +handles = [ + plt.Line2D( + [0], + [0], + marker="o", + color="w", + label=label, + markerfacecolor=color_map[label], + markersize=10, + markeredgecolor="black", + ) + for label in color_map.keys() +] +plt.legend(handles=handles, title="Ground Truth") + +plt.show() + +# %% diff --git a/applications/contrastive_phenotyping/evaluation/PC_vs_CF.py b/applications/contrastive_phenotyping/evaluation/PC_vs_CF.py index f0d404f4..79d82131 100644 --- a/applications/contrastive_phenotyping/evaluation/PC_vs_CF.py +++ b/applications/contrastive_phenotyping/evaluation/PC_vs_CF.py @@ -20,6 +20,7 @@ from viscy.representation.evalutation.feature import ( FeatureExtractor as FE, ) +from viscy.representation.evaluation import dataset_of_tracks # %% features_path = Path( @@ -333,6 +334,7 @@ # %% find the cell patches with the highest and lowest value in each feature + def save_patches(fov_name, track_id): data_path = Path( "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr" diff --git a/applications/contrastive_phenotyping/evaluation/PC_vs_CF_singleChannel.py b/applications/contrastive_phenotyping/evaluation/PC_vs_CF_singleChannel.py index 823c09ce..fe0cce47 100644 --- a/applications/contrastive_phenotyping/evaluation/PC_vs_CF_singleChannel.py +++ b/applications/contrastive_phenotyping/evaluation/PC_vs_CF_singleChannel.py @@ -20,6 +20,7 @@ from viscy.representation.evalutation.feature import ( FeatureExtractor as FE, ) +from viscy.representation.evaluation import dataset_of_tracks # %% features_path = Path( diff --git a/applications/contrastive_phenotyping/evaluation/analyze_embeddings.py b/applications/contrastive_phenotyping/evaluation/analyze_embeddings.py index c103baa6..10537c14 100644 --- a/applications/contrastive_phenotyping/evaluation/analyze_embeddings.py +++ b/applications/contrastive_phenotyping/evaluation/analyze_embeddings.py @@ -6,12 +6,15 @@ import seaborn as sns from viscy.representation.embedding_writer import read_embedding_dataset -from viscy.representation.evalutation import load_annotation -from viscy.representation.evalutation.dimensionality_reduction import ( - compute_pca, - compute_umap, -) - +from viscy.representation.evaluation import load_annotation, compute_pca, compute_umap + +# %% Jupyter magic command for autoreloading modules +# ruff: noqa +# fmt: off +# %load_ext autoreload +# %autoreload 2 +# fmt: on +# ruff: noqa # %% Paths and parameters path_embedding = Path( diff --git a/applications/contrastive_phenotyping/evaluation/displacement.py b/applications/contrastive_phenotyping/evaluation/displacement.py index cc7e9ca0..a14ce11d 100644 --- a/applications/contrastive_phenotyping/evaluation/displacement.py +++ b/applications/contrastive_phenotyping/evaluation/displacement.py @@ -5,75 +5,113 @@ import numpy as np from viscy.representation.embedding_writer import read_embedding_dataset -from viscy.representation.evalutation.distance import ( - calculate_normalized_euclidean_distance_cell, - compute_displacement_mean_std_full, +from viscy.representation.evaluation import ( + calculate_normalized_euclidean_distance_cell, + compute_displacement_mean_std_full, ) -# %% paths +# %% paths features_path_30_min = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr" + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr" ) -feature_path_no_track = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/feb_fixed_test_predict.zarr") +feature_path_no_track = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/feb_fixed_test_predict.zarr" +) -features_path_any_time = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_difcell_randomtime_sampling/Ver2_updateTracking_refineModel/predictions/Feb_2chan_128patch_32projDim/2chan_128patch_56ckpt_FebTest.zarr") +features_path_any_time = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_difcell_randomtime_sampling/Ver2_updateTracking_refineModel/predictions/Feb_2chan_128patch_32projDim/2chan_128patch_56ckpt_FebTest.zarr" +) data_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr" + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr" ) tracks_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr" + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr" ) # %% Load embedding datasets for all three sampling -fov_name = '/B/4/6' +fov_name = "/B/4/6" track_id = 52 embedding_dataset_30_min = read_embedding_dataset(features_path_30_min) embedding_dataset_no_track = read_embedding_dataset(feature_path_no_track) embedding_dataset_any_time = read_embedding_dataset(features_path_any_time) -#%% +# %% # Calculate displacement for each sampling -time_points_30_min, cosine_similarities_30_min = calculate_normalized_euclidean_distance_cell(embedding_dataset_30_min, fov_name, track_id) -time_points_no_track, cosine_similarities_no_track = calculate_normalized_euclidean_distance_cell(embedding_dataset_no_track, fov_name, track_id) -time_points_any_time, cosine_similarities_any_time = calculate_normalized_euclidean_distance_cell(embedding_dataset_any_time, fov_name, track_id) +time_points_30_min, cosine_similarities_30_min = ( + calculate_normalized_euclidean_distance_cell( + embedding_dataset_30_min, fov_name, track_id + ) +) +time_points_no_track, cosine_similarities_no_track = ( + calculate_normalized_euclidean_distance_cell( + embedding_dataset_no_track, fov_name, track_id + ) +) +time_points_any_time, cosine_similarities_any_time = ( + calculate_normalized_euclidean_distance_cell( + embedding_dataset_any_time, fov_name, track_id + ) +) # %% Plot displacement over time for all three conditions plt.figure(figsize=(10, 6)) -plt.plot(time_points_no_track, cosine_similarities_no_track, marker='o', label='classical contrastive (no tracking)') -plt.plot(time_points_any_time, cosine_similarities_any_time, marker='o', label='cell aware') -plt.plot(time_points_30_min, cosine_similarities_30_min, marker='o', label='cell & time aware (interval 30 min)') +plt.plot( + time_points_no_track, + cosine_similarities_no_track, + marker="o", + label="classical contrastive (no tracking)", +) +plt.plot( + time_points_any_time, cosine_similarities_any_time, marker="o", label="cell aware" +) +plt.plot( + time_points_30_min, + cosine_similarities_30_min, + marker="o", + label="cell & time aware (interval 30 min)", +) plt.xlabel("Time Delay (t)", fontsize=10) plt.ylabel("Normalized Euclidean Distance with First Time Point", fontsize=10) -plt.title("Normalized Euclidean Distance (Features) Over Time for Infected Cell", fontsize=12) +plt.title( + "Normalized Euclidean Distance (Features) Over Time for Infected Cell", fontsize=12 +) plt.grid(True) plt.legend(fontsize=10) -#plt.savefig('4_euc_dist_full.svg', format='svg') +# plt.savefig('4_euc_dist_full.svg', format='svg') plt.show() # %% Paths to datasets -features_path_30_min = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr") -feature_path_no_track = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/feb_fixed_test_predict.zarr") +features_path_30_min = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr" +) +feature_path_no_track = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/feb_fixed_test_predict.zarr" +) embedding_dataset_30_min = read_embedding_dataset(features_path_30_min) embedding_dataset_no_track = read_embedding_dataset(feature_path_no_track) # %% -max_tau = 10 +max_tau = 10 -mean_displacement_30_min_euc, std_displacement_30_min_euc = compute_displacement_mean_std_full(embedding_dataset_30_min, max_tau) -mean_displacement_no_track_euc, std_displacement_no_track_euc = compute_displacement_mean_std_full(embedding_dataset_no_track, max_tau) +mean_displacement_30_min_euc, std_displacement_30_min_euc = ( + compute_displacement_mean_std_full(embedding_dataset_30_min, max_tau) +) +mean_displacement_no_track_euc, std_displacement_no_track_euc = ( + compute_displacement_mean_std_full(embedding_dataset_no_track, max_tau) +) # %% Plot 2: Cosine Displacements plt.figure(figsize=(10, 6)) @@ -83,24 +121,44 @@ mean_values_30_min_euc = list(mean_displacement_30_min_euc.values()) std_values_30_min_euc = list(std_displacement_30_min_euc.values()) -plt.plot(taus, mean_values_30_min_euc, marker='o', label='Cell & Time Aware (30 min interval)', color='green') -plt.fill_between(taus, - np.array(mean_values_30_min_euc) - np.array(std_values_30_min_euc), - np.array(mean_values_30_min_euc) + np.array(std_values_30_min_euc), - color='green', alpha=0.3, label='Std Dev (30 min interval)') +plt.plot( + taus, + mean_values_30_min_euc, + marker="o", + label="Cell & Time Aware (30 min interval)", + color="green", +) +plt.fill_between( + taus, + np.array(mean_values_30_min_euc) - np.array(std_values_30_min_euc), + np.array(mean_values_30_min_euc) + np.array(std_values_30_min_euc), + color="green", + alpha=0.3, + label="Std Dev (30 min interval)", +) mean_values_no_track_euc = list(mean_displacement_no_track_euc.values()) std_values_no_track_euc = list(std_displacement_no_track_euc.values()) -plt.plot(taus, mean_values_no_track_euc, marker='o', label='Classical Contrastive (No Tracking)', color='blue') -plt.fill_between(taus, - np.array(mean_values_no_track_euc) - np.array(std_values_no_track_euc), - np.array(mean_values_no_track_euc) + np.array(std_values_no_track_euc), - color='blue', alpha=0.3, label='Std Dev (No Tracking)') +plt.plot( + taus, + mean_values_no_track_euc, + marker="o", + label="Classical Contrastive (No Tracking)", + color="blue", +) +plt.fill_between( + taus, + np.array(mean_values_no_track_euc) - np.array(std_values_no_track_euc), + np.array(mean_values_no_track_euc) + np.array(std_values_no_track_euc), + color="blue", + alpha=0.3, + label="Std Dev (No Tracking)", +) -plt.xlabel('Time Shift (τ)') -plt.ylabel('Euclidean Distance') -plt.title('Embedding Displacement Over Time (Features)') +plt.xlabel("Time Shift (τ)") +plt.ylabel("Euclidean Distance") +plt.title("Embedding Displacement Over Time (Features)") plt.grid(True) plt.legend() diff --git a/applications/contrastive_phenotyping/evaluation/log_regresssion_training.py b/applications/contrastive_phenotyping/evaluation/log_regresssion_training.py index ee0f405a..7456b284 100644 --- a/applications/contrastive_phenotyping/evaluation/log_regresssion_training.py +++ b/applications/contrastive_phenotyping/evaluation/log_regresssion_training.py @@ -1,23 +1,22 @@ - # %% from pathlib import Path import pandas as pd from viscy.representation.embedding_writer import read_embedding_dataset -from viscy.representation.evalutation import load_annotation +from viscy.representation.evaluation import load_annotation # %% Paths and parameters. features_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr" + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr" ) data_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr" + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr" ) tracks_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr" + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr" ) @@ -33,15 +32,15 @@ # %% OVERLAY INFECTION ANNOTATION ann_root = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/supervised_inf_pred" + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/supervised_inf_pred" ) infection = load_annotation( - features, - ann_root / "extracted_inf_state.csv", - "infection_state", - {0.0: "background", 1.0: "uninfected", 2.0: "infected"}, + features, + ann_root / "extracted_inf_state.csv", + "infection_state", + {0.0: "background", 1.0: "uninfected", 2.0: "infected"}, ) # %% plot the umap @@ -74,10 +73,18 @@ # %% manually split the dataset into training and testing set by well name # dataframe for training set, fov names starts with "/B/4/6" or "/B/4/7" or "/A/3/" -data_train_val = data[data["fov_name"].str.contains("/B/4/6") | data["fov_name"].str.contains("/B/4/7") | data["fov_name"].str.contains("/A/3/")] +data_train_val = data[ + data["fov_name"].str.contains("/B/4/6") + | data["fov_name"].str.contains("/B/4/7") + | data["fov_name"].str.contains("/A/3/") +] # dataframe for testing set, fov names starts with "/B/4/8" or "/B/4/9" or "/A/4/" -data_test = data[data["fov_name"].str.contains("/B/4/8") | data["fov_name"].str.contains("/B/4/9") | data["fov_name"].str.contains("/B/3/")] +data_test = data[ + data["fov_name"].str.contains("/B/4/8") + | data["fov_name"].str.contains("/B/4/9") + | data["fov_name"].str.contains("/B/3/") +] # %% train a linear classifier to predict infection state from PCA components diff --git a/applications/contrastive_phenotyping/evaluation/plot_embeddings.py b/applications/contrastive_phenotyping/evaluation/plot_embeddings.py index 5d37ea8f..ce99bc41 100644 --- a/applications/contrastive_phenotyping/evaluation/plot_embeddings.py +++ b/applications/contrastive_phenotyping/evaluation/plot_embeddings.py @@ -1,7 +1,9 @@ # %% +import os from pathlib import Path import matplotlib.pyplot as plt +import napari import numpy as np import pandas as pd import plotly.express as px @@ -131,9 +133,6 @@ plt.show() # %% display the track in napari -# import os - -# import napari # os.environ["DISPLAY"] = ":1" # viewer = napari.Viewer() diff --git a/applications/contrastive_phenotyping/figures/cell_division.py b/applications/contrastive_phenotyping/figures/cell_division.py index 096e7dba..2844ff58 100644 --- a/applications/contrastive_phenotyping/figures/cell_division.py +++ b/applications/contrastive_phenotyping/figures/cell_division.py @@ -7,6 +7,7 @@ import matplotlib.pyplot as plt import pandas as pd import seaborn as sns +from matplotlib.patches import FancyArrowPatch from sklearn.preprocessing import StandardScaler from umap import UMAP @@ -108,8 +109,6 @@ def load_annotation(da, path, name, categories: dict | None = None): # %% plot the trajectory quiver of one cell on top of the UMAP -from matplotlib.patches import FancyArrowPatch - cell_parent = features[ (features["fov_name"].str.contains("A/3/7")) & (features["track_id"].isin([13])) ] diff --git a/applications/contrastive_phenotyping/figures/classify_june.py b/applications/contrastive_phenotyping/figures/classify_june.py index ca51f2b1..21373fef 100644 --- a/applications/contrastive_phenotyping/figures/classify_june.py +++ b/applications/contrastive_phenotyping/figures/classify_june.py @@ -14,7 +14,10 @@ from viscy.representation.embedding_writer import read_embedding_dataset # %% Defining Paths for June Dataset -june_features_path = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/code_testing_soorya/output/Phase_RFP_smallPatch_June/phaseRFP_36patch_June.zarr") +june_features_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/code_testing_soorya/output/Phase_RFP_smallPatch_June/phaseRFP_36patch_June.zarr" +) + # %% Function to Load Annotations def load_annotation(da, path, name, categories: dict | None = None): @@ -29,29 +32,33 @@ def load_annotation(da, path, name, categories: dict | None = None): selected = selected.astype("category").cat.rename_categories(categories) return selected + # %% Function to Compute PCA def compute_pca(embedding_dataset, n_components=6): features = embedding_dataset["features"] scaled_features = StandardScaler().fit_transform(features.values) - + # Compute PCA with specified number of components pca = PCA(n_components=n_components, random_state=42) pca_embedding = pca.fit_transform(scaled_features) - + # Prepare DataFrame with id and PCA coordinates - pca_df = pd.DataFrame({ - "id": embedding_dataset["id"].values, - "fov_name": embedding_dataset["fov_name"].values, - "PCA1": pca_embedding[:, 0], - "PCA2": pca_embedding[:, 1], - "PCA3": pca_embedding[:, 2], - "PCA4": pca_embedding[:, 3], - "PCA5": pca_embedding[:, 4], - "PCA6": pca_embedding[:, 5] - }) - + pca_df = pd.DataFrame( + { + "id": embedding_dataset["id"].values, + "fov_name": embedding_dataset["fov_name"].values, + "PCA1": pca_embedding[:, 0], + "PCA2": pca_embedding[:, 1], + "PCA3": pca_embedding[:, 2], + "PCA4": pca_embedding[:, 3], + "PCA5": pca_embedding[:, 4], + "PCA6": pca_embedding[:, 5], + } + ) + return pca_df + # %% Load and Process June Dataset june_embedding_dataset = read_embedding_dataset(june_features_path) print(june_embedding_dataset) @@ -61,15 +68,21 @@ def compute_pca(embedding_dataset, n_components=6): print("Shape of pca_df before merge:", pca_df.shape) # Load the ground truth infection labels -june_ann_root = Path("/hpc/projects/intracellular_dashboard/viral-sensor/2024_06_13_SEC61_TOMM20_ZIKV_DENGUE_1/4.1-tracking") -june_infection = load_annotation(june_embedding_dataset, june_ann_root / "tracking_v1_infection.csv", "infection class", - {0.0: "background", 1.0: "uninfected", 2.0: "infected"}) +june_ann_root = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_06_13_SEC61_TOMM20_ZIKV_DENGUE_1/4.1-tracking" +) +june_infection = load_annotation( + june_embedding_dataset, + june_ann_root / "tracking_v1_infection.csv", + "infection class", + {0.0: "background", 1.0: "uninfected", 2.0: "infected"}, +) # Print shape of june_infection print("Shape of june_infection:", june_infection.shape) # Merge PCA results with ground truth labels on both 'fov_name' and 'id' -pca_df = pd.merge(pca_df, june_infection.reset_index(), on=['fov_name', 'id']) +pca_df = pd.merge(pca_df, june_infection.reset_index(), on=["fov_name", "id"]) # Print shape after merge print("Shape of pca_df after merge:", pca_df.shape) @@ -83,7 +96,9 @@ def compute_pca(embedding_dataset, n_components=6): X_resampled, y_resampled = smote.fit_resample(X, y) # Print shape after SMOTE -print(f"Shape after SMOTE - X_resampled: {X_resampled.shape}, y_resampled: {y_resampled.shape}") +print( + f"Shape after SMOTE - X_resampled: {X_resampled.shape}, y_resampled: {y_resampled.shape}" +) # %% Train Logistic Regression Classifier with Progress Bar model = LogisticRegression(max_iter=1000, random_state=42) @@ -104,18 +119,22 @@ def compute_pca(embedding_dataset, n_components=6): # %% Plotting the Results plt.figure(figsize=(10, 8)) -sns.scatterplot(x=pca_df["PCA1"], y=pca_df["PCA2"], hue=pca_df["infection class"], s=7, alpha=0.8) +sns.scatterplot( + x=pca_df["PCA1"], y=pca_df["PCA2"], hue=pca_df["infection class"], s=7, alpha=0.8 +) plt.title("PCA with Ground Truth Labels") -plt.savefig("june_pca_ground_truth_labels.png", format='png', dpi=300) +plt.savefig("june_pca_ground_truth_labels.png", format="png", dpi=300) plt.show() plt.figure(figsize=(10, 8)) -sns.scatterplot(x=pca_df["PCA1"], y=pca_df["PCA2"], hue=pca_df["Predicted_Label"], s=7, alpha=0.8) +sns.scatterplot( + x=pca_df["PCA1"], y=pca_df["PCA2"], hue=pca_df["Predicted_Label"], s=7, alpha=0.8 +) plt.title("PCA with Logistic Regression Predicted Labels") -plt.savefig("june_pca_predicted_labels.png", format='png', dpi=300) +plt.savefig("june_pca_predicted_labels.png", format="png", dpi=300) plt.show() # %% Save Predicted Labels to CSV save_path_csv = "june_logistic_regression_predicted_labels_feb_pca.csv" -pca_df[['id', 'fov_name', 'Predicted_Label']].to_csv(save_path_csv, index=False) +pca_df[["id", "fov_name", "Predicted_Label"]].to_csv(save_path_csv, index=False) print(f"Predicted labels saved to {save_path_csv}") diff --git a/applications/contrastive_phenotyping/figures/figure_4a_1.py b/applications/contrastive_phenotyping/figures/figure_4a_1.py index a670db0d..c1c2befc 100644 --- a/applications/contrastive_phenotyping/figures/figure_4a_1.py +++ b/applications/contrastive_phenotyping/figures/figure_4a_1.py @@ -10,9 +10,16 @@ from viscy.representation.embedding_writer import read_embedding_dataset # %% Defining Paths for February and June Datasets -feb_features_path = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/code_testing_soorya/output/June_140Patch_2chan/phaseRFP_140patch_99ckpt_Feb.zarr") -feb_data_path = Path("/hpc/projects/virtual_staining/2024_02_04_A549_DENV_ZIKV_timelapse/registered_chunked.zarr") -feb_tracks_path = Path("/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr") +feb_features_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/code_testing_soorya/output/June_140Patch_2chan/phaseRFP_140patch_99ckpt_Feb.zarr" +) +feb_data_path = Path( + "/hpc/projects/virtual_staining/2024_02_04_A549_DENV_ZIKV_timelapse/registered_chunked.zarr" +) +feb_tracks_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr" +) + # %% Function to Load and Process the Embedding Dataset def compute_umap(embedding_dataset): @@ -20,7 +27,7 @@ def compute_umap(embedding_dataset): scaled_features = StandardScaler().fit_transform(features.values) umap = UMAP() embedding = umap.fit_transform(scaled_features) - + features = ( features.assign_coords(UMAP1=("sample", embedding[:, 0])) .assign_coords(UMAP2=("sample", embedding[:, 1])) @@ -28,6 +35,7 @@ def compute_umap(embedding_dataset): ) return features + # %% Function to Load Annotations def load_annotation(da, path, name, categories: dict | None = None): annotation = pd.read_csv(path) @@ -41,19 +49,30 @@ def load_annotation(da, path, name, categories: dict | None = None): selected = selected.astype("category").cat.rename_categories(categories) return selected + # %% Function to Plot UMAP with Infection Annotations def plot_umap_infection(features, infection, title): plt.figure(figsize=(10, 8)) - sns.scatterplot(x=features["UMAP1"], y=features["UMAP2"], hue=infection, s=7, alpha=0.8) + sns.scatterplot( + x=features["UMAP1"], y=features["UMAP2"], hue=infection, s=7, alpha=0.8 + ) plt.title(f"UMAP Plot - {title}") plt.show() + # %% Load and Process February Dataset feb_embedding_dataset = read_embedding_dataset(feb_features_path) feb_features = compute_umap(feb_embedding_dataset) -feb_ann_root = Path("/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track") -feb_infection = load_annotation(feb_features, feb_ann_root / "tracking_v1_infection.csv", "infection class", {0.0: "background", 1.0: "uninfected", 2.0: "infected"}) +feb_ann_root = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track" +) +feb_infection = load_annotation( + feb_features, + feb_ann_root / "tracking_v1_infection.csv", + "infection class", + {0.0: "background", 1.0: "uninfected", 2.0: "infected"}, +) # %% Plot UMAP with Infection Status for February Dataset plot_umap_infection(feb_features, feb_infection, "February Dataset") @@ -66,21 +85,47 @@ def plot_umap_infection(features, infection, title): # %% Identify cells by infection type using fov_name -mock_cells = feb_features.sel(sample=feb_features['fov_name'].str.contains('/A/3') | feb_features['fov_name'].str.contains('/B/3')) -zika_cells = feb_features.sel(sample=feb_features['fov_name'].str.contains('/A/4')) -dengue_cells = feb_features.sel(sample=feb_features['fov_name'].str.contains('/B/4')) +mock_cells = feb_features.sel( + sample=feb_features["fov_name"].str.contains("/A/3") + | feb_features["fov_name"].str.contains("/B/3") +) +zika_cells = feb_features.sel(sample=feb_features["fov_name"].str.contains("/A/4")) +dengue_cells = feb_features.sel(sample=feb_features["fov_name"].str.contains("/B/4")) # %% Plot UMAP with Infection Status plt.figure(figsize=(10, 8)) -sns.scatterplot(x=feb_features["UMAP1"], y=feb_features["UMAP2"], hue=feb_infection, s=7, alpha=0.8) +sns.scatterplot( + x=feb_features["UMAP1"], y=feb_features["UMAP2"], hue=feb_infection, s=7, alpha=0.8 +) # Overlay with circled cells -plt.scatter(mock_cells["UMAP1"], mock_cells["UMAP2"], facecolors='none', edgecolors='blue', s=20, label='Mock Cells') -plt.scatter(zika_cells["UMAP1"], zika_cells["UMAP2"], facecolors='none', edgecolors='green', s=20, label='Zika MOI 5') -plt.scatter(dengue_cells["UMAP1"], dengue_cells["UMAP2"], facecolors='none', edgecolors='red', s=20, label='Dengue MOI 5') +plt.scatter( + mock_cells["UMAP1"], + mock_cells["UMAP2"], + facecolors="none", + edgecolors="blue", + s=20, + label="Mock Cells", +) +plt.scatter( + zika_cells["UMAP1"], + zika_cells["UMAP2"], + facecolors="none", + edgecolors="green", + s=20, + label="Zika MOI 5", +) +plt.scatter( + dengue_cells["UMAP1"], + dengue_cells["UMAP2"], + facecolors="none", + edgecolors="red", + s=20, + label="Dengue MOI 5", +) # Add legend and show plot -plt.legend(loc='best') +plt.legend(loc="best") plt.title("UMAP Plot - February Dataset with Mock, Zika, and Dengue Highlighted") plt.show() @@ -89,27 +134,48 @@ def plot_umap_infection(features, infection, title): fig, axs = plt.subplots(1, 3, figsize=(18, 6), sharex=True, sharey=True) # Mock Cells Heatmap -sns.histplot(x=mock_cells["UMAP1"], y=mock_cells["UMAP2"], bins=50, pmax=1, cmap="Blues", ax=axs[0]) -axs[0].set_title('Mock Cells') +sns.histplot( + x=mock_cells["UMAP1"], + y=mock_cells["UMAP2"], + bins=50, + pmax=1, + cmap="Blues", + ax=axs[0], +) +axs[0].set_title("Mock Cells") axs[0].set_xlim(feb_features["UMAP1"].min(), feb_features["UMAP1"].max()) axs[0].set_ylim(feb_features["UMAP2"].min(), feb_features["UMAP2"].max()) # Zika Cells Heatmap -sns.histplot(x=zika_cells["UMAP1"], y=zika_cells["UMAP2"], bins=50, pmax=1, cmap="Greens", ax=axs[1]) -axs[1].set_title('Zika MOI 5') +sns.histplot( + x=zika_cells["UMAP1"], + y=zika_cells["UMAP2"], + bins=50, + pmax=1, + cmap="Greens", + ax=axs[1], +) +axs[1].set_title("Zika MOI 5") axs[1].set_xlim(feb_features["UMAP1"].min(), feb_features["UMAP1"].max()) axs[1].set_ylim(feb_features["UMAP2"].min(), feb_features["UMAP2"].max()) # Dengue Cells Heatmap -sns.histplot(x=dengue_cells["UMAP1"], y=dengue_cells["UMAP2"], bins=50, pmax=1, cmap="Reds", ax=axs[2]) -axs[2].set_title('Dengue MOI 5') +sns.histplot( + x=dengue_cells["UMAP1"], + y=dengue_cells["UMAP2"], + bins=50, + pmax=1, + cmap="Reds", + ax=axs[2], +) +axs[2].set_title("Dengue MOI 5") axs[2].set_xlim(feb_features["UMAP1"].min(), feb_features["UMAP1"].max()) axs[2].set_ylim(feb_features["UMAP2"].min(), feb_features["UMAP2"].max()) # Set labels and adjust layout for ax in axs: - ax.set_xlabel('UMAP1') - ax.set_ylabel('UMAP2') + ax.set_xlabel("UMAP1") + ax.set_ylabel("UMAP2") plt.tight_layout() plt.show() @@ -122,32 +188,67 @@ def plot_umap_infection(features, infection, title): fig, axs = plt.subplots(2, 3, figsize=(24, 12), sharex=True, sharey=True) # Mock Cells Heatmap -sns.histplot(x=mock_cells["UMAP1"], y=mock_cells["UMAP2"], bins=50, pmax=1, cmap="Blues", ax=axs[0, 0]) -axs[0, 0].set_title('Mock Cells') +sns.histplot( + x=mock_cells["UMAP1"], + y=mock_cells["UMAP2"], + bins=50, + pmax=1, + cmap="Blues", + ax=axs[0, 0], +) +axs[0, 0].set_title("Mock Cells") axs[0, 0].set_xlim(feb_features["UMAP1"].min(), feb_features["UMAP1"].max()) axs[0, 0].set_ylim(feb_features["UMAP2"].min(), feb_features["UMAP2"].max()) # Zika Cells Heatmap -sns.histplot(x=zika_cells["UMAP1"], y=zika_cells["UMAP2"], bins=50, pmax=1, cmap="Greens", ax=axs[0, 1]) -axs[0, 1].set_title('Zika MOI 5') +sns.histplot( + x=zika_cells["UMAP1"], + y=zika_cells["UMAP2"], + bins=50, + pmax=1, + cmap="Greens", + ax=axs[0, 1], +) +axs[0, 1].set_title("Zika MOI 5") axs[0, 1].set_xlim(feb_features["UMAP1"].min(), feb_features["UMAP1"].max()) axs[0, 1].set_ylim(feb_features["UMAP2"].min(), feb_features["UMAP2"].max()) # Dengue Cells Heatmap -sns.histplot(x=dengue_cells["UMAP1"], y=dengue_cells["UMAP2"], bins=50, pmax=1, cmap="Reds", ax=axs[0, 2]) -axs[0, 2].set_title('Dengue MOI 5') +sns.histplot( + x=dengue_cells["UMAP1"], + y=dengue_cells["UMAP2"], + bins=50, + pmax=1, + cmap="Reds", + ax=axs[0, 2], +) +axs[0, 2].set_title("Dengue MOI 5") axs[0, 2].set_xlim(feb_features["UMAP1"].min(), feb_features["UMAP1"].max()) axs[0, 2].set_ylim(feb_features["UMAP2"].min(), feb_features["UMAP2"].max()) # Infected Cells Heatmap -sns.histplot(x=infected_cells["UMAP1"], y=infected_cells["UMAP2"], bins=50, pmax=1, cmap="Reds", ax=axs[1, 0]) -axs[1, 0].set_title('Infected Cells') +sns.histplot( + x=infected_cells["UMAP1"], + y=infected_cells["UMAP2"], + bins=50, + pmax=1, + cmap="Reds", + ax=axs[1, 0], +) +axs[1, 0].set_title("Infected Cells") axs[1, 0].set_xlim(feb_features["UMAP1"].min(), feb_features["UMAP1"].max()) axs[1, 0].set_ylim(feb_features["UMAP2"].min(), feb_features["UMAP2"].max()) # Uninfected Cells Heatmap -sns.histplot(x=uninfected_cells["UMAP1"], y=uninfected_cells["UMAP2"], bins=50, pmax=1, cmap="Greens", ax=axs[1, 1]) -axs[1, 1].set_title('Uninfected Cells') +sns.histplot( + x=uninfected_cells["UMAP1"], + y=uninfected_cells["UMAP2"], + bins=50, + pmax=1, + cmap="Greens", + ax=axs[1, 1], +) +axs[1, 1].set_title("Uninfected Cells") axs[1, 1].set_xlim(feb_features["UMAP1"].min(), feb_features["UMAP1"].max()) axs[1, 1].set_ylim(feb_features["UMAP2"].min(), feb_features["UMAP2"].max()) @@ -156,12 +257,11 @@ def plot_umap_infection(features, infection, title): # Set labels and adjust layout for ax in axs.flat: - ax.set_xlabel('UMAP1') - ax.set_ylabel('UMAP2') + ax.set_xlabel("UMAP1") + ax.set_ylabel("UMAP2") plt.tight_layout() plt.show() - # %% diff --git a/applications/contrastive_phenotyping/figures/figure_4e_2_feb.py b/applications/contrastive_phenotyping/figures/figure_4e_2_feb.py index d3052018..bf16fdff 100644 --- a/applications/contrastive_phenotyping/figures/figure_4e_2_feb.py +++ b/applications/contrastive_phenotyping/figures/figure_4e_2_feb.py @@ -12,52 +12,72 @@ def load_gmm_annotation(gmm_csv_path): gmm_df = pd.read_csv(gmm_csv_path) return gmm_df + # %% Function to Count and Calculate Percentage of Infected Cells Over Time Based on GMM Labels def count_infected_cell_states_over_time(embedding_dataset, gmm_df): # Convert the embedding dataset to a DataFrame - df = pd.DataFrame({ - "fov_name": embedding_dataset["fov_name"].values, - "track_id": embedding_dataset["track_id"].values, - "t": embedding_dataset["t"].values, - "id": embedding_dataset["id"].values - }) - + df = pd.DataFrame( + { + "fov_name": embedding_dataset["fov_name"].values, + "track_id": embedding_dataset["track_id"].values, + "t": embedding_dataset["t"].values, + "id": embedding_dataset["id"].values, + } + ) + # Merge with GMM data to add GMM labels - df = pd.merge(df, gmm_df[['id', 'fov_name', 'Predicted_Label']], on=['fov_name', 'id'], how='left') + df = pd.merge( + df, + gmm_df[["id", "fov_name", "Predicted_Label"]], + on=["fov_name", "id"], + how="left", + ) # Filter by time range (3 HPI to 30 HPI) - df = df[(df['t'] >= 3) & (df['t'] <= 27)] - + df = df[(df["t"] >= 3) & (df["t"] <= 27)] + # Determine the well type (Mock, Zika, Dengue) based on fov_name - df['well_type'] = df['fov_name'].apply(lambda x: 'Mock' if '/A/3' in x or '/B/3' in x else - ('Zika' if '/A/4' in x else 'Dengue')) - + df["well_type"] = df["fov_name"].apply( + lambda x: ( + "Mock" + if "/A/3" in x or "/B/3" in x + else ("Zika" if "/A/4" in x else "Dengue") + ) + ) + # Group by time, well type, and GMM label to count the number of infected cells - state_counts = df.groupby(['t', 'well_type', 'Predicted_Label']).size().unstack(fill_value=0) - + state_counts = ( + df.groupby(["t", "well_type", "Predicted_Label"]).size().unstack(fill_value=0) + ) + # Ensure that 'infected' column exists - if 'infected' not in state_counts.columns: - state_counts['infected'] = 0 - + if "infected" not in state_counts.columns: + state_counts["infected"] = 0 + # Calculate the percentage of infected cells - state_counts['total'] = state_counts.sum(axis=1) - state_counts['infected'] = (state_counts['infected'] / state_counts['total']) * 100 - + state_counts["total"] = state_counts.sum(axis=1) + state_counts["infected"] = (state_counts["infected"] / state_counts["total"]) * 100 + return state_counts + # %% Function to Plot Percentage of Infected Cells Over Time def plot_infected_cell_states(state_counts): plt.figure(figsize=(12, 8)) # Loop through each well type - for well_type in ['Mock', 'Zika', 'Dengue']: + for well_type in ["Mock", "Zika", "Dengue"]: # Select the data for the current well type - if well_type in state_counts.index.get_level_values('well_type'): - well_data = state_counts.xs(well_type, level='well_type') - + if well_type in state_counts.index.get_level_values("well_type"): + well_data = state_counts.xs(well_type, level="well_type") + # Plot only the percentage of infected cells - if 'infected' in well_data.columns: - plt.plot(well_data.index, well_data['infected'], label=f'{well_type} - Infected') + if "infected" in well_data.columns: + plt.plot( + well_data.index, + well_data["infected"], + label=f"{well_type} - Infected", + ) plt.title("Percentage of Infected Cells Over Time - February") plt.xlabel("Hours Post Perturbation") @@ -66,12 +86,17 @@ def plot_infected_cell_states(state_counts): plt.grid(True) plt.show() + # %% Load and process Feb Dataset -feb_features_path = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/code_testing_soorya/output/June_140Patch_2chan/phaseRFP_140patch_99ckpt_Feb.zarr") +feb_features_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/code_testing_soorya/output/June_140Patch_2chan/phaseRFP_140patch_99ckpt_Feb.zarr" +) feb_embedding_dataset = read_embedding_dataset(feb_features_path) # Load the GMM annotation CSV -gmm_csv_path = "june_logistic_regression_predicted_labels_feb_pca.csv" # Path to CSV file +gmm_csv_path = ( + "june_logistic_regression_predicted_labels_feb_pca.csv" # Path to CSV file +) gmm_df = load_gmm_annotation(gmm_csv_path) # %% Count Infected Cell States Over Time as Percentage using GMM labels @@ -83,5 +108,3 @@ def plot_infected_cell_states(state_counts): plot_infected_cell_states(state_counts) # %% - - diff --git a/applications/contrastive_phenotyping/figures/figure_4e_2_june.py b/applications/contrastive_phenotyping/figures/figure_4e_2_june.py index 1605ba27..e33a0f36 100644 --- a/applications/contrastive_phenotyping/figures/figure_4e_2_june.py +++ b/applications/contrastive_phenotyping/figures/figure_4e_2_june.py @@ -11,53 +11,72 @@ def load_annotation(csv_path): return pd.read_csv(csv_path) + # %% Function to Count and Calculate Percentage of Infected Cells Over Time Based on Predicted Labels def count_infected_cell_states_over_time(embedding_dataset, prediction_df): # Convert the embedding dataset to a DataFrame - df = pd.DataFrame({ - "fov_name": embedding_dataset["fov_name"].values, - "track_id": embedding_dataset["track_id"].values, - "t": embedding_dataset["t"].values, - "id": embedding_dataset["id"].values - }) - + df = pd.DataFrame( + { + "fov_name": embedding_dataset["fov_name"].values, + "track_id": embedding_dataset["track_id"].values, + "t": embedding_dataset["t"].values, + "id": embedding_dataset["id"].values, + } + ) + # Merge with the prediction data to add Predicted Labels - df = pd.merge(df, prediction_df[['id', 'fov_name', 'Infection_Class']], on=['fov_name', 'id'], how='left') + df = pd.merge( + df, + prediction_df[["id", "fov_name", "Infection_Class"]], + on=["fov_name", "id"], + how="left", + ) # Filter by time range (2 HPI to 50 HPI) - df = df[(df['t'] >= 2) & (df['t'] <= 50)] - + df = df[(df["t"] >= 2) & (df["t"] <= 50)] + # Determine the well type (Mock, Dengue, Zika) based on fov_name - df['well_type'] = df['fov_name'].apply( - lambda x: 'Mock' if '/0/1' in x or '/0/2' in x or '/0/3' in x or '/0/4' in x else - ('Dengue' if '/0/5' in x or '/0/6' in x else 'Zika')) - + df["well_type"] = df["fov_name"].apply( + lambda x: ( + "Mock" + if "/0/1" in x or "/0/2" in x or "/0/3" in x or "/0/4" in x + else ("Dengue" if "/0/5" in x or "/0/6" in x else "Zika") + ) + ) + # Group by time, well type, and Predicted_Label to count the number of infected cells - state_counts = df.groupby(['t', 'well_type', 'Infection_Class']).size().unstack(fill_value=0) - + state_counts = ( + df.groupby(["t", "well_type", "Infection_Class"]).size().unstack(fill_value=0) + ) + # Ensure that 'infected' column exists - if 'infected' not in state_counts.columns: - state_counts['infected'] = 0 - + if "infected" not in state_counts.columns: + state_counts["infected"] = 0 + # Calculate the percentage of infected cells - state_counts['total'] = state_counts.sum(axis=1) - state_counts['infected'] = (state_counts['infected'] / state_counts['total']) * 100 - + state_counts["total"] = state_counts.sum(axis=1) + state_counts["infected"] = (state_counts["infected"] / state_counts["total"]) * 100 + return state_counts + # %% Function to Plot Percentage of Infected Cells Over Time def plot_infected_cell_states(state_counts): plt.figure(figsize=(12, 8)) # Loop through each well type - for well_type in ['Mock', 'Dengue', 'Zika']: + for well_type in ["Mock", "Dengue", "Zika"]: # Select the data for the current well type - if well_type in state_counts.index.get_level_values('well_type'): - well_data = state_counts.xs(well_type, level='well_type') - + if well_type in state_counts.index.get_level_values("well_type"): + well_data = state_counts.xs(well_type, level="well_type") + # Plot only the percentage of infected cells - if 'infected' in well_data.columns: - plt.plot(well_data.index, well_data['infected'], label=f'{well_type} - Infected') + if "infected" in well_data.columns: + plt.plot( + well_data.index, + well_data["infected"], + label=f"{well_type} - Infected", + ) plt.title("Percentage of Infected Cells Over Time - June") plt.xlabel("Hours Post Perturbation") @@ -66,8 +85,11 @@ def plot_infected_cell_states(state_counts): plt.grid(True) plt.show() + # %% Load and process June Dataset -june_features_path = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/code_testing_soorya/output/Phase_RFP_smallPatch_June/phaseRFP_36patch_June.zarr") +june_features_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/code_testing_soorya/output/Phase_RFP_smallPatch_June/phaseRFP_36patch_June.zarr" +) june_embedding_dataset = read_embedding_dataset(june_features_path) # Load the predicted labels from CSV @@ -75,7 +97,9 @@ def plot_infected_cell_states(state_counts): prediction_df = load_annotation(prediction_csv_path) # %% Count Infected Cell States Over Time as Percentage using Predicted labels -state_counts = count_infected_cell_states_over_time(june_embedding_dataset, prediction_df) +state_counts = count_infected_cell_states_over_time( + june_embedding_dataset, prediction_df +) print(state_counts.head()) state_counts.info() diff --git a/applications/contrastive_phenotyping/figures/figure_cell_infection.py b/applications/contrastive_phenotyping/figures/figure_cell_infection.py index 270c78bd..30e7cd31 100644 --- a/applications/contrastive_phenotyping/figures/figure_cell_infection.py +++ b/applications/contrastive_phenotyping/figures/figure_cell_infection.py @@ -9,15 +9,16 @@ import pandas as pd import seaborn as sns from sklearn.decomposition import PCA +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import confusion_matrix from sklearn.preprocessing import StandardScaler from umap import UMAP from viscy.representation.embedding_writer import read_embedding_dataset -from viscy.representation.evalutation import load_annotation +from viscy.representation.evaluation import load_annotation # %% Paths and parameters. - features_path = Path( "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr" ) @@ -174,8 +175,6 @@ # %% train a linear classifier to predict infection state from PCA components -from sklearn.linear_model import LogisticRegression - x_train = data_train_val.drop( columns=[ "infection", @@ -214,9 +213,6 @@ # %% construct confusion matrix to compare the true and predicted infection state -import seaborn as sns -from sklearn.metrics import confusion_matrix - cm = confusion_matrix(y_test, y_pred) cm_percentage = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] * 100 sns.heatmap(cm_percentage, annot=True, fmt=".2f", cmap="viridis") diff --git a/applications/contrastive_phenotyping/figures/save_patches.py b/applications/contrastive_phenotyping/figures/save_patches.py index 882e79b1..ebba6c32 100644 --- a/applications/contrastive_phenotyping/figures/save_patches.py +++ b/applications/contrastive_phenotyping/figures/save_patches.py @@ -8,7 +8,7 @@ sys.path.append("/hpc/mydata/soorya.pradeep/scratch/viscy_infection_phenotyping/VisCy") # from viscy.data.triplet import TripletDataModule -from viscy.representation.evalutation import dataset_of_tracks +from viscy.representation.evaluation import dataset_of_tracks # %% input parameters diff --git a/applications/contrastive_phenotyping/predict.py b/applications/contrastive_phenotyping/predict.py new file mode 100644 index 00000000..d54f21f2 --- /dev/null +++ b/applications/contrastive_phenotyping/predict.py @@ -0,0 +1,171 @@ +from argparse import ArgumentParser +from pathlib import Path +import numpy as np +from lightning.pytorch import Trainer +from lightning.pytorch.callbacks import TQDMProgressBar +from lightning.pytorch.strategies import DDPStrategy +from viscy.data.triplet import TripletDataModule, TripletDataset +from viscy.light.engine import ContrastiveModule +import os +from torch.multiprocessing import Manager +from viscy.transforms import ( + NormalizeSampled, + RandAdjustContrastd, + RandAffined, + RandGaussianNoised, + RandGaussianSmoothd, + RandScaleIntensityd, + RandWeightedCropd, +) +from monai.transforms import NormalizeIntensityd, ScaleIntensityRangePercentilesd + +# Updated normalizations +normalizations = [ + NormalizeIntensityd( + keys=["Phase3D"], + subtrahend=None, + divisor=None, + nonzero=False, + channel_wise=False, + dtype=None, + allow_missing_keys=False + ), + ScaleIntensityRangePercentilesd( + keys=["RFP"], + lower=50, + upper=99, + b_min=0.0, + b_max=1.0, + clip=False, + relative=False, + channel_wise=False, + dtype=None, + allow_missing_keys=False + ), +] + +def main(hparams): + # Set paths + # /hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/6-patches/expanded_final_track_timesteps.csv + # /hpc/mydata/alishba.imran/VisCy/viscy/applications/contrastive_phenotyping/uninfected_cells.csv + # /hpc/mydata/alishba.imran/VisCy/viscy/applications/contrastive_phenotyping/expanded_transitioning_cells_metadata.csv + checkpoint_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/infection_score/multi-resnet2/contrastive_model-test-epoch=21-val_loss=0.00.ckpt" + + # non-rechunked data + data_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/2.1-register/registered.zarr" + + # updated tracking data + tracks_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/5-finaltrack/track_labels_final.zarr" + + source_channel = ["RFP", "Phase3D"] + z_range = (26, 38) + batch_size = 1 # match the number of fovs being processed such that no data is left + # set to 15 for full, 12 for infected, and 8 for uninfected + + # infected cells - JUNE + # include_fov_names = ['/0/8/001001', '/0/8/001001', '/0/8/000001', '/0/6/002002', '/0/6/002002', '/0/6/00200'] + # include_track_ids = [31, 8, 21, 4, 2, 21] + + # # uninfected cells - JUNE + # include_fov_names = ['/0/1/000000', '/0/1/000000', '/0/1/000000', '/0/1/000000', '/0/8/000002', '/0/8/000002'] + # include_track_ids = [25, 36, 37, 48, 16, 17] + + # # dividing cells - JUNE + # include_fov_names = ['/0/1/000000', '/0/1/000000', '/0/1/000000'] + # include_track_ids = [18, 21, 50] + + # uninfected cells - FEB + # include_fov_names = ['/A/3/0', 'B/3/5', 'B/3/5', 'B/3/5', 'B/3/5', '/A/4/14', '/A/4/14'] + # include_track_ids = [15, 34, 32, 31, 26, 33, 30] + + # # infected cells - FEB + # include_fov_names = ['/A/4/13', '/A/4/14', '/B/4/4', '/B/4/5', '/B/4/6', '/B/4/6'] + # include_track_ids = [25, 19, 68, 11, 29, 35] + + # # dividing cells - FEB + # include_fov_names = ['/B/4/4', '/B/3/5'] + # include_track_ids = [71, 42] + + # Initialize the data module for prediction + data_module = TripletDataModule( + data_path=data_path, + tracks_path=tracks_path, + source_channel=source_channel, + z_range=z_range, + initial_yx_patch_size=(224, 224), + final_yx_patch_size=(224, 224), + batch_size=batch_size, + num_workers=hparams.num_workers, + normalizations=normalizations, + # predict_cells = True, + # include_fov_names=include_fov_names, + # include_track_ids=include_track_ids, + ) + + data_module.setup(stage="predict") + + print(f"Total prediction dataset size: {len(data_module.predict_dataset)}") + + # Load the model from checkpoint + backbone = "resnet50" + in_stack_depth = 12 + stem_kernel_size = (5, 3, 3) + model = ContrastiveModule.load_from_checkpoint( + str(checkpoint_path), + predict=True, + backbone=backbone, + in_channels=len(source_channel), + in_stack_depth=in_stack_depth, + stem_kernel_size=stem_kernel_size, + tracks_path = tracks_path, + ) + + model.eval() + + # Initialize the trainer + trainer = Trainer( + accelerator="gpu", + devices=1, + num_nodes=1, + strategy=DDPStrategy(find_unused_parameters=False), + callbacks=[TQDMProgressBar(refresh_rate=1)], + ) + + # Run prediction + trainer.predict(model, datamodule=data_module) + + # # Collect features and projections + # features_list = [] + # projections_list = [] + + # for batch_idx, batch in enumerate(predictions): + # features, projections = batch + # features_list.append(features.cpu().numpy()) + # projections_list.append(projections.cpu().numpy()) + # all_features = np.concatenate(features_list, axis=0) + # all_projections = np.concatenate(projections_list, axis=0) + + # # for saving visualizations embeddings + # base_dir = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/5-finaltrack/test_visualizations" + # features_path = os.path.join(base_dir, 'B', '4', '2', 'before_projected_embeddings', 'test_epoch88_predicted_features.npy') + # projections_path = os.path.join(base_dir, 'B', '4', '2', 'projected_embeddings', 'test_epoch88_predicted_projections.npy') + + # np.save("/hpc/mydata/alishba.imran/VisCy/viscy/applications/contrastive_phenotyping/embeddings/resnet_uninf_rfp_epoch99_predicted_features.npy", all_features) + # np.save("/hpc/mydata/alishba.imran/VisCy/viscy/applications/contrastive_phenotyping/embeddings/resnet_uninf_rfp_epoch99_predicted_projections.npy", all_projections) + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--backbone", type=str, default="resnet50") + parser.add_argument("--margin", type=float, default=0.5) + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument("--schedule", type=str, default="Constant") + parser.add_argument("--log_steps_per_epoch", type=int, default=10) + parser.add_argument("--embedding_len", type=int, default=256) + parser.add_argument("--max_epochs", type=int, default=100) + parser.add_argument("--accelerator", type=str, default="gpu") + parser.add_argument("--devices", type=int, default=1) + parser.add_argument("--num_nodes", type=int, default=1) + parser.add_argument("--log_every_n_steps", type=int, default=1) + parser.add_argument("--num_workers", type=int, default=8) + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/applications/contrastive_phenotyping/training_script.py b/applications/contrastive_phenotyping/training_script.py new file mode 100644 index 00000000..a027945e --- /dev/null +++ b/applications/contrastive_phenotyping/training_script.py @@ -0,0 +1,272 @@ +# %% Imports and paths. +import logging +import os +from argparse import ArgumentParser +from pathlib import Path + +import torch +from torch.utils.data import DataLoader +from lightning.pytorch import Trainer +from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.strategies import DDPStrategy +from viscy.transforms import ( + NormalizeSampled, + RandAdjustContrastd, + RandAffined, + RandGaussianNoised, + RandGaussianSmoothd, + RandScaleIntensityd, + RandWeightedCropd, +) +from viscy.data.triplet import TripletDataModule, TripletDataset +from viscy.light.engine import ContrastiveModule +from viscy.representation.contrastive import ContrastiveEncoder +import pandas as pd +from pathlib import Path +from monai.transforms import NormalizeIntensityd, ScaleIntensityRangePercentilesd +from lightning.pytorch.loggers import TensorBoardLogger +from lightning.pytorch.callbacks import DeviceStatsMonitor + + +# %% Paths and constants + +# @rank_zero_only +# def init_wandb(): +# wandb.init(project="contrastive_model", dir="/hpc/mydata/alishba.imran/wandb_logs/") + +# init_wandb() + +# wandb.init(project="contrastive_model", dir="/hpc/mydata/alishba.imran/wandb_logs/") + +# input_zarr = top_dir / "2024_02_04_A549_DENV_ZIKV_timelapse/6-patches/full_patch.zarr" +# input_zarr = "/hpc/projects/virtual_staining/viral_sensor_test_dataio/2024_02_04_A549_DENV_ZIKV_timelapse/6-patches/full_patch.zarr" +top_dir = Path("/hpc/projects/intracellular_dashboard/viral-sensor/") +model_dir = top_dir / "infection_classification/models/infection_score" + +# checkpoint dir: /hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/infection_score/updated_multiple_channels +# timesteps_csv_path = ( +# top_dir / "2024_02_04_A549_DENV_ZIKV_timelapse/6-patches/final_track_timesteps.csv" +# ) + +# Data parameters +# 15 for covnext backbone, 12 for resnet (z slices) +# (28, 43) for covnext backbone, (26, 38) for resnet + +# rechunked data +data_path = "/hpc/projects/virtual_staining/2024_02_04_A549_DENV_ZIKV_timelapse/registered_chunked.zarr" + +# updated tracking data +tracks_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr" +source_channel = ["RFP", "Phase3D"] +z_range = (26, 38) +batch_size = 32 + +# normalizations = [ +# # Normalization for Phase3D using mean and std +# NormalizeSampled( +# keys=["Phase3D"], +# level="fov_statistics", +# subtrahend="mean", +# divisor="std", +# ), +# # Normalization for RFP using median and IQR +# NormalizeSampled( +# keys=["RFP"], +# level="fov_statistics", +# subtrahend="median", +# divisor="iqr", +# ), +# ] + +# Updated normalizations +normalizations = [ + NormalizeIntensityd( + keys=["Phase3D"], + subtrahend=None, + divisor=None, + nonzero=False, + channel_wise=False, + dtype=None, + allow_missing_keys=False + ), + ScaleIntensityRangePercentilesd( + keys=["RFP"], + lower=50, + upper=99, + b_min=0.0, + b_max=1.0, + clip=False, + relative=False, + channel_wise=False, + dtype=None, + allow_missing_keys=False + ), +] + +augmentations = [ + # Apply rotations and scaling together to both channels + RandAffined( + keys=source_channel, + rotate_range=[3.14, 0.0, 0.0], + scale_range=[0.0, 0.2, 0.2], + prob=0.8, + padding_mode="zeros", + shear_range=[0.0, 0.01, 0.01], + ), + # Apply contrast adjustment separately for each channel + RandAdjustContrastd(keys=["RFP"], prob=0.5, gamma=(0.7, 1.3)), # Broader range for RFP + RandAdjustContrastd(keys=["Phase3D"], prob=0.5, gamma=(0.8, 1.2)), # Moderate range for Phase + # Apply intensity scaling separately for each channel + RandScaleIntensityd(keys=["RFP"], factors=0.7, prob=0.5), # Broader scaling for RFP + RandScaleIntensityd(keys=["Phase3D"], factors=0.5, prob=0.5), # Moderate scaling for Phase + # Apply Gaussian smoothing to both channels together + RandGaussianSmoothd( + keys=source_channel, + sigma_x=(0.25, 0.75), + sigma_y=(0.25, 0.75), + sigma_z=(0.0, 0.0), + prob=0.5, + ), + # Apply Gaussian noise separately for each channel + RandGaussianNoised(keys=["RFP"], prob=0.5, mean=0.0, std=0.5), # Higher noise for RFP + RandGaussianNoised(keys=["Phase3D"], prob=0.5, mean=0.0, std=0.2), # Moderate noise for Phase + ] + +torch.set_float32_matmul_precision("medium") + +# contra_model = ContrastiveEncoder(backbone="resnet50") +# print(contra_model) + +# model_graph = torchview.draw_graph( +# contra_model, +# torch.randn(1, 1, 15, 200, 200), +# depth=3, +# device="cpu", +# ) +# model_graph.visual_graph + +# contrastive_module = ContrastiveModule() +# print(contrastive_module.encoder) + +# model_graph = torchview.draw_graph( +# contrastive_module.encoder, +# torch.randn(1, 1, 15, 200, 200), +# depth=3, +# device="cpu", +# ) +# model_graph.visual_graph + + +# %% Define the main function for training +def main(hparams): + # Seed for reproducibility + # seed_everything(42, workers=True) + + num_gpus = torch.cuda.device_count() + print(f"Number of GPUs available: {num_gpus}") + + print("Starting data module..") + # Initialize the data module + data_module = TripletDataModule( + data_path=data_path, + tracks_path=tracks_path, + source_channel=source_channel, + z_range=z_range, + initial_yx_patch_size=(512, 512), + final_yx_patch_size=(224, 224), + batch_size=batch_size, + num_workers=hparams.num_workers, + normalizations=normalizations, + augmentations=augmentations, + ) + + print("data module set up!") + + # Setup the data module for training, val and testing + data_module.setup(stage="fit") + + print( + f"Total dataset size: {len(data_module.train_dataset) + len(data_module.val_dataset)}" + ) + print(f"Training dataset size: {len(data_module.train_dataset)}") + print(f"Validation dataset size: {len(data_module.val_dataset)}") + + # Initialize the model + model = ContrastiveModule( + backbone=hparams.backbone, + loss_function=torch.nn.TripletMarginLoss(), + margin=hparams.margin, + lr=hparams.lr, + schedule=hparams.schedule, + log_batches_per_epoch=2, # total 6 images per epoch are logged + log_samples_per_batch=3, + in_channels=len(source_channel), + in_stack_depth=z_range[1] - z_range[0], + stem_kernel_size=(5, 3, 3), + embedding_len=hparams.embedding_len, + ) + + # set for each run to avoid overwritting! + #custom_folder_name = "test" + checkpoint_callback = ModelCheckpoint( + #dirpath=os.path.join(model_dir, custom_folder_name), + filename="contrastive_model-test-{epoch:02d}-{val_loss:.2f}", + save_top_k=3, + mode="min", + monitor="val/loss_epoch", + ) + + trainer = Trainer( + max_epochs=hparams.max_epochs, + # limit_train_batches=2, + # limit_val_batches=2, + callbacks=[checkpoint_callback], + logger=TensorBoardLogger( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/test_tb", + log_graph=True, + default_hp_metric=True, + ), + accelerator=hparams.accelerator, + devices=hparams.devices, + num_nodes=hparams.num_nodes, + strategy=DDPStrategy(), + log_every_n_steps=hparams.log_every_n_steps, + num_sanity_val_steps=0, + ) + + # train_loader = data_module.train_dataloader() + # example_batch = next(iter(train_loader)) + # example_input = example_batch[0] + + # wandb_logger.watch(model, log="all", log_graph=(example_input,)) + + # Fetches batches from the training dataloader, + # Calls the training_step method on the model for each batch + # Aggregates the losses and performs optimization steps + trainer.fit(model, datamodule=data_module) + + # # Validate the model + trainer.validate(model, datamodule=data_module) + + # # Test the model + # trainer.test(model, datamodule=data_module) + +# Argument parser for command-line options +# to-do: need to clean up to always use the same args +parser = ArgumentParser() +parser.add_argument("--backbone", type=str, default="resnet50") +parser.add_argument("--margin", type=float, default=0.5) +parser.add_argument("--lr", type=float, default=1e-3) +parser.add_argument("--schedule", type=str, default="Constant") +parser.add_argument("--log_steps_per_epoch", type=int, default=10) +parser.add_argument("--embedding_len", type=int, default=256) +parser.add_argument("--max_epochs", type=int, default=100) +parser.add_argument("--accelerator", type=str, default="gpu") +parser.add_argument("--devices", type=int, default=1) # 4 GPUs +parser.add_argument("--num_nodes", type=int, default=1) +parser.add_argument("--log_every_n_steps", type=int, default=1) +parser.add_argument("--num_workers", type=int, default=15) +args = parser.parse_args() + +main(args) + diff --git a/viscy/_log_images.py b/viscy/_log_images.py new file mode 100644 index 00000000..cc6b0fe4 --- /dev/null +++ b/viscy/_log_images.py @@ -0,0 +1,38 @@ +from typing import Sequence + +import numpy as np +from matplotlib.pyplot import get_cmap +from skimage.exposure import rescale_intensity +from torch import Tensor + + +def detach_sample(imgs: Sequence[Tensor], log_samples_per_batch: int): + num_samples = min(imgs[0].shape[0], log_samples_per_batch) + samples = [] + for i in range(num_samples): + patches = [] + for img in imgs: + patch = img[i].detach().cpu().numpy() + patch = np.squeeze(patch[:, patch.shape[1] // 2]) + patches.append(patch) + samples.append(patches) + return samples + + +def render_images(imgs: Sequence[Sequence[np.ndarray]], cmaps: list[str] = []): + images_grid = [] + for sample_images in imgs: + images_row = [] + for i, image in enumerate(sample_images): + if cmaps: + cm_name = cmaps[i] + else: + cm_name = "gray" if i == 0 else "inferno" + if image.ndim == 2: + image = image[np.newaxis] + for channel in image: + channel = rescale_intensity(channel, out_range=(0, 1)) + render = get_cmap(cm_name)(channel, bytes=True)[..., :3] + images_row.append(render) + images_grid.append(np.concatenate(images_row, axis=1)) + return np.concatenate(images_grid, axis=0) diff --git a/viscy/cli/cli.py b/viscy/cli/cli.py new file mode 100644 index 00000000..af8fb517 --- /dev/null +++ b/viscy/cli/cli.py @@ -0,0 +1,64 @@ +import logging +import os +import sys +from datetime import datetime + +import torch +from jsonargparse import lazy_instance +from lightning.pytorch import LightningModule +from lightning.pytorch.cli import LightningCLI +from lightning.pytorch.loggers import TensorBoardLogger + +from viscy.data.hcs import HCSDataModule +from viscy.translation.engine import VSUNet +from viscy.translation.trainer import VSTrainer + + +class VSLightningCLI(LightningCLI): + """Extending lightning CLI arguments and defualts.""" + + @staticmethod + def subcommands() -> dict[str, set[str]]: + subcommands = LightningCLI.subcommands() + subcommands["preprocess"] = {"model", "dataloaders", "datamodule"} + subcommands["export"] = {"model", "dataloaders", "datamodule"} + return subcommands + + def add_arguments_to_parser(self, parser): + if "preprocess" not in sys.argv: + parser.link_arguments("data.yx_patch_size", "model.example_input_yx_shape") + parser.link_arguments("model.architecture", "data.architecture") + parser.set_defaults( + { + "trainer.logger": lazy_instance( + TensorBoardLogger, + save_dir="", + version=datetime.now().strftime(r"%Y%m%d-%H%M%S"), + log_graph=True, + ) + } + ) + + +def main(): + """Main Lightning CLI entry point.""" + log_level = os.getenv("VISCY_LOG_LEVEL", logging.INFO) + logging.getLogger("lightning.pytorch").setLevel(log_level) + torch.set_float32_matmul_precision("high") + model_class = VSUNet + datamodule_class = HCSDataModule + seed = True + if "preprocess" in sys.argv: + seed = False + model_class = LightningModule + datamodule_class = None + _ = VSLightningCLI( + model_class=model_class, + datamodule_class=datamodule_class, + trainer_class=VSTrainer, + seed_everything_default=seed, + ) + + +if __name__ == "__main__": + main() diff --git a/viscy/cli/contrastive_triplet.py b/viscy/cli/contrastive_triplet.py new file mode 100644 index 00000000..f890b59c --- /dev/null +++ b/viscy/cli/contrastive_triplet.py @@ -0,0 +1,41 @@ +import logging +import os +from datetime import datetime + +import torch +from jsonargparse import lazy_instance +from lightning.pytorch.cli import LightningCLI +from lightning.pytorch.loggers import TensorBoardLogger + +from viscy.data.triplet import TripletDataModule +from viscy.representation.engine import ContrastiveModule + + +class ContrastiveLightningCLI(LightningCLI): + """Lightning CLI with default logger.""" + + def add_arguments_to_parser(self, parser): + parser.set_defaults( + { + "trainer.logger": lazy_instance( + TensorBoardLogger, + save_dir="", + version=datetime.now().strftime(r"%Y%m%d-%H%M%S"), + log_graph=True, + ) + } + ) + + +def main(): + """Main Lightning CLI entry point.""" + log_level = os.getenv("VISCY_LOG_LEVEL", logging.INFO) + logging.getLogger("lightning.pytorch").setLevel(log_level) + torch.set_float32_matmul_precision("high") + _ = ContrastiveLightningCLI( + model_class=ContrastiveModule, datamodule_class=TripletDataModule + ) + + +if __name__ == "__main__": + main() diff --git a/viscy/cli/metrics_script.py b/viscy/cli/metrics_script.py new file mode 100644 index 00000000..b4739534 --- /dev/null +++ b/viscy/cli/metrics_script.py @@ -0,0 +1,132 @@ +# %% script to generate your ground truth directory for viscy prediction evaluation +# After inference, the predictions generated are stored as zarr store. +# Evaluation metrics can be computed by comparison of prediction to +# human proof read ground truth. + +import argparse +import os + +import imageio as iio +import iohub.ngff as ngff +import pandas as pd + +import viscy.translation.evaluation_metrics as metrics +import viscy.utils.aux_utils as aux_utils + +# %% read the below details from the config file + + +def parse_args(): + """ + Parse command line arguments + In python namespaces are implemented as dictionaries + + :return: namespace containing the arguments passed. + """ + parser = argparse.ArgumentParser() + + parser.add_argument( + "--config", + type=str, + help="path to yaml configuration file", + ) + args = parser.parse_args() + return args + + +def main(config): + """ + pick focus slice mask from pred_zarr from slice number stored on png mask name + input pred mask & corrected ground truth mask to metrics computation + store the metrics values as csv file to corresponding positions in list + Info to be stored: + 1. position no, + 2. eval metrics values + """ + + torch_config = aux_utils.read_config(config) + + pred_dir = torch_config["evaluation_metrics"]["pred_dir"] + metric_channel = torch_config["evaluation_metrics"]["metric_channel"] + PosList = torch_config["evaluation_metrics"]["PosList"] + z_list = torch_config["evaluation_metrics"]["z_list"] + metrics_list = torch_config["evaluation_metrics"]["metrics"] + ground_truth_chans = torch_config["data"]["target_channel"] + ground_truth_subdir = "ground_truth" + + d_pod = [ + "OD_true_positives", + "OD_false_positives", + "OD_false_negatives", + "OD_precision", + "OD_recall", + "OD_f1_score", + ] + + metric_map = { + "ssim": metrics.ssim_metric, + "corr": metrics.corr_metric, + "r2": metrics.r2_metric, + "mse": metrics.mse_metric, + "mae": metrics.mae_metric, + "dice": metrics.dice_metric, + "IoU": metrics.IOU_metric, + "VI": metrics.VOI_metric, + "POD": metrics.POD_metric, + } + + path_split_head_tail = os.path.split(pred_dir) + target_zarr_dir = path_split_head_tail[0] + pred_plate = ngff.open_ome_zarr( + store_path=os.path.join(target_zarr_dir, metric_channel + "_pred.zarr"), + mode="r+", + ) + chan_names = pred_plate.channel_names + metric_chan_mask = metric_channel + "_cp_mask" + ground_truth_dir = os.path.join(target_zarr_dir, ground_truth_subdir) + + col_val = metrics_list[:] + if "POD" in col_val: + col_val.remove("POD") + for i in range(len(d_pod)): + col_val.insert(i + metrics_list.index("POD"), d_pod[i]) + df_metrics = pd.DataFrame(columns=col_val, index=PosList) + + for position, pos_data in pred_plate.positions(): + pos = int(position.split("/")[-1]) + + if pos in PosList: + idx = PosList.index(pos) + raw_data = pos_data.data + pred_mask = raw_data[0, chan_names.index(metric_chan_mask)] + + z_slice_no = z_list[idx] + gt_mask_save_name = ( + ground_truth_chans[0] + + "_p" + + str(format(pos, "03d")) + + "_z" + + str(z_slice_no) + + "_cp_mask.png" + ) + + gt_mask = iio.imread(os.path.join(ground_truth_dir, gt_mask_save_name)) + + pos_metric_list = [] + for metric_name in metrics_list: + metric_fn = metric_map[metric_name] + cur_metric_list = metric_fn( + gt_mask, + pred_mask[0], + ) + pos_metric_list = pos_metric_list + cur_metric_list + + df_metrics.loc[pos] = pos_metric_list + + csv_filename = os.path.join(ground_truth_dir, "GT_metrics.csv") + df_metrics.to_csv(csv_filename) + + +if __name__ == "__main__": + args = parse_args() + main(args.config) diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index 9ee37185..54837bb3 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -3,6 +3,7 @@ import os import re import tempfile +import warnings from pathlib import Path from typing import Callable, Literal, Sequence @@ -26,6 +27,11 @@ from viscy.data.typing import ChannelMap, DictTransform, HCSStackIndex, NormMeta, Sample +warnings.filterwarnings( + "ignore", + category=UserWarning, + message="To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).", +) _logger = logging.getLogger("lightning.pytorch") diff --git a/viscy/data/triplet.py b/viscy/data/triplet.py index b816b28c..7f564a9b 100644 --- a/viscy/data/triplet.py +++ b/viscy/data/triplet.py @@ -367,7 +367,6 @@ def _align_tracks_tables_with_positions( next((self.tracks_path / fov_name).glob("*.csv")) ).astype(int) tracks_tables.append(tracks_df) - return positions, tracks_tables @property diff --git a/viscy/representation/contrastive.py b/viscy/representation/contrastive.py index 8edeb862..dd4186fc 100644 --- a/viscy/representation/contrastive.py +++ b/viscy/representation/contrastive.py @@ -1,3 +1,4 @@ +import warnings from typing import Literal import timm @@ -6,6 +7,8 @@ from viscy.unet.networks.unext2 import StemDepthtoChannels +warnings.filterwarnings("ignore", category=UserWarning, module="torch") + class ContrastiveEncoder(nn.Module): """ diff --git a/viscy/representation/evalutation/clustering.py b/viscy/representation/evalutation/clustering.py index d87d3968..b7065339 100644 --- a/viscy/representation/evalutation/clustering.py +++ b/viscy/representation/evalutation/clustering.py @@ -1,14 +1,84 @@ """Methods for evaluating clustering performance.""" +import numpy as np from sklearn.cluster import DBSCAN from sklearn.metrics import ( accuracy_score, adjusted_rand_score, normalized_mutual_info_score, ) +from sklearn.mixture import GaussianMixture from sklearn.neighbors import KNeighborsClassifier +class GMMClustering: + def __init__(self, features_data, n_clusters_range=np.arange(2, 10)): + self.features_data = features_data + self.n_clusters_range = n_clusters_range + self.best_n_clusters = None + self.best_gmm = None + self.aic_scores = None + self.bic_scores = None + + def find_best_n_clusters(self): + """Find the best number of clusters using AIC/BIC scores.""" + aic_scores = [] + bic_scores = [] + for n in self.n_clusters_range: + gmm = GaussianMixture(n_components=n, random_state=42) + gmm.fit(self.features_data) + aic_scores.append(gmm.aic(self.features_data)) + bic_scores.append(gmm.bic(self.features_data)) + + self.aic_scores = aic_scores + self.bic_scores = bic_scores + + return aic_scores, bic_scores + + def fit_best_model(self, criterion="bic", n_clusters=None): + """ + Fit the best GMM model based on AIC or BIC scores, or a user-specified number of clusters. + + Parameters: + - criterion: 'aic' or 'bic' to select the best model based on the chosen criterion. + - n_clusters: Specify a fixed number of clusters (overrides the 'best' search). + """ + # Case 1: If the user provides n_clusters, use it directly + if n_clusters is not None: + self.best_n_clusters = n_clusters + + # Case 2: If no n_clusters is provided but find_best_n_clusters was run, use stored AIC/BIC results + elif self.aic_scores is not None and self.bic_scores is not None: + if criterion == "bic": + self.best_n_clusters = self.n_clusters_range[np.argmin(self.bic_scores)] + else: + self.best_n_clusters = self.n_clusters_range[np.argmin(self.aic_scores)] + + # Case 3: If find_best_n_clusters hasn't been run, compute AIC/BIC scores now + else: + aic_scores, bic_scores = self.find_best_n_clusters() + if criterion == "bic": + self.best_n_clusters = self.n_clusters_range[np.argmin(bic_scores)] + else: + self.best_n_clusters = self.n_clusters_range[np.argmin(aic_scores)] + + self.best_gmm = GaussianMixture( + n_components=self.best_n_clusters, random_state=42 + ) + self.best_gmm.fit(self.features_data) + + return self.best_gmm + + def predict_clusters(self): + """Run prediction on the fitted best GMM model.""" + if self.best_gmm is None: + raise Exception( + "No GMM model is fitted yet. Please run fit_best_model() first." + ) + cluster_labels = self.best_gmm.predict(self.features_data) + return cluster_labels + + def knn_accuracy(embeddings, annotations, k=5): """ Evaluate the k-NN classification accuracy. diff --git a/viscy/representation/evalutation/dimensionality_reduction.py b/viscy/representation/evalutation/dimensionality_reduction.py index 0a906bf4..130c8634 100644 --- a/viscy/representation/evalutation/dimensionality_reduction.py +++ b/viscy/representation/evalutation/dimensionality_reduction.py @@ -8,7 +8,7 @@ from xarray import Dataset -def compute_pca(embedding_dataset, n_components=None, normalize_features=True): +def compute_pca(embedding_dataset, n_components=None, normalize_features=False): features = embedding_dataset["features"] projections = embedding_dataset["projections"] @@ -19,31 +19,21 @@ def compute_pca(embedding_dataset, n_components=None, normalize_features=True): scaled_projections = projections.values scaled_features = features.values - # Compute PCA with specified number of components PCA_features = PCA(n_components=n_components, random_state=42) PCA_projection = PCA(n_components=n_components, random_state=42) pc_features = PCA_features.fit_transform(scaled_features) pc_projection = PCA_projection.fit_transform(scaled_projections) - # Prepare DataFrame with id and PCA coordinates - pca_df = pd.DataFrame( - { - "id": embedding_dataset["id"].values, - "fov_name": embedding_dataset["fov_name"].values, - "PCA1": pc_features[:, 0], - "PCA2": pc_features[:, 1], - "PCA3": pc_features[:, 2], - "PCA4": pc_features[:, 3], - "PCA5": pc_features[:, 4], - "PCA6": pc_features[:, 5], - "PCA1_proj": pc_projection[:, 0], - "PCA2_proj": pc_projection[:, 1], - "PCA3_proj": pc_projection[:, 2], - "PCA4_proj": pc_projection[:, 3], - "PCA5_proj": pc_projection[:, 4], - "PCA6_proj": pc_projection[:, 5], - } - ) + pca_df_dict = { + "id": embedding_dataset["id"].values, + "fov_name": embedding_dataset["fov_name"].values, + } + + for i in range(n_components): + pca_df_dict[f"PCA{i + 1}"] = pc_features[:, i] + pca_df_dict[f"PCA{i + 1}_proj"] = pc_projection[:, i] + + pca_df = pd.DataFrame(pca_df_dict) return PCA_features, PCA_projection, pca_df diff --git a/viscy/representation/lca.py b/viscy/representation/lca.py new file mode 100644 index 00000000..7c521619 --- /dev/null +++ b/viscy/representation/lca.py @@ -0,0 +1,190 @@ +"""Linear probing of trained encoder based on cell state labels.""" + +from typing import Mapping + +import pandas as pd +import torch +import torch.nn as nn +from captum.attr import IntegratedGradients, Occlusion +from numpy.typing import NDArray +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import classification_report +from sklearn.preprocessing import StandardScaler +from torch import Tensor +from xarray import DataArray + +from viscy.representation.contrastive import ContrastiveEncoder + + +def fit_logistic_regression( + features: DataArray, + annotations: pd.Series, + train_fovs: list[str], + remove_background_class: bool = True, + scale_features: bool = False, + class_weight: Mapping | str | None = "balanced", + random_state: int | None = None, + solver="liblinear", +) -> tuple[ + LogisticRegression, + tuple[tuple[NDArray, NDArray], tuple[NDArray, NDArray]], +]: + """Fit a binary logistic regression classifier. + + Parameters + ---------- + features : DataArray + Xarray of features. + annotations : pd.Series + Categorical class annotations with label values starting from 0. + Must have 3 classes (when remove background is True) or 2 classes. + train_fovs : list[str] + List of FOVs to use for training. The rest will be used for testing. + remove_background_class : bool, optional + Remove background class (0), by default True + scale_features : bool, optional + Scale features, by default False + class_weight : Mapping | str | None, optional + Class weight for balancing, by default "balanced" + random_state : int | None, optional + Random state or seed, by default None + solver : str, optional + Solver for the regression problem, by default "liblinear" + + Returns + ------- + tuple[LogisticRegression, tuple[tuple[NDArray, NDArray], tuple[NDArray, NDArray]]] + Trained classifier and data split [[X_train, y_train], [X_test, y_test]]. + """ + fov_selection = features["fov_name"].isin(train_fovs) + train_selection = fov_selection + test_selection = ~fov_selection + annotations = annotations.cat.codes.values.copy() + if remove_background_class: + label_selection = annotations != 0 + train_selection &= label_selection + test_selection &= label_selection + annotations -= 1 + train_features = features.values[train_selection] + test_features = features.values[test_selection] + if scale_features: + scaler = StandardScaler() + train_features = scaler.fit_transform(train_features) + test_features = scaler.fit_transform(test_features) + train_annotations = annotations[train_selection] + test_annotations = annotations[test_selection] + logistic_regression = LogisticRegression( + class_weight=class_weight, + random_state=random_state, + solver=solver, + ) + logistic_regression.fit(train_features, train_annotations) + prediction = logistic_regression.predict(test_features) + print("Trained logistic regression classifier.") + print( + "Training set accuracy:\n" + + classification_report( + logistic_regression.predict(train_features), train_annotations, digits=3 + ) + ) + print( + "Test set accuracy:\n" + + classification_report(prediction, test_annotations, digits=3) + ) + return logistic_regression, ( + (train_features, train_annotations), + (test_features, test_annotations), + ) + + +def linear_from_binary_logistic_regression( + logistic_regression: LogisticRegression, +) -> nn.Linear: + """Convert a binary logistic regression model to a ``torch.nn.Linear`` layer. + + Parameters + ---------- + logistic_regression : LogisticRegression + Trained logistic regression model. + + Returns + ------- + nn.Linear + Converted linear model. + """ + weights = torch.from_numpy(logistic_regression.coef_).float() + bias = torch.from_numpy(logistic_regression.intercept_).float() + model = nn.Linear(in_features=weights.shape[1], out_features=1) + model.weight.data = weights + model.bias.data = bias + model.eval() + return model + + +class AssembledClassifier(torch.nn.Module): + """Assemble a contrastive encoder with a linear classifier. + + Parameters + ---------- + backbone : ContrastiveEncoder + Encoder backbone. + classifier : nn.Linear + Classifier head. + """ + + def __init__(self, backbone: ContrastiveEncoder, classifier: nn.Linear) -> None: + super().__init__() + self.backbone = backbone + self.classifier = classifier + + @staticmethod + def scale_features(x: Tensor) -> Tensor: + m = x.mean(-2, keepdim=True) + s = x.std(-2, unbiased=False, keepdim=True) + return (x - m) / s + + def forward(self, x: Tensor, scale_features: bool = False) -> Tensor: + x = self.backbone.stem(x) + x = self.backbone.encoder(x) + if scale_features: + x = self.scale_features(x) + x = self.classifier(x) + return x + + def attribute_integrated_gradients(self, img: Tensor, **kwargs) -> Tensor: + """Compute integrated gradients for a binary classification task. + + Parameters + ---------- + img : Tensor + input image + **kwargs : Any + Keyword arguments for ``IntegratedGradients()``. + + Returns + ------- + attribution : Tensor + Integrated gradients attribution map. + """ + self.zero_grad() + ig = IntegratedGradients(self, **kwargs) + attribution = ig.attribute(img) + return attribution + + def attribute_occlusion(self, img: Tensor, **kwargs) -> Tensor: + """Compute occlusion-based attribution for a binary classification task. + + Parameters + ---------- + img : Tensor + input image + **kwargs : Any + Keyword arguments for the ``Occlusion.attribute()``. + + Returns + ------- + attribution : Tensor + Occlusion attribution map. + """ + oc = Occlusion(self) + return oc.attribute(img, **kwargs) diff --git a/viscy/translation/engine.py b/viscy/translation/engine.py index aa7ac24b..4ba73491 100644 --- a/viscy/translation/engine.py +++ b/viscy/translation/engine.py @@ -36,6 +36,10 @@ except ImportError: CellposeModel = None +try: + import wandb +except ImportError: + wandb = None _UNET_ARCHITECTURE = { "2D": Unet2d, diff --git a/viscy/translation/trainer.py b/viscy/translation/trainer.py new file mode 100644 index 00000000..7670df19 --- /dev/null +++ b/viscy/translation/trainer.py @@ -0,0 +1,101 @@ +import logging +from pathlib import Path +from typing import Literal, Sequence, Union + +import torch +from iohub import open_ome_zarr +from lightning.pytorch import LightningDataModule, LightningModule, Trainer +from lightning.pytorch.utilities.compile import _maybe_unwrap_optimized +from torch.onnx import OperatorExportTypes + +from viscy.utils.meta_utils import generate_normalization_metadata + + +class VSTrainer(Trainer): + def preprocess( + self, + data_path: Path, + channel_names: Union[list[str], Literal[-1]] = -1, + num_workers: int = 1, + block_size: int = 32, + model: LightningModule = None, + datamodule: LightningDataModule = None, + dataloaders: Sequence = None, + ): + """Compute dataset statistics before training or testing for normalization. + + :param Path data_path: Path to the HCS OME-Zarr dataset + :param Union[list[str], Literal[ channel_names: channel names, + defaults to -1 (all channels) + :param int num_workers: number of workers, defaults to 1 + :param int block_size: sampling block size, defaults to 32 + :param LightningModule model: place holder for model, ignored + :param LightningDataModule datamodule: place holder for datamodule, ignored + :param Sequence dataloaders: place holder for dataloaders, ignored + """ + if model or dataloaders or datamodule: + logging.debug("Ignoring model and data configs during preprocessing.") + with open_ome_zarr(data_path, layout="hcs", mode="r") as dataset: + channel_indices = ( + [dataset.channel_names.index(c) for c in channel_names] + if channel_names != -1 + else channel_names + ) + generate_normalization_metadata( + zarr_dir=data_path, + num_workers=num_workers, + channel_ids=channel_indices, + grid_spacing=block_size, + ) + + def export( + self, + model: LightningModule, + export_path: str, + ckpt_path: str, + format="onnx", + datamodule: LightningDataModule = None, + dataloaders: Sequence = None, + ): + """Export the model for deployment (currently only ONNX is supported). + + :param LightningModule model: module to export + :param str export_path: output file name + :param str ckpt_path: model checkpoint + :param str format: format (currently only ONNX is supported), defaults to "onnx" + :param LightningDataModule datamodule: placeholder for datamodule, + defaults to None + :param Sequence dataloaders: placeholder for dataloaders, defaults to None + """ + if dataloaders or datamodule: + logging.debug("Ignoring datamodule and dataloaders during export.") + if not format.lower() == "onnx": + raise NotImplementedError(f"Export format '{format}'") + model = _maybe_unwrap_optimized(model) + self.strategy._lightning_module = model + model.load_state_dict(torch.load(ckpt_path)["state_dict"]) + model.eval() + model.to_onnx( + export_path, + input_sample=model.example_input_array, + export_params=True, + opset_version=18, + operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK, + input_names=["input"], + output_names=["output"], + dynamic_axes={ + "input": { + 0: "batch_size", + 1: "channels", + 3: "num_rows", + 4: "num_cols", + }, + "output": { + 0: "batch_size", + 1: "channels", + 3: "num_rows", + 4: "num_cols", + }, + }, + ) + logging.info(f"ONNX exported at {export_path}")