From 0f85c4ebbd5b16709a9262a3cb1309254704cafe Mon Sep 17 00:00:00 2001 From: Alishba Imran <44557946+alishbaimran@users.noreply.github.com> Date: Thu, 8 Aug 2024 22:11:05 -0700 Subject: [PATCH] Updated code (contrastive learning) (#130) * updated prediction code to test specific cells * updated kernel and stride * updated kernel and stride * new training script and predict takes in output path as parameter * passing ci requirements * fixed error for output path * fixed error for output path * fixed formatting * format resolved * docstrings for methods that parse tracks --------- Co-authored-by: Shalin Mehta --- applications/contrastive_phenotyping/pca.py | 393 ------------------ .../contrastive_phenotyping/predict.py | 114 +++-- .../training_script.py | 99 +---- viscy/data/triplet.py | 50 +++ viscy/light/engine.py | 29 +- viscy/representation/contrastive.py | 2 +- viscy/unet/networks/unext2.py | 12 +- 7 files changed, 168 insertions(+), 531 deletions(-) delete mode 100644 applications/contrastive_phenotyping/pca.py diff --git a/applications/contrastive_phenotyping/pca.py b/applications/contrastive_phenotyping/pca.py deleted file mode 100644 index b4395a2a..00000000 --- a/applications/contrastive_phenotyping/pca.py +++ /dev/null @@ -1,393 +0,0 @@ -# %% -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -import plotly.express as px -import plotly.io as pio -import seaborn as sns -from iohub import open_ome_zarr -from scipy.stats import spearmanr -from sklearn.decomposition import PCA - -# Set Plotly default renderer for VSCode -pio.renderers.default = "vscode" - -# Load predicted features and projections -predicted_features = np.load("updated_epoch66_predicted_features.npy") -predicted_projections = np.load("updated_epoch66_predicted_projections.npy") - -print(predicted_features.shape) -print(predicted_projections.shape) - -# Load the CSV file -csv_path = "epoch66_processed_order.csv" -df = pd.read_csv(csv_path) - -# Load ground truth masks -base_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/6-patches/all_annotations_patch.zarr" -ds = open_ome_zarr(base_path, layout="hcs", mode="r") - -background_mask_index = ds.channel_names.index("background_mask") -uninfected_mask_index = ds.channel_names.index("uninfected_mask") -infected_mask_index = ds.channel_names.index("infected_mask") - - -# %% -# Assuming all masks have the same shape -# TO-DO: -# tie the image with projected embeddings -# test with ER - -# Initialize arrays to store the sums -num_cells = len(df) -background_sums = np.zeros(num_cells) -uninfected_sums = np.zeros(num_cells) -infected_sums = np.zeros(num_cells) - -# %% -for idx, row in df.iterrows(): - position_key = f"{row['Row']}/{row['Column']}/fov{row['FOV']}cell{row['Cell ID']}/0" - zarr_array = ds[position_key] - t = row["Timestep"] - - # Load a single z-slice, for example the first one - background_mask = zarr_array[t, background_mask_index, 0, :, :] - uninfected_mask = zarr_array[t, uninfected_mask_index, 0, :, :] - infected_mask = zarr_array[t, infected_mask_index, 0, :, :] - - # Sum values across each mask - background_sums[idx] = np.sum(background_mask) - uninfected_sums[idx] = np.sum(uninfected_mask) - infected_sums[idx] = np.sum(infected_mask) - -# %% -# Normalize the sums -max_background = np.max(background_sums) -max_uninfected = np.max(uninfected_sums) -max_infected = np.max(infected_sums) - -background_sums /= max_background -uninfected_sums /= max_uninfected -infected_sums /= max_infected - -# %% -# Combine the sums into a single array and apply softmax -combined_sums = np.stack([background_sums, uninfected_sums, infected_sums], axis=1) -softmax_sums = np.exp(combined_sums) / np.sum( - np.exp(combined_sums), axis=1, keepdims=True -) - -# Separate the softmax values -background_softmax = softmax_sums[:, 0] -uninfected_softmax = softmax_sums[:, 1] -infected_softmax = softmax_sums[:, 2] - -# %% -# Check for NaN values in the softmax results -print("NaN values in combined_sums:", np.isnan(combined_sums).any()) -print("NaN values in softmax_sums:", np.isnan(softmax_sums).any()) -print("Infinite values in combined_sums:", np.isinf(combined_sums).any()) -print("Infinite values in softmax_sums:", np.isinf(softmax_sums).any()) - -# %% -# Check for NaN values in the softmax results -print("NaN values in background_softmax:", np.isnan(background_softmax).any()) -print("NaN values in uninfected_softmax:", np.isnan(uninfected_softmax).any()) -print("NaN values in infected_softmax:", np.isnan(infected_softmax).any()) - -# Check for zero variance in the softmax results -print("Variance in background_softmax:", np.var(background_softmax)) -print("Variance in uninfected_softmax:", np.var(uninfected_softmax)) -print("Variance in infected_softmax:", np.var(infected_softmax)) - -# %% -# Determine the number of principal components to keep -# reshaped_features = predicted_features.reshape(predicted_features.shape[0], -1) - -pca = PCA() -pca.fit(predicted_features) -explained_variance_ratio = np.cumsum(pca.explained_variance_ratio_) - -# %% -# Plot the explained variance ratio -plt.figure(figsize=(12, 6)) -plt.plot( - range(1, len(explained_variance_ratio) + 1), - explained_variance_ratio, - marker="o", - linestyle="--", -) -plt.xlabel("Number of Components") -plt.ylabel("Cumulative Explained Variance") -plt.title("Explained Variance by Number of Components") -plt.grid(True) -plt.show() - -# %% -# Choose the number of components that explain a significant amount of variance (e.g., 90%) -n_components = np.argmax(explained_variance_ratio >= 0.90) + 1 -print(f"Number of components selected: {n_components}") - -# %% -# Perform PCA with the selected number of components -pca = PCA(n_components=2) -reduced_projections = pca.fit_transform(predicted_projections) - -# %% -df["PC1"] = reduced_projections[:, 0] -df["PC2"] = reduced_projections[:, 1] -df["Infected Softmax Score"] = infected_softmax - -print(df.head()) - - -# %% -# Calculate rank correlations -correlations = [] - -for i in range(reduced_projections.shape[1]): - pc = reduced_projections[:, i] - - background_corr, _ = spearmanr(pc, background_softmax) - uninfected_corr, _ = spearmanr(pc, uninfected_softmax) - infected_corr, _ = spearmanr(pc, infected_softmax) - - correlations.append( - { - "PC": i + 1, - "Background Correlation": background_corr, - "Uninfected Correlation": uninfected_corr, - "Infected Correlation": infected_corr, - } - ) - -correlation_df = pd.DataFrame(correlations) -print(correlation_df) - -# %% -# Create an interactive scatter plot -fig = px.scatter( - df, - x="PC1", - y="PC2", - color="Infected Softmax Score", - hover_data=["Row", "Column", "FOV", "Cell ID", "Timestep"], -) - -# Show the plot -fig.show() - -# %% -# Function to get cell data and plot the images - -rfp_index = ds.channel_names.index("RFP") -phase3d_index = ds.channel_names.index("Phase3D") - - -def get_cell_data_and_plot(row, col, fov, cell_id, timestep): - position_key = f"{row}/{col}/fov{fov}cell{cell_id}/0" - zarr_array = ds[position_key] - - phase_img = zarr_array[timestep, phase3d_index, 32, :, :] - rfp_img = zarr_array[timestep, rfp_index, 32, :, :] - - fig, axes = plt.subplots(1, 2, figsize=(12, 6)) - axes[0].imshow(phase_img, cmap="gray") - axes[0].set_title("Phase3D Image") - axes[1].imshow(rfp_img, cmap="gray") - axes[1].set_title("RFP Image") - plt.show() - - return phase_img, rfp_img - - -# example: get data for a specific cell and plot -row = "B" -col = "3" -fov = 5 -cell_id = 14 -timestep = 4 - -phase_img, rfp_img = get_cell_data_and_plot(row, col, fov, cell_id, timestep) - -# %% -# Visualize the PCA results with cells colored based on their infected softmax scores -plt.figure(figsize=(12, 6)) -sc = plt.scatter( - reduced_projections[:, 0], - reduced_projections[:, 1], - c=infected_softmax, - cmap="viridis", - label="Cells", -) -plt.colorbar(sc, label="Infected Softmax Score") -plt.xlabel("Principal Component 1") -plt.ylabel("Principal Component 2") -plt.title("PCA of Predicted Projections (Colored by Infected Softmax Score)") -plt.legend() -plt.show() - -# %% -# PC1 vs PC3, PC1 vs PC4, etc. -n_components = 5 -if n_components > 2: - for i in range(2, n_components): - plt.figure(figsize=(12, 6)) - sc = plt.scatter( - reduced_projections[:, 0], - reduced_projections[:, i], - c=infected_softmax, - cmap="viridis", - label="Cells", - ) - plt.colorbar(sc, label="Infected Softmax Score") - plt.xlabel("Principal Component 1") - plt.ylabel(f"Principal Component {i + 1}") - plt.title( - f"PCA of Predicted Projections: PC1 vs PC{i + 1} (Colored by Infected Softmax Score)" - ) - plt.legend() - plt.show() - -# %% - -correlations = np.zeros(n_components) -for i in range(n_components): - pc = reduced_projections[:, i] - correlation, _ = spearmanr(pc, infected_softmax) - correlations[i] = correlation - - -# %% -# Visualize the PCA results with cells colored based on their principal component values -for i in range(n_components): - plt.figure(figsize=(12, 6)) - sc = plt.scatter( - reduced_projections[:, 0], - reduced_projections[:, 1], - c=reduced_projections[:, i], - cmap="viridis", - label=f"PC{i+1} Correlation: {correlations[i]:.2f}", - ) - plt.colorbar(sc, label="Principal Component Value") - plt.xlabel("Principal Component 1") - plt.ylabel("Principal Component 2") - plt.title(f"PCA of Predicted Projections (Colored by PC{i+1} Values)") - plt.legend() - plt.show() - -# %% -# Additional components if n_components > 2 -if n_components > 2: - for i in range(2, n_components): - plt.figure(figsize=(12, 6)) - sc = plt.scatter( - reduced_projections[:, 0], - reduced_projections[:, i], - c=reduced_projections[:, i], - cmap="viridis", - label=f"PC{i+1} Correlation: {correlations[i]:.2f}", - ) - plt.colorbar(sc, label="Principal Component Value") - plt.xlabel("Principal Component 1") - plt.ylabel(f"Principal Component {i + 1}") - plt.title( - f"PCA of Predicted Projections: PC1 vs PC{i + 1} (Colored by PC{i+1} Values)" - ) - plt.legend() - plt.show() - -# %% - -# Visualize the PCA results with color based on correlation with infected softmax score -for i in range(reduced_projections.shape[1]): - pc = reduced_projections[:, i] - infected_corr, _ = spearmanr(pc, infected_softmax) - - plt.figure(figsize=(12, 6)) - sc = plt.scatter( - reduced_projections[:, 0], - reduced_projections[:, 1], - c=pc, - cmap="viridis", - label=f"PC{i+1} Correlation: {infected_corr:.2f}", - ) - plt.colorbar(sc, label="Principal Component Value") - plt.xlabel("Principal Component 1") - plt.ylabel("Principal Component 2") - plt.title( - f"PCA of Predicted Projections (Colored by PC{i+1} Correlation with Infected Softmax Score)" - ) - plt.legend() - plt.show() - - -# %% -# Visualize additional PCs if needed -# PC1 vs PC3, PC1 vs PC4, etc. -if n_components > 2: - for i in range(2, n_components): - plt.figure(figsize=(12, 6)) - sc = plt.scatter( - reduced_projections[:, 0], - reduced_projections[:, i], - c=infected_softmax, - cmap="viridis", - label="Cells", - ) - plt.colorbar(sc, label="Infected Softmax Score") - plt.xlabel("Principal Component 1") - plt.ylabel(f"Principal Component {i + 1}") - plt.title(f"PCA of Predicted Projections: PC1 vs PC{i + 1}") - plt.legend() - plt.show() - - -# %% -# Visualize the rank correlations -plt.figure(figsize=(12, 6)) -sns.barplot( - x="PC", - y="Background Correlation", - data=correlation_df, - color="blue", - label="Background", -) -sns.barplot( - x="PC", - y="Uninfected Correlation", - data=correlation_df, - color="green", - label="Uninfected", -) -sns.barplot( - x="PC", y="Infected Correlation", data=correlation_df, color="red", label="Infected" -) -plt.xlabel("Principal Component") -plt.ylabel("Spearman Correlation") -plt.title("Rank Correlations of Principal Components with Ground Truth Masks") -plt.legend() -plt.show() - -# %% -components = pca.components_ - -# Assuming your original features are named, you can list them -feature_names = [ - f"Feature {i}" for i in range(predicted_features.shape[1]) -] # Replace with actual feature names if available - -fig, axes = plt.subplots(n_components, 1, figsize=(12, 3 * n_components)) -for i, (component, ax) in enumerate(zip(components[:n_components], axes)): - sns.heatmap( - component.reshape(1, -1), - cmap="viridis", - ax=ax, - cbar=False, - xticklabels=feature_names, - ) - ax.set_title(f"Principal Component {i + 1}") - ax.set_xlabel("Features") - ax.set_ylabel("Component Value") -plt.tight_layout() -plt.show() diff --git a/applications/contrastive_phenotyping/predict.py b/applications/contrastive_phenotyping/predict.py index 6055eed2..a2136491 100644 --- a/applications/contrastive_phenotyping/predict.py +++ b/applications/contrastive_phenotyping/predict.py @@ -17,22 +17,31 @@ RandScaleIntensityd, RandWeightedCropd, ) +from monai.transforms import NormalizeIntensityd, ScaleIntensityRangePercentilesd +# Updated normalizations normalizations = [ - # Normalization for Phase3D using mean and std - NormalizeSampled( - keys=["Phase3D"], - level="fov_statistics", - subtrahend="mean", - divisor="std", - ), - # Normalization for RFP using median and IQR - NormalizeSampled( - keys=["RFP"], - level="fov_statistics", - subtrahend="median", - divisor="iqr", - ), + NormalizeIntensityd( + keys=["Phase3D"], + subtrahend=None, + divisor=None, + nonzero=False, + channel_wise=False, + dtype=None, + allow_missing_keys=False + ), + ScaleIntensityRangePercentilesd( + keys=["MultiCam_GFP_mCherry_BF-Prime BSI Express"], + lower=50, + upper=99, + b_min=0.0, + b_max=1.0, + clip=False, + relative=False, + channel_wise=False, + dtype=None, + allow_missing_keys=False + ), ] def main(hparams): @@ -40,18 +49,47 @@ def main(hparams): # /hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/6-patches/expanded_final_track_timesteps.csv # /hpc/mydata/alishba.imran/VisCy/viscy/applications/contrastive_phenotyping/uninfected_cells.csv # /hpc/mydata/alishba.imran/VisCy/viscy/applications/contrastive_phenotyping/expanded_transitioning_cells_metadata.csv - checkpoint_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/infection_score/contrastive_model-test-epoch=09-val_loss=0.00.ckpt" + checkpoint_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/test_tb/lightning_logs/copy/contrastive_model-test-epoch=24-val_loss=0.00.ckpt" # non-rechunked data - data_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/2.1-register/registered.zarr" + # /hpc/projects/intracellular_dashboard/viral-sensor/2024_06_13_SEC61_TOMM20_ZIKV_DENGUE_1/5-infection_pred/infection_score.zarr + # /hpc/projects/virtual_staining/2024_02_04_A549_DENV_ZIKV_timelapse/registered_chunked.zarr + # /hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/2.1-register/registered.zarr + data_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_06_13_SEC61_TOMM20_ZIKV_DENGUE_1/5-infection_pred/infection_score.zarr" # updated tracking data - tracks_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/5-finaltrack/track_labels_final.zarr" + # /hpc/projects/intracellular_dashboard/viral-sensor/2024_06_13_SEC61_TOMM20_ZIKV_DENGUE_1/4.1-tracking/test_tracking_4.zarr + # /hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr + tracks_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_06_13_SEC61_TOMM20_ZIKV_DENGUE_1/4.1-tracking/test_tracking_4.zarr" - source_channel = ["RFP", "Phase3D"] - z_range = (26, 38) - batch_size = 15 # match the number of fovs being processed such that no data is left - # set to 15 for full, 12 for infected, and 8 for uninfected + # MultiCam_GFP_mCherry_BF-Prime BSI Express for June dataset + source_channel = ["MultiCam_GFP_mCherry_BF-Prime BSI Express", "Phase3D"] + z_range = (28, 43) + batch_size = 1 + + # infected cells - JUNE + # include_fov_names = ['/0/8/001001', '/0/8/001001', '/0/8/000001', '/0/6/002002', '/0/6/002002', '/0/6/00200'] + # include_track_ids = [31, 8, 21, 4, 2, 21] + + # # uninfected cells - JUNE + # include_fov_names = ['/0/1/000000', '/0/1/000000', '/0/1/000000', '/0/1/000000', '/0/8/000002', '/0/8/000002'] + # include_track_ids = [25, 36, 37, 48, 16, 17] + + # # dividing cells - JUNE + # include_fov_names = ['/0/1/000000', '/0/1/000000', '/0/1/000000'] + # include_track_ids = [18, 21, 50] + + # uninfected cells - FEB + # include_fov_names = ['/A/3/0', 'B/3/5', 'B/3/5', 'B/3/5', 'B/3/5', '/A/4/14', '/A/4/14'] + # include_track_ids = [15, 34, 32, 31, 26, 33, 30] + + # # infected cells - FEB + # include_fov_names = ['/A/4/13', '/A/4/14', '/B/4/4', '/B/4/5', '/B/4/6', '/B/4/6'] + # include_track_ids = [25, 19, 68, 11, 29, 35] + + # # dividing cells - FEB + # include_fov_names = ['/B/4/4', '/B/3/5'] + # include_track_ids = [71, 42] # Initialize the data module for prediction data_module = TripletDataModule( @@ -59,11 +97,14 @@ def main(hparams): tracks_path=tracks_path, source_channel=source_channel, z_range=z_range, - initial_yx_patch_size=(512, 512), + initial_yx_patch_size=(224, 224), final_yx_patch_size=(224, 224), batch_size=batch_size, num_workers=hparams.num_workers, normalizations=normalizations, + # predict_cells = True, + # include_fov_names=include_fov_names, + # include_track_ids=include_track_ids, ) data_module.setup(stage="predict") @@ -71,8 +112,8 @@ def main(hparams): print(f"Total prediction dataset size: {len(data_module.predict_dataset)}") # Load the model from checkpoint - backbone = "resnet50" - in_stack_depth = 12 + backbone = "convnext_tiny" + in_stack_depth = 15 stem_kernel_size = (5, 3, 3) model = ContrastiveModule.load_from_checkpoint( str(checkpoint_path), @@ -98,28 +139,9 @@ def main(hparams): # Run prediction trainer.predict(model, datamodule=data_module) - # # Collect features and projections - # features_list = [] - # projections_list = [] - - # for batch_idx, batch in enumerate(predictions): - # features, projections = batch - # features_list.append(features.cpu().numpy()) - # projections_list.append(projections.cpu().numpy()) - # all_features = np.concatenate(features_list, axis=0) - # all_projections = np.concatenate(projections_list, axis=0) - - # # for saving visualizations embeddings - # base_dir = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/5-finaltrack/test_visualizations" - # features_path = os.path.join(base_dir, 'B', '4', '2', 'before_projected_embeddings', 'test_epoch88_predicted_features.npy') - # projections_path = os.path.join(base_dir, 'B', '4', '2', 'projected_embeddings', 'test_epoch88_predicted_projections.npy') - - # np.save("/hpc/mydata/alishba.imran/VisCy/viscy/applications/contrastive_phenotyping/embeddings/resnet_uninf_rfp_epoch99_predicted_features.npy", all_features) - # np.save("/hpc/mydata/alishba.imran/VisCy/viscy/applications/contrastive_phenotyping/embeddings/resnet_uninf_rfp_epoch99_predicted_projections.npy", all_projections) - if __name__ == "__main__": parser = ArgumentParser() - parser.add_argument("--backbone", type=str, default="resnet50") + parser.add_argument("--backbone", type=str, default="convnext_tiny") parser.add_argument("--margin", type=float, default=0.5) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--schedule", type=str, default="Constant") @@ -130,6 +152,6 @@ def main(hparams): parser.add_argument("--devices", type=int, default=1) parser.add_argument("--num_nodes", type=int, default=1) parser.add_argument("--log_every_n_steps", type=int, default=1) - parser.add_argument("--num_workers", type=int, default=15) + parser.add_argument("--num_workers", type=int, default=8) args = parser.parse_args() - main(args) + main(args) \ No newline at end of file diff --git a/applications/contrastive_phenotyping/training_script.py b/applications/contrastive_phenotyping/training_script.py index 3b67a1f9..a9f9a19e 100644 --- a/applications/contrastive_phenotyping/training_script.py +++ b/applications/contrastive_phenotyping/training_script.py @@ -8,7 +8,6 @@ from torch.utils.data import DataLoader from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelCheckpoint -from lightning.pytorch.loggers import WandbLogger from lightning.pytorch.strategies import DDPStrategy from viscy.transforms import ( NormalizeSampled, @@ -19,38 +18,21 @@ RandScaleIntensityd, RandWeightedCropd, ) + from viscy.data.triplet import TripletDataModule, TripletDataset from viscy.light.engine import ContrastiveModule from viscy.representation.contrastive import ContrastiveEncoder import pandas as pd from pathlib import Path from monai.transforms import NormalizeIntensityd, ScaleIntensityRangePercentilesd +from lightning.pytorch.loggers import TensorBoardLogger +from lightning.pytorch.callbacks import DeviceStatsMonitor +from lightning.pytorch.callbacks import LearningRateMonitor -# Set W&B logging level to suppress warnings -logging.getLogger("wandb").setLevel(logging.ERROR) - -# %% Paths and constants -os.environ["WANDB_DIR"] = "/hpc/mydata/alishba.imran/wandb_logs/" - -# @rank_zero_only -# def init_wandb(): -# wandb.init(project="contrastive_model", dir="/hpc/mydata/alishba.imran/wandb_logs/") - -# init_wandb() - -# wandb.init(project="contrastive_model", dir="/hpc/mydata/alishba.imran/wandb_logs/") - -# input_zarr = top_dir / "2024_02_04_A549_DENV_ZIKV_timelapse/6-patches/full_patch.zarr" -# input_zarr = "/hpc/projects/virtual_staining/viral_sensor_test_dataio/2024_02_04_A549_DENV_ZIKV_timelapse/6-patches/full_patch.zarr" top_dir = Path("/hpc/projects/intracellular_dashboard/viral-sensor/") model_dir = top_dir / "infection_classification/models/infection_score" -# checkpoint dir: /hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/infection_score/updated_multiple_channels -# timesteps_csv_path = ( -# top_dir / "2024_02_04_A549_DENV_ZIKV_timelapse/6-patches/final_track_timesteps.csv" -# ) - # Data parameters # 15 for covnext backbone, 12 for resnet (z slices) # (28, 43) for covnext backbone, (26, 38) for resnet @@ -61,8 +43,8 @@ # updated tracking data tracks_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr" source_channel = ["RFP", "Phase3D"] -z_range = (26, 38) -batch_size = 32 +z_range = (28, 43) +batch_size = 64 # normalizations = [ # # Normalization for Phase3D using mean and std @@ -137,34 +119,9 @@ torch.set_float32_matmul_precision("medium") -# contra_model = ContrastiveEncoder(backbone="resnet50") -# print(contra_model) - -# model_graph = torchview.draw_graph( -# contra_model, -# torch.randn(1, 1, 15, 200, 200), -# depth=3, -# device="cpu", -# ) -# model_graph.visual_graph - -# contrastive_module = ContrastiveModule() -# print(contrastive_module.encoder) - -# model_graph = torchview.draw_graph( -# contrastive_module.encoder, -# torch.randn(1, 1, 15, 200, 200), -# depth=3, -# device="cpu", -# ) -# model_graph.visual_graph - # %% Define the main function for training def main(hparams): - # Seed for reproducibility - # seed_everything(42, workers=True) - num_gpus = torch.cuda.device_count() print(f"Number of GPUs available: {num_gpus}") @@ -201,30 +158,27 @@ def main(hparams): margin=hparams.margin, lr=hparams.lr, schedule=hparams.schedule, - log_steps_per_epoch=hparams.log_steps_per_epoch, + log_batches_per_epoch=1, # total 2 images per epoch are logged + log_samples_per_batch=2, in_channels=len(source_channel), in_stack_depth=z_range[1] - z_range[0], - stem_kernel_size=(5, 3, 3), + stem_kernel_size=(5, 4, 4), embedding_len=hparams.embedding_len, ) + print("Model initialized!") - # Initialize logger - wandb_logger = WandbLogger(project="contrastive_model", log_model="all") - - # set for each run to avoid overwritting! - custom_folder_name = "test" - checkpoint_callback = ModelCheckpoint( - dirpath=os.path.join(model_dir, custom_folder_name), - filename="contrastive_model-test-{epoch:02d}-{val_loss:.2f}", - save_top_k=3, - mode="min", - monitor="val/loss_epoch", - ) + lr_monitor = LearningRateMonitor(logging_interval='step') trainer = Trainer( max_epochs=hparams.max_epochs, - callbacks=[checkpoint_callback], - logger=wandb_logger, + # limit_train_batches=2, + # limit_val_batches=2, + callbacks=[ModelCheckpoint(), lr_monitor], + logger=TensorBoardLogger( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/test_tb", + log_graph=True, + default_hp_metric=True, + ), accelerator=hparams.accelerator, devices=hparams.devices, num_nodes=hparams.num_nodes, @@ -233,30 +187,21 @@ def main(hparams): num_sanity_val_steps=0, ) - # train_loader = data_module.train_dataloader() - # example_batch = next(iter(train_loader)) - # example_input = example_batch[0] - # wandb_logger.watch(model, log="all", log_graph=(example_input,)) + print("Trainer initialized!") - # Fetches batches from the training dataloader, - # Calls the training_step method on the model for each batch - # Aggregates the losses and performs optimization steps trainer.fit(model, datamodule=data_module) # # Validate the model trainer.validate(model, datamodule=data_module) - # # Test the model - # trainer.test(model, datamodule=data_module) - # Argument parser for command-line options # to-do: need to clean up to always use the same args parser = ArgumentParser() parser.add_argument("--backbone", type=str, default="convnext_tiny") parser.add_argument("--margin", type=float, default=0.5) -parser.add_argument("--lr", type=float, default=1e-3) -parser.add_argument("--schedule", type=str, default="Constant") +parser.add_argument("--lr", type=float, default=0.00001) +parser.add_argument("--schedule", type=str, default="CosineAnnealingWarmRestarts") parser.add_argument("--log_steps_per_epoch", type=int, default=10) parser.add_argument("--embedding_len", type=int, default=256) parser.add_argument("--max_epochs", type=int, default=100) diff --git a/viscy/data/triplet.py b/viscy/data/triplet.py index 4914bb55..45aab1fe 100644 --- a/viscy/data/triplet.py +++ b/viscy/data/triplet.py @@ -56,6 +56,9 @@ def __init__( positive_transform: DictTransform | None = None, negative_transform: DictTransform | None = None, fit: bool = True, + predict_cells: bool = False, + include_fov_names: list[str] | None = None, + include_track_ids: list[int] | None = None, ) -> None: self.positions = positions self.channel_names = channel_names @@ -68,9 +71,28 @@ def __init__( self.negative_transform = negative_transform self.fit = fit self.yx_patch_size = initial_yx_patch_size + self.predict_cells = predict_cells + self.include_fov_names = include_fov_names or [] + self.include_track_ids = include_track_ids or [] self.tracks = self._filter_tracks(tracks_tables) + self.tracks = ( + self._specific_cells(self.tracks) if self.predict_cells else self.tracks + ) def _filter_tracks(self, tracks_tables: list[pd.DataFrame]) -> pd.DataFrame: + """_filter_tracks Select tracks within positions that belong to this dataset and remove tracks that are too close to the border. + + Parameters + ---------- + tracks_tables : list[pd.DataFrame] + List of tracks_tables returned by TripletDataModule._align_tracks_tables_with_positions + + Returns + ------- + pd.DataFrame + Filtered tracks table + + """ filtered_tracks = [] y_exclude, x_exclude = (self.yx_patch_size[0] // 2, self.yx_patch_size[1] // 2) for pos, tracks in zip(self.positions, tracks_tables, strict=True): @@ -94,6 +116,17 @@ def _filter_tracks(self, tracks_tables: list[pd.DataFrame]) -> pd.DataFrame: ) return pd.concat(filtered_tracks).reset_index(drop=True) + def _specific_cells(self, tracks: pd.DataFrame) -> pd.DataFrame: + specific_tracks = pd.DataFrame() + print(self.include_fov_names) + print(self.include_track_ids) + for fov_name, track_id in zip(self.include_fov_names, self.include_track_ids): + filtered_tracks = tracks[ + (tracks["fov_name"] == fov_name) & (tracks["track_id"] == track_id) + ] + specific_tracks = pd.concat([specific_tracks, filtered_tracks]) + return specific_tracks.reset_index(drop=True) + def __len__(self): return len(self.tracks) @@ -182,6 +215,9 @@ def __init__( normalizations: list[MapTransform] = [], augmentations: list[MapTransform] = [], caching: bool = False, + predict_cells: bool = False, + include_fov_names: list[str] | None = None, + include_track_ids: list[int] | None = None, ): """Lightning data module for triplet sampling of patches. @@ -220,10 +256,21 @@ def __init__( self.z_range = slice(*z_range) self.tracks_path = Path(tracks_path) self.initial_yx_patch_size = initial_yx_patch_size + self.predict_cells = predict_cells + self.include_fov_names = include_fov_names + self.include_track_ids = include_track_ids def _align_tracks_tables_with_positions( self, ) -> tuple[list[Position], list[pd.DataFrame]]: + """Parse positions in ome-zarr store containing tracking information + and assemble tracks tables for each position. + + Returns + ------- + tuple[list[Position], list[pd.DataFrame]] + List of positions and list of tracks tables for each position + """ positions = [] tracks_tables = [] images_plate = open_ome_zarr(self.data_path) @@ -290,6 +337,9 @@ def _setup_predict(self, dataset_settings: dict): initial_yx_patch_size=self.initial_yx_patch_size, anchor_transform=Compose(self.normalizations), fit=False, + predict_cells=self.predict_cells, + include_fov_names=self.include_fov_names, + include_track_ids=self.include_track_ids, **dataset_settings, ) diff --git a/viscy/light/engine.py b/viscy/light/engine.py index 8ffc89e9..326639d6 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -581,10 +581,13 @@ def __init__( in_channels: int = 1, example_input_yx_shape: Sequence[int] = (256, 256), in_stack_depth: int = 15, - stem_kernel_size: tuple[int, int, int] = (5, 3, 3), + stem_kernel_size: tuple[int, int, int] = (5, 4, 4), embedding_len: int = 256, predict: bool = False, tracks_path: str = "data/tracks", + features_output_path: str = "", + projections_output_path: str = "", + metadata_output_path: str = "", ) -> None: super().__init__() self.loss_function = loss_function @@ -602,6 +605,9 @@ def __init__( self.processed_order = [] self.predictions = [] self.tracks_path = tracks_path + self.features_output_path = features_output_path + self.projections_output_path = projections_output_path + self.metadata_output_path = metadata_output_path self.model = ContrastiveEncoder( backbone=backbone, in_channels=in_channels, @@ -734,8 +740,17 @@ def configure_optimizers(self): optimizer = Adam(self.parameters(), lr=self.lr) return optimizer + def on_predict_start(self) -> None: + if not ( + self.features_output_path + and self.projections_output_path + and self.metadata_output_path + ): + raise ValueError( + "Output paths for features, projections, and metadata must be provided." + ) + def predict_step(self, batch: TripletSample, batch_idx, dataloader_idx=0): - print("running predict step!") """Prediction step for extracting embeddings.""" features, projections = self.model(batch["anchor"]) index = batch["index"] @@ -780,12 +795,8 @@ def on_predict_epoch_end(self) -> None: combined_features = np.array(combined_features) combined_projections = np.array(combined_projections) - np.save("embeddings2/multi_resnet_predicted_features.npy", combined_features) - print("Saved features with shape", combined_features.shape) - np.save( - "embeddings2/multi_resnet_predicted_projections.npy", combined_projections - ) - print("Saved projections with shape", combined_projections.shape) + np.save(self.features_output_path, combined_features) + np.save(self.projections_output_path, combined_projections) rows, columns, fovs, track_ids, timesteps = zip(*accumulated_data) df = pd.DataFrame( @@ -798,4 +809,4 @@ def on_predict_epoch_end(self) -> None: } ) - df.to_csv("embeddings2/multi_resnet_predicted_metadata.csv", index=False) + df.to_csv(self.metadata_output_path, index=False) diff --git a/viscy/representation/contrastive.py b/viscy/representation/contrastive.py index 1fdb8d45..1dc269ce 100644 --- a/viscy/representation/contrastive.py +++ b/viscy/representation/contrastive.py @@ -11,7 +11,7 @@ def __init__( backbone: str = "convnext_tiny", in_channels: int = 2, in_stack_depth: int = 12, - stem_kernel_size: tuple[int, int, int] = (5, 3, 3), + stem_kernel_size: tuple[int, int, int] = (5, 4, 4), embedding_len: int = 256, stem_stride: int = 2, predict: bool = False, diff --git a/viscy/unet/networks/unext2.py b/viscy/unet/networks/unext2.py index fc8a19a2..c2403fc9 100644 --- a/viscy/unet/networks/unext2.py +++ b/viscy/unet/networks/unext2.py @@ -99,12 +99,12 @@ def __init__( in_channels: int, in_stack_depth: int, in_channels_encoder: int, - stem_kernel_size: tuple[int, int, int] = (5, 3, 3), - stem_stride: int = 2, # stride for the kernel + stem_kernel_size: tuple[int, int, int] = (5, 4, 4), + stem_stride: tuple[int, int, int] = (5, 4, 4), # stride for the kernel ) -> None: super().__init__() stem3d_out_channels = self.compute_stem_channels( - in_stack_depth, stem_kernel_size, stem_stride, in_channels_encoder + in_stack_depth, stem_kernel_size, stem_stride[0], in_channels_encoder ) self.conv = nn.Conv3d( @@ -115,9 +115,11 @@ def __init__( ) def compute_stem_channels( - self, in_stack_depth, stem_kernel_size, stem_stride, in_channels_encoder + self, in_stack_depth, stem_kernel_size, stem_stride_depth, in_channels_encoder ): - stem3d_out_depth = (in_stack_depth - stem_kernel_size[0]) // stem_stride + 1 + stem3d_out_depth = ( + in_stack_depth - stem_kernel_size[0] + ) // stem_stride_depth + 1 stem3d_out_channels = in_channels_encoder // stem3d_out_depth channel_mismatch = in_channels_encoder - stem3d_out_depth * stem3d_out_channels if channel_mismatch != 0: