diff --git a/viscy/applications/contrastive_phenotyping/graphs_ConvNeXt_ResNet.py b/viscy/applications/contrastive_phenotyping/graphs_ConvNeXt_ResNet.py index c6be84e4..3d406462 100644 --- a/viscy/applications/contrastive_phenotyping/graphs_ConvNeXt_ResNet.py +++ b/viscy/applications/contrastive_phenotyping/graphs_ConvNeXt_ResNet.py @@ -3,16 +3,74 @@ import torch from viscy.representation.contrastive import ContrastiveEncoder import torchview +import timm +# uncomment if you are using jupyter and want to autoreload the updated code. +# %load_ext autoreload +# %autoreload 2 -%load_ext autoreload -%autoreload 2 -# %% Initialize the model and log the graph. -contra_model = ContrastiveEncoder(backbone = "convnext_tiny") # other options: convnext_tiny resnet50 -print(contra_model) +# %% Explore model graphs returned by timm + +convnextv1 = timm.create_model( + "convnext_tiny", pretrained=False, features_only=False, num_classes=200 +) +print(convnextv1) +output = convnextv1(torch.randn(1, 3, 256, 256)) +print(output.shape) +# %% Initialize the model and log the graph: convnext. +in_channels = 1 +in_stack_depth = 15 + +contrastive_convnext1 = ContrastiveEncoder( + backbone="convnext_tiny", in_channels=in_channels, in_stack_depth=in_stack_depth +) +print(contrastive_convnext1) + + +projections, embedding = contrastive_convnext1( + torch.randn(1, in_channels, in_stack_depth, 256, 256) +) +print( + f"shape of projections:{projections.shape}, shape of embedding: {embedding.shape}" +) +# %% + +in_channels = 3 +in_stack_depth = 18 + +contrastive_convnext2 = ContrastiveEncoder( + backbone="convnextv2_tiny", in_channels=in_channels, in_stack_depth=in_stack_depth +) +print(contrastive_convnext2) +embedding, projections = contrastive_convnext2( + torch.randn(1, in_channels, in_stack_depth, 256, 256) +) +print( + f"shape of projections:{projections.shape}, shape of embedding: {embedding.shape}" +) + +# %% +in_channels = 10 +in_stack_depth = 12 +contrastive_resnet = ContrastiveEncoder( + backbone="resnet50", + in_channels=in_channels, + in_stack_depth=in_stack_depth, + embedding_len=256, +) +print(contrastive_resnet) +embedding, projections = contrastive_resnet( + torch.randn(1, in_channels, in_stack_depth, 256, 256) +) +print( + f"shape of projections:{projections.shape}, shape of embedding: {embedding.shape}" +) + +# %% +plot_model = contrastive_resnet model_graph = torchview.draw_graph( - contra_model, - torch.randn(1, 2, 15, 224, 224), + plot_model, + input_size=(20, in_channels, in_stack_depth, 224, 224), depth=3, # adjust depth to zoom in. device="cpu", ) @@ -21,7 +79,9 @@ model_graph.visual_graph # %% Initialize a resent50 model and log the graph. -contra_model = ContrastiveEncoder(backbone = "resnet50", in_stack_depth = 16, stem_kernel_size = (4, 3, 3)) # note that the resnet first layer takes 64 channels (so we can't have multiples of 3) +contra_model = ContrastiveEncoder( + backbone="resnet50", in_stack_depth=16, stem_kernel_size=(4, 3, 3) +) # note that the resnet first layer takes 64 channels (so we can't have multiples of 3) print(contra_model) model_graph = torchview.draw_graph( contra_model, @@ -41,7 +101,7 @@ # %% model_graph = torchview.draw_graph( contrastive_module.encoder, - torch.randn(1, 2, 15, 200, 200), + torch.randn(1, in_channels, in_stack_depth, 200, 200), depth=3, # adjust depth to zoom in. device="cpu", ) @@ -49,7 +109,6 @@ model_graph.visual_graph # %% Playground -import timm available_models = timm.list_models(pretrained=True) diff --git a/viscy/applications/contrastive_phenotyping/predict.py b/viscy/applications/contrastive_phenotyping/predict.py index a5447735..adfe6525 100644 --- a/viscy/applications/contrastive_phenotyping/predict.py +++ b/viscy/applications/contrastive_phenotyping/predict.py @@ -6,7 +6,8 @@ from lightning.pytorch.strategies import DDPStrategy from viscy.data.hcs import ContrastiveDataModule from viscy.light.engine import ContrastiveModule -import os +import os + def main(hparams): # Set paths @@ -35,7 +36,7 @@ def main(hparams): batch_size=batch_size, z_range=z_range, predict_base_path=predict_base_path, - analysis=True, # for self-supervised results + analysis=True, # for self-supervised results ) data_module.setup(stage="predict") @@ -54,7 +55,7 @@ def main(hparams): # Run prediction predictions = trainer.predict(model, datamodule=data_module) - + # Collect features and projections features_list = [] projections_list = [] @@ -66,13 +67,33 @@ def main(hparams): all_features = np.concatenate(features_list, axis=0) all_projections = np.concatenate(projections_list, axis=0) - # for saving visualizations embeddings + # for saving visualizations embeddings base_dir = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/5-finaltrack/test_visualizations" - features_path = os.path.join(base_dir, 'B', '4', '2', 'before_projected_embeddings', 'test_epoch88_predicted_features.npy') - projections_path = os.path.join(base_dir, 'B', '4', '2', 'projected_embeddings', 'test_epoch88_predicted_projections.npy') + features_path = os.path.join( + base_dir, + "B", + "4", + "2", + "before_projected_embeddings", + "test_epoch88_predicted_features.npy", + ) + projections_path = os.path.join( + base_dir, + "B", + "4", + "2", + "projected_embeddings", + "test_epoch88_predicted_projections.npy", + ) - np.save("/hpc/mydata/alishba.imran/VisCy/viscy/applications/contrastive_phenotyping/ss1_epoch97_predicted_features.npy", all_features) - np.save("/hpc/mydata/alishba.imran/VisCy/viscy/applications/contrastive_phenotyping/ss1_epoch97_predicted_projections.npy", all_projections) + np.save( + "/hpc/mydata/alishba.imran/VisCy/viscy/applications/contrastive_phenotyping/ss1_epoch97_predicted_features.npy", + all_features, + ) + np.save( + "/hpc/mydata/alishba.imran/VisCy/viscy/applications/contrastive_phenotyping/ss1_epoch97_predicted_projections.npy", + all_projections, + ) if __name__ == "__main__": @@ -89,4 +110,4 @@ def main(hparams): parser.add_argument("--num_nodes", type=int, default=2) parser.add_argument("--log_every_n_steps", type=int, default=1) args = parser.parse_args() - main(args) \ No newline at end of file + main(args) diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index d62a4e13..dcbc7dd7 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -6,7 +6,8 @@ from glob import glob from pathlib import Path from typing import Callable, Literal, Optional, Sequence, Union -#import pytorch_lightning as pl + +# import pytorch_lightning as pl from monai.transforms import MapTransform import random @@ -15,10 +16,22 @@ import zarr from imageio import imread from iohub.ngff import ImageArray, Plate, Position, open_ome_zarr -#from lightning.pytorch import LightningDataModule + +# from lightning.pytorch import LightningDataModule from monai.data import set_track_meta from monai.data.utils import collate_meta_tensor -from monai.transforms import Compose, RandAdjustContrastd, RandAffined, RandGaussianNoised, RandGaussianSmoothd, RandScaleIntensityd, RandShiftIntensityd, RandZoomd, Rand3DElasticd, RandGaussianSharpend +from monai.transforms import ( + Compose, + RandAdjustContrastd, + RandAffined, + RandGaussianNoised, + RandGaussianSmoothd, + RandScaleIntensityd, + RandShiftIntensityd, + RandZoomd, + Rand3DElasticd, + RandGaussianSharpend, +) from torch import Tensor from torch.utils.data import DataLoader, Dataset from viscy.data.typing import ChannelMap, HCSStackIndex, NormMeta, Sample @@ -27,11 +40,13 @@ import pandas as pd import warnings from lightning.pytorch import LightningDataModule, LightningModule, Trainer + # from viscy.data.typing import Optional from pathlib import Path _logger = logging.getLogger("lightning.pytorch") + def _ensure_channel_list(str_or_seq: str | Sequence[str]) -> list[str]: """ Ensure channel argument is a list of strings. @@ -591,6 +606,7 @@ def _train_transform(self) -> list[Callable]: _logger.debug(f"Training augmentations: {self.augmentations}") return list(self.augmentations) + # dataloader for organelle phenotyping class ContrastiveDataset(Dataset): def __init__( @@ -614,16 +630,21 @@ def __init__( self.ds = self.open_zarr_store(self.base_path) self.positions = list(self.ds.positions()) self.timesteps_df = pd.read_csv(timesteps_csv_path) - self.channel_indices = [self.ds.channel_names.index(channel) for channel in self.channel_names] + self.channel_indices = [ + self.ds.channel_names.index(channel) for channel in self.channel_names + ] print("channel indices!") print(self.channel_indices) print(f"Initialized dataset with {len(self.positions)} positions.") # self.statistics = self.compute_statistics() # print("Channel Statistics:", self.statistics) - + def compute_statistics(self): - stats = {channel: {'mean': 0, 'sum_sq_diff': 0, 'min': np.inf, 'max': -np.inf} for channel in self.channel_names} + stats = { + channel: {"mean": 0, "sum_sq_diff": 0, "min": np.inf, "max": -np.inf} + for channel in self.channel_names + } count = 0 total_elements = 0 @@ -633,23 +654,25 @@ def compute_statistics(self): for i, channel in enumerate(self.channel_names): channel_data = data[i] mean = np.mean(channel_data) - stats[channel]['mean'] += mean - stats[channel]['min'] = min(stats[channel]['min'], np.min(channel_data)) - stats[channel]['max'] = max(stats[channel]['max'], np.max(channel_data)) - stats[channel]['sum_sq_diff'] += np.sum((channel_data - mean) ** 2) + stats[channel]["mean"] += mean + stats[channel]["min"] = min(stats[channel]["min"], np.min(channel_data)) + stats[channel]["max"] = max(stats[channel]["max"], np.max(channel_data)) + stats[channel]["sum_sq_diff"] += np.sum((channel_data - mean) ** 2) count += 1 total_elements += np.prod(channel_data.shape) for channel in self.channel_names: - stats[channel]['mean'] /= count - stats[channel]['std'] = np.sqrt(stats[channel]['sum_sq_diff'] / total_elements) - del stats[channel]['sum_sq_diff'] - + stats[channel]["mean"] /= count + stats[channel]["std"] = np.sqrt( + stats[channel]["sum_sq_diff"] / total_elements + ) + del stats[channel]["sum_sq_diff"] + print("done!") return stats def open_zarr_store(self, path, layout="hcs", mode="r"): - #print(f"Opening Zarr store at {path} with layout '{layout}' and mode '{mode}'") + # print(f"Opening Zarr store at {path} with layout '{layout}' and mode '{mode}'") return open_ome_zarr(path, layout=layout, mode=mode) def __len__(self): @@ -671,7 +694,7 @@ def __getitem__(self, idx): negative_idx = random.randint(0, self.__len__() - 1) negative_position_path = self.positions[negative_idx][0] negative_data = self.load_data(negative_position_path) - negative_data = self.normalize_data(negative_data) + negative_data = self.normalize_data(negative_data) negative_data = self.apply_channel_transforms(negative_data) negative_data = self.normalize_data(negative_data) @@ -698,8 +721,8 @@ def load_data(self, position_path): data = self.restructure_data(zarr_array, position_path) data = data[self.channel_indices, self.z_range[0] : self.z_range[1], :, :] - #print("shape after!") - #print(data.shape) + # print("shape after!") + # print(data.shape) return data def restructure_data(self, data, position_path): @@ -743,16 +766,17 @@ def normalize_data(self, data): std = np.std(channel_data) normalized_data[i] = (channel_data - mean) / (std + 1e-6) return normalized_data - + def apply_channel_transforms(self, data): transformed_data = np.empty_like(data) for i, channel_name in enumerate(self.channel_names): channel_data = data[i] transform = self.transform[channel_name] transformed_data[i] = transform({"image": channel_data})["image"] - #print(f"transformed {channel_name}") + # print(f"transformed {channel_name}") return transformed_data + def get_transforms(): rfp_transforms = Compose( [ @@ -798,10 +822,8 @@ def get_transforms(): ] ) - return { - "RFP": rfp_transforms, - "Phase3D": phase_transforms - } + return {"RFP": rfp_transforms, "Phase3D": phase_transforms} + class ContrastiveDataModule(LightningDataModule): def __init__( @@ -817,7 +839,7 @@ def __init__( train_split_ratio: float = 0.64, val_split_ratio: float = 0.16, batch_size: int = 4, - num_workers: int = 15, #for analysis purposes reduced to 1 + num_workers: int = 15, # for analysis purposes reduced to 1 z_range: tuple[int, int] = None, analysis: bool = False, ): @@ -859,10 +881,12 @@ def setup(self, stage: str = None): test_size = len(dataset) - train_size - val_size self.train_dataset, self.val_dataset, self.test_dataset = ( - torch.utils.data.random_split(dataset, [train_size, val_size, test_size]) + torch.utils.data.random_split( + dataset, [train_size, val_size, test_size] + ) ) - # setup prediction dataset + # setup prediction dataset if stage == "predict" and self.predict_base_path and not self.analysis: print("setting up!") self.predict_dataset = PredictDataset( @@ -887,15 +911,15 @@ def setup(self, stage: str = None): z_range=self.z_range, analysis=True, ) - + def train_dataloader(self): return DataLoader( self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, - prefetch_factor=2, - persistent_workers=True + prefetch_factor=2, + persistent_workers=True, ) def val_dataloader(self): @@ -904,8 +928,8 @@ def val_dataloader(self): batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, - prefetch_factor=2, - persistent_workers=True + prefetch_factor=2, + persistent_workers=True, ) def test_dataloader(self): @@ -914,8 +938,8 @@ def test_dataloader(self): batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, - prefetch_factor=2, - persistent_workers=True + prefetch_factor=2, + persistent_workers=True, ) def predict_dataloader(self): @@ -928,10 +952,10 @@ def predict_dataloader(self): return DataLoader( self.predict_dataset, batch_size=self.batch_size, - shuffle=False, # False shuffle for prediction + shuffle=False, # False shuffle for prediction num_workers=self.num_workers, - prefetch_factor=2, - persistent_workers=True + prefetch_factor=2, + persistent_workers=True, ) @@ -956,8 +980,14 @@ def __init__( self.ds = self.open_zarr_store(self.base_path) self.timesteps_csv_path = timesteps_csv_path self.timesteps_df = pd.read_csv(timesteps_csv_path) - self.positions = list(self.ds.positions()) if not analysis else self.filter_positions_from_csv() - self.channel_indices = [self.ds.channel_names.index(channel) for channel in self.channel_names] + self.positions = ( + list(self.ds.positions()) + if not analysis + else self.filter_positions_from_csv() + ) + self.channel_indices = [ + self.ds.channel_names.index(channel) for channel in self.channel_names + ] self.current_position_idx = 0 self.current_timestep_idx = 0 self.analysis = analysis @@ -966,24 +996,32 @@ def __init__( print(f"Initialized predict dataset with {len(self.positions)} positions.") self.position_to_timesteps = { - position: self.timesteps_df[self.timesteps_df.apply( - lambda x: f"{x['Row']}/{x['Column']}/fov{x['FOV']}cell{x['Cell ID']}", - axis=1) == position]['Timestep'].values + position: self.timesteps_df[ + self.timesteps_df.apply( + lambda x: f"{x['Row']}/{x['Column']}/fov{x['FOV']}cell{x['Cell ID']}", + axis=1, + ) + == position + ]["Timestep"].values for position in self.positions } - #print(self.positions[0]) + # print(self.positions[0]) def open_zarr_store(self, path, layout="hcs", mode="r"): return open_ome_zarr(path, layout=layout, mode=mode) def filter_positions_from_csv(self): - unique_positions = self.timesteps_df[['Row', 'Column', 'FOV', 'Cell ID']].drop_duplicates() + unique_positions = self.timesteps_df[ + ["Row", "Column", "FOV", "Cell ID"] + ].drop_duplicates() valid_positions = [] for idx, row in unique_positions.iterrows(): - position_path = f"{row['Row']}/{row['Column']}/fov{row['FOV']}cell{row['Cell ID']}" + position_path = ( + f"{row['Row']}/{row['Column']}/fov{row['FOV']}cell{row['Cell ID']}" + ) valid_positions.append(position_path) - #print(valid_positions) + # print(valid_positions) return valid_positions # def get_positions_from_csv(self): @@ -994,10 +1032,12 @@ def filter_positions_from_csv(self): # positions.append((position_path, row['Random Timestep'])) # #print(positions) # return positions - + def __len__(self): if self.analysis: - return sum(len(timesteps) for timesteps in self.position_to_timesteps.values()) + return sum( + len(timesteps) for timesteps in self.position_to_timesteps.values() + ) else: return len(self.positions) @@ -1016,13 +1056,17 @@ def __getitem__(self, idx): accumulated_idx += len(timesteps) if timestep is None or position_path is None: - raise ValueError(f"Timestep or position_path could not be determined for index: {idx}") + raise ValueError( + f"Timestep or position_path could not be determined for index: {idx}" + ) - #print(f"Analysis mode: Index: {idx}, Position: {position_path}, Timestep: {timestep}") + # print(f"Analysis mode: Index: {idx}, Position: {position_path}, Timestep: {timestep}") data = self.load_data(position_path, timestep) else: if idx >= len(self.positions): - raise IndexError(f"Index {idx} out of range for positions of length {len(self.positions)}") + raise IndexError( + f"Index {idx} out of range for positions of length {len(self.positions)}" + ) position_path = self.positions[idx] print(f"Non-analysis mode: Index: {idx}, Position: {position_path}") @@ -1031,16 +1075,21 @@ def __getitem__(self, idx): data = self.normalize_data(data) return torch.tensor(data, dtype=torch.float32), position_path - # double check printing order + # double check printing order def load_data(self, position_path, timestep=None): position = self.ds[position_path] - print(f"Loading data for position path: {position_path}" + (f" at Timestep: {timestep}" if timestep is not None else "")) + print( + f"Loading data for position path: {position_path}" + + (f" at Timestep: {timestep}" if timestep is not None else "") + ) zarr_array = position["0"][:] if timestep is None: raise ValueError("Timestep must be provided for analysis") - data = zarr_array[timestep, self.channel_indices, self.z_range[0]:self.z_range[1], :, :] + data = zarr_array[ + timestep, self.channel_indices, self.z_range[0] : self.z_range[1], :, : + ] return data def normalize_data(self, data): diff --git a/viscy/light/engine.py b/viscy/light/engine.py index bb4c89da..de4863ea 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -588,7 +588,7 @@ def __init__( self.test_metrics = [] self.processed_order = [] - self.encoder = ContrastiveEncoder( + self.model = ContrastiveEncoder( backbone=backbone, in_channels=in_channels, in_stack_depth=in_stack_depth, @@ -611,7 +611,7 @@ def __init__( def forward(self, x: Tensor) -> Tensor: """Forward pass of the model.""" - projections = self.encoder(x) + _, projections = self.model(x) return projections # features is without projection head and projects is with projection head @@ -714,10 +714,12 @@ def training_step( """Training step of the model.""" anchor, pos_img, neg_img = batch - emb_anchor = self.encoder(anchor) - emb_pos = self.encoder(pos_img) - emb_neg = self.encoder(neg_img) - loss = self.loss_function(emb_anchor, emb_pos, emb_neg) + _, anchorProjection = self.model(anchor) + _, negativeProjection = self.model(neg_img) + _, positiveProjection = self.model(pos_img) + loss = self.loss_function( + anchorProjection, positiveProjection, negativeProjection + ) self.log("train/loss_step", loss, on_step=True, prog_bar=True, logger=True) @@ -727,7 +729,9 @@ def training_step( anchor, pos_img, neg_img, self.current_epoch, "training_images" ) - self.log_metrics(emb_anchor, emb_pos, emb_neg, "train") + self.log_metrics( + anchorProjection, positiveProjection, negativeProjection, "train" + ) # self.print_embedding_norms(emb_anchor, emb_pos, emb_neg, 'train') self.training_step_outputs.append(loss) @@ -777,10 +781,12 @@ def validation_step( """Validation step of the model.""" anchor, pos_img, neg_img = batch - emb_anchor = self.encoder(anchor) - emb_pos = self.encoder(pos_img) - emb_neg = self.encoder(neg_img) - loss = self.loss_function(emb_anchor, emb_pos, emb_neg) + _, anchorProjection = self.model(anchor) + _, positiveProjection = self.model(pos_img) + _, negativeProjection = self.model(neg_img) + loss = self.loss_function( + anchorProjection, positiveProjection, negativeProjection + ) self.log("val/loss_step", loss, on_step=True, prog_bar=True, logger=True) @@ -790,7 +796,9 @@ def validation_step( anchor, pos_img, neg_img, self.current_epoch, "validation_images" ) - self.log_metrics(emb_anchor, emb_pos, emb_neg, "val") + self.log_metrics( + anchorProjection, positiveProjection, negativeProjection, "val" + ) self.validation_step_outputs.append(loss) return {"loss": loss} @@ -839,14 +847,16 @@ def test_step( """Test step of the model.""" anchor, pos_img, neg_img = batch - emb_anchor = self.encoder(anchor) - emb_pos = self.encoder(pos_img) - emb_neg = self.encoder(neg_img) - loss = self.loss_function(emb_anchor, emb_pos, emb_neg) + _, anchorProjection = self.model(anchor) + _, positiveProjection = self.model(pos_img) + _, negativeProjection = self.model(neg_img) + loss = self.loss_function( + anchorProjection, positiveProjection, negativeProjection + ) self.log("test/loss_step", loss, on_step=True, prog_bar=True, logger=True) - self.log_metrics(emb_anchor, emb_pos, emb_neg, "test") + self.log_metrics(emb_anchor, positiveProjection, negativeProjection, "test") self.test_step_outputs.append(loss) return {"loss": loss} @@ -912,7 +922,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): print("running predict step!") """Prediction step for extracting embeddings.""" x, position_info = batch - features, projections = self.encoder(x) + features, projections = self.model(x) self.processed_order.extend(position_info) return features, projections diff --git a/viscy/representation/contrastive.py b/viscy/representation/contrastive.py index 56188214..71287b5d 100644 --- a/viscy/representation/contrastive.py +++ b/viscy/representation/contrastive.py @@ -5,7 +5,8 @@ # from viscy.unet.networks.resnet import resnetStem # Currently identical to resnetStem, but could be different in the future. from viscy.unet.networks.unext2 import UNeXt2Stem -from viscy.unet.networks.unext2 import UNeXt2StemResNet +from viscy.unet.networks.unext2 import StemDepthtoChannels + class ContrastiveEncoder(nn.Module): def __init__( @@ -15,11 +16,13 @@ def __init__( in_stack_depth: int = 15, stem_kernel_size: tuple[int, int, int] = (5, 3, 3), embedding_len: int = 256, + stem_stride: int = 2, predict: bool = False, ): super().__init__() self.predict = predict + self.backbone = backbone """ ContrastiveEncoder network that uses ConvNext and ResNet backbons from timm. @@ -32,164 +35,60 @@ def __init__( - embedding_len (int): Length of the embedding. Default is 1000. """ - if in_stack_depth % stem_kernel_size[0] != 0: - raise ValueError( - f"Input stack depth {in_stack_depth} is not divisible " - f"by stem kernel depth {stem_kernel_size[0]}." - ) - - # encoder - self.model = timm.create_model( + # encoder from timm + encoder = timm.create_model( backbone, pretrained=True, features_only=False, drop_path_rate=0.2, - num_classes=4 * embedding_len, + num_classes=3 * embedding_len, ) - if "convnext_tiny" in backbone: - print("Using ConvNext backbone.") - # replace the stem designed for RGB images with a stem designed to handle 3D multi-channel input. - in_channels_encoder = self.model.stem[0].out_channels - stem = UNeXt2Stem( - in_channels=in_channels, - out_channels=in_channels_encoder, - kernel_size=stem_kernel_size, - in_stack_depth=in_stack_depth, - ) - self.model.stem = stem + # Do encoder surgery and setup stem and projection head. - self.model.head.fc = nn.Sequential( - self.model.head.fc, - nn.ReLU(inplace=True), - nn.Linear(4 * embedding_len, embedding_len), - ) - - """ - head of convnext - ------------------- - (head): NormMlpClassifierHead( - (global_pool): SelectAdaptivePool2d(pool_type=avg, flatten=Identity()) - (norm): LayerNorm2d((768,), eps=1e-06, elementwise_affine=True) - (flatten): Flatten(start_dim=1, end_dim=-1) - (pre_logits): Identity() - (drop): Dropout(p=0.0, inplace=False) - (fc): Linear(in_features=768, out_features=1024, bias=True) + if "convnext" in backbone: + # replace the stem designed for RGB images with a stem designed to handle 3D multi-channel input. + in_channels_encoder = encoder.stem[0].out_channels + # in_channels_encoder can be 96 or 64, and in_channels can be 1,2,or 3. + # Remove the convolution layer of stem, but keep the layernorm. + encoder.stem[0] = nn.Identity() - head of convnext for contrastive learning - ---------------------------- - (head): NormMlpClassifierHead( - (global_pool): SelectAdaptivePool2d(pool_type=avg, flatten=Identity()) - (norm): LayerNorm2d((768,), eps=1e-06, elementwise_affine=True) - (flatten): Flatten(start_dim=1, end_dim=-1) - (pre_logits): Identity() - (drop): Dropout(p=0.0, inplace=False) - (fc): Sequential( - (0): Linear(in_features=768, out_features=1024, bias=True) - (1): ReLU(inplace=True) - (2): Linear(in_features=1024, out_features=256, bias=True) + # Save projection head separately and erase the projection head contained within the encoder. + projection = nn.Sequential( + nn.Linear(encoder.head.fc.in_features, 3 * embedding_len), + nn.ReLU(inplace=True), + nn.Linear(3 * embedding_len, embedding_len), ) - """ + encoder.head.fc = nn.Identity() # TO-DO: need to debug further elif "resnet" in backbone: - print("Using ResNet backbone.") - # Adapt stem and projection head of resnet here. - # replace the stem designed for RGB images with a stem designed to handle 3D multi-channel input. - in_channels_encoder = self.model.conv1.out_channels - print("in_channels_encoder", in_channels_encoder) + in_channels_encoder = encoder.conv1.out_channels + encoder.conv1 = nn.Identity() - out_channels_encoder = self.model.bn1.num_features - print("out_channels_bn", out_channels_encoder) - - stem = UNeXt2StemResNet( - in_channels=in_channels, - out_channels=out_channels_encoder, - kernel_size=stem_kernel_size, - in_stack_depth=in_stack_depth, - ) - self.model.conv1 = stem - - self.model.bn1 = nn.BatchNorm2d(out_channels_encoder) - - print(f'Updated out_channels_encoder: {out_channels_encoder}') - - self.model.fc = nn.Sequential( - nn.Linear(self.model.fc.in_features, 4 * embedding_len), + projection = nn.Sequential( + nn.Linear(encoder.fc.in_features, 3 * embedding_len), nn.ReLU(inplace=True), - nn.Linear(4 * embedding_len, embedding_len), + nn.Linear(3 * embedding_len, embedding_len), ) + encoder.fc = nn.Identity() - # self.model.fc = nn.Sequential( - # nn.Linear(self.model.fc.in_features, 1024), - # nn.ReLU(inplace=True), - # nn.Linear(1024, embedding_len), - # ) - - """ - head of resnet - ------------------- - (global_pool): SelectAdaptivePool2d(pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1)) - (fc): Linear(in_features=2048, out_features=1024, bias=True) - - - head of resnet for contrastive learning - ---------------------------- - (global_pool): SelectAdaptivePool2d(pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1)) - (fc): Sequential( - (0): Linear(in_features=2048, out_features=1024, bias=True) - (1): ReLU(inplace=True) - (2): Linear(in_features=1024, out_features=256, bias=True) - """ + # Create a new stem that can handle 3D multi-channel input. + self.stem = StemDepthtoChannels( + in_channels, in_stack_depth, in_channels_encoder + ) + # Append modified encoder. + self.encoder = encoder + # Append modified projection head. + self.projection = projection def forward(self, x): - if self.predict: - print("running predict forward!") - x = self.model.stem(x) - x = self.model.stages[0](x) - x = self.model.stages[1](x) - x = self.model.stages[2](x) - x = self.model.stages[3](x) - x = self.model.head.global_pool(x) - x = self.model.head.norm(x) - x = self.model.head.flatten(x) - features_before_projection = self.model.head.drop(x) - projections = self.model.head.fc(features_before_projection) - features_before_projection = F.normalize( - features_before_projection, p=2, dim=1 - ) - projections = F.normalize(projections, p=2, dim=1) # L2 normalization - print(features_before_projection.shape, projections.shape) - return features_before_projection, projections - # feature is without projection head - else: - print("running forward without predict!") - print("Running forward without predict!") - x = self.model.conv1(x) - print(f'After conv1: {x.shape}') # Debugging statement - x = self.model.bn1(x) - print(f'After bn1: {x.shape}') # Debugging statement - x = self.model.act1(x) - print(f'After act1: {x.shape}') # Debugging statement - x = self.model.maxpool(x) - print(f'After maxpool: {x.shape}') # Debugging statement - x = self.model.layer1(x) - print(f'After layer1: {x.shape}') # Debugging statement - x = self.model.layer2(x) - print(f'After layer2: {x.shape}') # Debugging statement - x = self.model.layer3(x) - print(f'After layer3: {x.shape}') # Debugging statement - x = self.model.layer4(x) - print(f'After layer4: {x.shape}') # Debugging statement - x = self.model.global_pool(x) - print(f'After global_pool: {x.shape}') # Debugging statement - x = x.flatten(1) - x = self.model.fc(x) - print(f'After fc: {x.shape}') # Debugging statement - x = F.normalize(x, p=2, dim=1) # L2 normalization - return x - - # projections = self.model(x) - # projections = F.normalize(projections, p=2, dim=1) # L2 normalization - # return projections + x = self.stem(x) + embedding = self.encoder(x) + projections = self.projection(embedding) + projections = F.normalize(projections, p=2, dim=1) + return ( + embedding, + projections, + ) # Compute the loss on projections, analyze the embeddings. diff --git a/viscy/unet/networks/unext2.py b/viscy/unet/networks/unext2.py index b3a685cb..fc8a19a2 100644 --- a/viscy/unet/networks/unext2.py +++ b/viscy/unet/networks/unext2.py @@ -90,28 +90,47 @@ def forward(self, x: Tensor): # return a view when possible (contiguous) return x.reshape(b, c * d, h, w) -class UNeXt2StemResNet(nn.Module): - """Stem for ResNet in ContrastiveEncoder networks.""" + +class StemDepthtoChannels(nn.Module): + """Stem with 3D convolution that maps depth to channels.""" def __init__( self, in_channels: int, - out_channels: int, - kernel_size: tuple[int, int, int], in_stack_depth: int, + in_channels_encoder: int, + stem_kernel_size: tuple[int, int, int] = (5, 3, 3), + stem_stride: int = 2, # stride for the kernel ) -> None: super().__init__() + stem3d_out_channels = self.compute_stem_channels( + in_stack_depth, stem_kernel_size, stem_stride, in_channels_encoder + ) + self.conv = nn.Conv3d( in_channels=in_channels, - out_channels=out_channels, # matches the expected BatchNorm2d input channels - kernel_size=kernel_size, - stride=kernel_size, + out_channels=stem3d_out_channels, + kernel_size=stem_kernel_size, + stride=stem_stride, ) + def compute_stem_channels( + self, in_stack_depth, stem_kernel_size, stem_stride, in_channels_encoder + ): + stem3d_out_depth = (in_stack_depth - stem_kernel_size[0]) // stem_stride + 1 + stem3d_out_channels = in_channels_encoder // stem3d_out_depth + channel_mismatch = in_channels_encoder - stem3d_out_depth * stem3d_out_channels + if channel_mismatch != 0: + raise ValueError( + f"Stem needs to output {channel_mismatch} more channels to match the encoder. Adjust the in_stack_depth." + ) + return stem3d_out_channels + def forward(self, x: Tensor): x = self.conv(x) b, c, d, h, w = x.shape - print(f'After Conv3d: {x.shape}') + # project Z/depth into channels + # return a view when possible (contiguous) return x.reshape(b, c * d, h, w)