diff --git a/applications/contrastive_phenotyping/demo_fit.py b/applications/contrastive_phenotyping/demo_fit.py new file mode 100644 index 00000000..27f9f532 --- /dev/null +++ b/applications/contrastive_phenotyping/demo_fit.py @@ -0,0 +1,45 @@ +from lightning.pytorch import Trainer +from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.loggers import TensorBoardLogger +from lightning.pytorch.callbacks import DeviceStatsMonitor + + +from viscy.data.triplet import TripletDataModule +from viscy.light.engine import ContrastiveModule + + +def main(): + dm = TripletDataModule( + data_path="/hpc/projects/virtual_staining/2024_02_04_A549_DENV_ZIKV_timelapse/registered_chunked.zarr", + tracks_path="/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr", + source_channel=["Phase3D", "RFP"], + z_range=(20, 35), + batch_size=16, + num_workers=10, + initial_yx_patch_size=(384, 384), + final_yx_patch_size=(224, 224), + ) + model = ContrastiveModule( + backbone="convnext_tiny", + in_channels=2, + log_batches_per_epoch=2, + log_samples_per_batch=3, + ) + trainer = Trainer( + max_epochs=5, + limit_train_batches=10, + limit_val_batches=5, + logger=TensorBoardLogger( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/test_tb", + log_graph=True, + default_hp_metric=True, + ), + log_every_n_steps=1, + callbacks=[ModelCheckpoint()], + profiler="simple", # other options: "advanced" uses cprofiler, "pytorch" uses pytorch profiler. + ) + trainer.fit(model, dm) + + +if __name__ == "__main__": + main() diff --git a/tests/preprocessing/test_pixel_ratio.py b/tests/preprocessing/test_pixel_ratio.py index 2dce7afe..0251fefc 100644 --- a/tests/preprocessing/test_pixel_ratio.py +++ b/tests/preprocessing/test_pixel_ratio.py @@ -6,10 +6,8 @@ def test_sematic_class_weights(small_hcs_dataset): weights = sematic_class_weights(small_hcs_dataset, "GFP") assert weights.shape == (3,) - assert_allclose(weights[0], 1.0) + assert_allclose(weights[0], 1.0, atol=1e-5) # infinity assert weights[1] > 1.0 assert weights[2] > 1.0 - assert sematic_class_weights( - small_hcs_dataset, "GFP", num_classes=2 - ).shape == (2,) + assert sematic_class_weights(small_hcs_dataset, "GFP", num_classes=2).shape == (2,) diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index 5f915d3a..e8ba12fa 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -88,9 +88,11 @@ def _read_norm_meta(fov: Position) -> NormMeta | None: for channel, channel_values in norm_meta.items(): for level, level_values in channel_values.items(): for stat, value in level_values.items(): - norm_meta[channel][level][stat] = torch.tensor( - value, dtype=torch.float32 - ) + if isinstance(value, Tensor): + value = value.clone().float() + else: + value = torch.tensor(value, dtype=torch.float32) + norm_meta[channel][level][stat] = value return norm_meta diff --git a/viscy/light/engine.py b/viscy/light/engine.py index 1a4cd710..8ffc89e9 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -11,12 +11,8 @@ from matplotlib.pyplot import get_cmap from monai.optimizers import WarmupCosineSchedule from monai.transforms import DivisiblePad, Rotate90 -from pytorch_lightning.utilities import rank_zero_only from skimage.exposure import rescale_intensity from torch import Tensor, nn - -# from lightning.pytorch import LightningModule -# from lightning import LightningModule from torch.optim import Adam from torch.optim.lr_scheduler import ConstantLR from torchmetrics.functional import ( @@ -44,10 +40,6 @@ except ImportError: CellposeModel = None -try: - import wandb -except ImportError: - wandb = None _UNET_ARCHITECTURE = { "2D": Unet2d, @@ -57,6 +49,40 @@ "UNeXt2_2D": FullyConvolutionalMAE, } +_logger = logging.getLogger("lightning.pytorch") + + +def _detach_sample(imgs: Sequence[Tensor], log_samples_per_batch: int): + num_samples = min(imgs[0].shape[0], log_samples_per_batch) + samples = [] + for i in range(num_samples): + patches = [] + for img in imgs: + patch = img[i].detach().cpu().numpy() + patch = np.squeeze(patch[:, patch.shape[1] // 2]) + patches.append(patch) + samples.append(patches) + return samples + + +def _render_images(imgs: Sequence[Sequence[np.ndarray]], cmaps: list[str] = []): + images_grid = [] + for sample_images in imgs: + images_row = [] + for i, image in enumerate(sample_images): + if cmaps: + cm_name = cmaps[i] + else: + cm_name = "gray" if i == 0 else "inferno" + if image.ndim == 2: + image = image[np.newaxis] + for channel in image: + channel = rescale_intensity(channel, out_range=(0, 1)) + render = get_cmap(cm_name)(channel, bytes=True)[..., :3] + images_row.append(render) + images_grid.append(np.concatenate(images_row, axis=1)) + return np.concatenate(images_grid, axis=0) + class MixedLoss(nn.Module): """Mixed reconstruction loss. @@ -206,7 +232,7 @@ def training_step(self, batch: Sample | Sequence[Sample], batch_idx: int): batch_size += source.shape[0] if batch_idx < self.log_batches_per_epoch: self.training_step_outputs.extend( - self._detach_sample((source, target, pred)) + _detach_sample((source, target, pred), self.log_samples_per_batch) ) loss_step = torch.stack(losses).mean() self.log( @@ -237,7 +263,7 @@ def validation_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0 ) if batch_idx < self.log_batches_per_epoch: self.validation_step_outputs.extend( - self._detach_sample((source, target, pred)) + _detach_sample((source, target, pred), self.log_samples_per_batch) ) def test_step(self, batch: Sample, batch_idx: int): @@ -309,7 +335,7 @@ def _log_segmentation_metrics( pred_binary = pred_labels > 0 target_binary = target_labels > 0 coco_metrics = mean_average_precision(pred_labels, target_labels) - logging.debug(coco_metrics) + _logger.debug(coco_metrics) self.log_dict( { # semantic segmentation @@ -405,7 +431,7 @@ def on_test_start(self): # "Please install the metrics dependency with " # '`pip install viscy".[metrics]"`' # ) - logging.warning( + _logger.warning( "CellPose not installed. " "Please install the metrics dependency with " '`pip install viscy"[metrics]"`' @@ -441,27 +467,8 @@ def configure_optimizers(self): ) return [optimizer], [scheduler] - def _detach_sample(self, imgs: Sequence[Tensor]): - num_samples = min(imgs[0].shape[0], self.log_samples_per_batch) - return [ - [np.squeeze(img[i].detach().cpu().numpy().max(axis=1)) for img in imgs] - for i in range(num_samples) - ] - def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): - images_grid = [] - for sample_images in imgs: - images_row = [] - for i, image in enumerate(sample_images): - cm_name = "gray" if i == 0 else "inferno" - if image.ndim == 2: - image = image[np.newaxis] - for channel in image: - channel = rescale_intensity(channel, out_range=(0, 1)) - render = get_cmap(cm_name)(channel, bytes=True)[..., :3] - images_row.append(render) - images_grid.append(np.concatenate(images_row, axis=1)) - grid = np.concatenate(images_grid, axis=0) + grid = _render_images(imgs) self.logger.experiment.add_image( key, grid, self.current_epoch, dataformats="HWC" ) @@ -519,7 +526,10 @@ def training_step(self, batch: Sequence[Sample], batch_idx: int): batch_size += source.shape[0] if batch_idx < self.log_batches_per_epoch: self.training_step_outputs.extend( - self._detach_sample((source, target * mask.unsqueeze(2), pred)) + _detach_sample( + (source, target * mask.unsqueeze(2), pred), + self.log_samples_per_batch, + ) ) loss_step = torch.stack(losses).mean() self.log( @@ -547,7 +557,10 @@ def validation_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0 ) if batch_idx < self.log_batches_per_epoch: self.validation_step_outputs.extend( - self._detach_sample((source, target * mask.unsqueeze(2), pred)) + _detach_sample( + (source, target * mask.unsqueeze(2), pred), + self.log_samples_per_batch, + ) ) @@ -563,7 +576,8 @@ def __init__( margin: float = 0.5, lr: float = 1e-3, schedule: Literal["WarmupCosine", "Constant"] = "Constant", - log_steps_per_epoch: int = 8, + log_batches_per_epoch: int = 8, + log_samples_per_batch: int = 1, in_channels: int = 1, example_input_yx_shape: Sequence[int] = (256, 256), in_stack_depth: int = 15, @@ -573,15 +587,12 @@ def __init__( tracks_path: str = "data/tracks", ) -> None: super().__init__() - if wandb is None: - raise ImportError( - f"wandb is required for logging of {type(self).__name__}." - ) self.loss_function = loss_function self.margin = margin self.lr = lr self.schedule = schedule - self.log_steps_per_epoch = log_steps_per_epoch + self.log_batches_per_epoch = log_batches_per_epoch + self.log_samples_per_batch = log_samples_per_batch self.training_step_outputs = [] self.validation_step_outputs = [] self.test_step_outputs = [] @@ -591,7 +602,6 @@ def __init__( self.processed_order = [] self.predictions = [] self.tracks_path = tracks_path - self.model = ContrastiveEncoder( backbone=backbone, in_channels=in_channels, @@ -600,177 +610,96 @@ def __init__( embedding_len=embedding_len, predict=predict, ) - - # commented because not logging the graph. - # self.example_input_array = torch.rand( - # 1, - # in_channels, - # in_stack_depth, - # *example_input_yx_shape, # batch size - # ) - - self.images_to_log = [] - self.train_batch_counter = 0 - self.val_batch_counter = 0 + self.example_input_array = torch.rand( + 1, in_channels, in_stack_depth, *example_input_yx_shape + ) + self.training_step_outputs = [] + self.validataion_step_outputs = [] def forward(self, x: Tensor) -> Tensor: - """Forward pass of the model.""" - _, projections = self.model(x) - return projections - # features is without projection head and projects is with projection head + """Projected embeddings.""" + return self.model(x)[1] def log_feature_statistics(self, embeddings: Tensor, prefix: str): mean = torch.mean(embeddings, dim=0).detach().cpu().numpy() std = torch.std(embeddings, dim=0).detach().cpu().numpy() - - print(f"{prefix}_mean: {mean}") - print(f"{prefix}_std: {std}") + _logger.debug(f"{prefix}_mean: {mean}") + _logger.debug(f"{prefix}_std: {std}") def print_embedding_norms(self, anchor, positive, negative, phase): anchor_norm = torch.norm(anchor, dim=1).mean().item() positive_norm = torch.norm(positive, dim=1).mean().item() negative_norm = torch.norm(negative, dim=1).mean().item() + _logger.debug(f"{phase}/anchor_norm: {anchor_norm}") + _logger.debug(f"{phase}/positive_norm: {positive_norm}") + _logger.debug(f"{phase}/negative_norm: {negative_norm}") - print(f"{phase}/anchor_norm: {anchor_norm}") - print(f"{phase}/positive_norm: {positive_norm}") - print(f"{phase}/negative_norm: {negative_norm}") - - # logs over all steps - @rank_zero_only - def log_metrics(self, anchor, positive, negative, phase): - cosine_sim_pos = F.cosine_similarity(anchor, positive, dim=1).mean().item() - cosine_sim_neg = F.cosine_similarity(anchor, negative, dim=1).mean().item() - - euclidean_dist_pos = F.pairwise_distance(anchor, positive).mean().item() - euclidean_dist_neg = F.pairwise_distance(anchor, negative).mean().item() - - metrics = { - f"{phase}/cosine_similarity_positive": cosine_sim_pos, - f"{phase}/cosine_similarity_negative": cosine_sim_neg, - f"{phase}/euclidean_distance_positive": euclidean_dist_pos, - f"{phase}/euclidean_distance_negative": euclidean_dist_neg, - } - - wandb.log(metrics) - - if phase == "train": - self.training_metrics.append(metrics) - elif phase == "val": - self.validation_metrics.append(metrics) - elif phase == "test": - self.test_metrics.append(metrics) - - @rank_zero_only - # logs only one sample from the first batch per epoch - def log_images(self, anchor, positive, negative, epoch, step_name): - z_idx = 7 # middle of z_slice - - anchor_img_rfp = anchor[0, 0, z_idx, :, :].cpu().numpy() - positive_img_rfp = positive[0, 0, z_idx, :, :].cpu().numpy() - negative_img_rfp = negative[0, 0, z_idx, :, :].cpu().numpy() - - anchor_img_phase = anchor[0, 1, z_idx, :, :].cpu().numpy() - positive_img_phase = positive[0, 1, z_idx, :, :].cpu().numpy() - negative_img_phase = negative[0, 1, z_idx, :, :].cpu().numpy() - - def normalize(image): - min_val = image.min() - max_val = image.max() - return (image - min_val) / (max_val - min_val) * 255 - - anchor_img_rfp = normalize(anchor_img_rfp) - positive_img_rfp = normalize(positive_img_rfp) - negative_img_rfp = normalize(negative_img_rfp) - - anchor_img_phase = normalize(anchor_img_phase) - positive_img_phase = normalize(positive_img_phase) - negative_img_phase = normalize(negative_img_phase) - - # combine the images side by side - combined_img_rfp = np.concatenate( - (anchor_img_rfp, positive_img_rfp, negative_img_rfp), axis=1 + def _log_metrics( + self, loss, anchor, positive, negative, stage: Literal["train", "val"] + ): + self.log( + f"loss/{stage}", + loss.to(self.device), + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, ) - combined_img_phase = np.concatenate( - (anchor_img_phase, positive_img_phase, negative_img_phase), axis=1 + cosine_sim_pos = F.cosine_similarity(anchor, positive, dim=1).mean() + cosine_sim_neg = F.cosine_similarity(anchor, negative, dim=1).mean() + euclidean_dist_pos = F.pairwise_distance(anchor, positive).mean() + euclidean_dist_neg = F.pairwise_distance(anchor, negative).mean() + self.log_dict( + { + f"metrics/cosine_similarity_positive/{stage}": cosine_sim_pos, + f"metrics/cosine_similarity_negative/{stage}": cosine_sim_neg, + f"metrics/euclidean_distance_positive/{stage}": euclidean_dist_pos, + f"metrics/euclidean_distance_negative/{stage}": euclidean_dist_neg, + }, + on_step=False, + on_epoch=True, + logger=True, + sync_dist=True, ) - combined_img = np.concatenate((combined_img_rfp, combined_img_phase), axis=0) - self.images_to_log.append( - wandb.Image( - combined_img, caption=f"Anchor | Positive | Negative (Epoch {epoch})" - ) + def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): + grid = _render_images(imgs, cmaps=["gray"] * 3) + self.logger.experiment.add_image( + key, grid, self.current_epoch, dataformats="HWC" ) - wandb.log({f"{step_name}": self.images_to_log}) - self.images_to_log = [] - def training_step( self, batch: TripletSample, batch_idx: int, ) -> Tensor: """Training step of the model.""" - - anchor = batch["anchor"] + stage = "train" + anchor_img = batch["anchor"] pos_img = batch["positive"] neg_img = batch["negative"] - _, anchor_projection = self.model(anchor) + _, anchor_projection = self.model(anchor_img) _, negative_projection = self.model(neg_img) _, positive_projection = self.model(pos_img) loss = self.loss_function( anchor_projection, positive_projection, negative_projection ) - - self.log("train/loss_step", loss, on_step=True, prog_bar=True, logger=True) - - self.train_batch_counter += 1 - if self.train_batch_counter % self.log_steps_per_epoch == 0: - self.log_images( - anchor, pos_img, neg_img, self.current_epoch, "training_images" - ) - - self.log_metrics( - anchor_projection, positive_projection, negative_projection, "train" + self._log_metrics( + loss, anchor_projection, positive_projection, negative_projection, stage ) - - self.training_step_outputs.append(loss) - return {"loss": loss} + if batch_idx < self.log_batches_per_epoch: + self.training_step_outputs.extend( + _detach_sample( + (anchor_img, pos_img, neg_img), self.log_samples_per_batch + ) + ) + return loss def on_train_epoch_end(self) -> None: - epoch_loss = torch.stack(self.training_step_outputs).mean() - self.log( - "train/loss_epoch", epoch_loss, on_epoch=True, prog_bar=True, logger=True - ) - self.training_step_outputs.clear() - - if self.training_metrics: - avg_metrics = self.aggregate_metrics(self.training_metrics, "train") - self.log( - "train/avg_cosine_similarity_positive", - avg_metrics["train/cosine_similarity_positive"], - on_epoch=True, - logger=True, - ) - self.log( - "train/avg_cosine_similarity_negative", - avg_metrics["train/cosine_similarity_negative"], - on_epoch=True, - logger=True, - ) - self.log( - "train/avg_euclidean_distance_positive", - avg_metrics["train/euclidean_distance_positive"], - on_epoch=True, - logger=True, - ) - self.log( - "train/avg_euclidean_distance_negative", - avg_metrics["train/euclidean_distance_negative"], - on_epoch=True, - logger=True, - ) - self.training_metrics.clear() - self.train_batch_counter = 0 + super().on_train_epoch_end() + self._log_samples("train_samples", self.training_step_outputs) + self.training_step_outputs = [] def validation_step( self, @@ -778,7 +707,6 @@ def validation_step( batch_idx: int, ) -> Tensor: """Validation step of the model.""" - anchor = batch["anchor"] pos_img = batch["positive"] neg_img = batch["negative"] @@ -788,141 +716,24 @@ def validation_step( loss = self.loss_function( anchor_projection, positive_projection, negative_projection ) - - self.log("val/loss_step", loss, on_step=True, prog_bar=True, logger=True) - - self.val_batch_counter += 1 - if self.val_batch_counter % self.log_steps_per_epoch == 0: - self.log_images( - anchor, pos_img, neg_img, self.current_epoch, "validation_images" - ) - - self.log_metrics( - anchor_projection, positive_projection, negative_projection, "val" + self._log_metrics( + loss, anchor_projection, positive_projection, negative_projection, "val" ) - - self.validation_step_outputs.append(loss) - return {"loss": loss} - - def on_validation_epoch_end(self) -> None: - epoch_loss = torch.stack(self.validation_step_outputs).mean() - self.log( - "val/loss_epoch", epoch_loss, on_epoch=True, prog_bar=True, logger=True - ) - self.validation_step_outputs.clear() - - if self.validation_metrics: - avg_metrics = self.aggregate_metrics(self.validation_metrics, "val") - self.log( - "val/avg_cosine_similarity_positive", - avg_metrics["val/cosine_similarity_positive"], - on_epoch=True, - logger=True, - ) - self.log( - "val/avg_cosine_similarity_negative", - avg_metrics["val/cosine_similarity_negative"], - on_epoch=True, - logger=True, - ) - self.log( - "val/avg_euclidean_distance_positive", - avg_metrics["val/euclidean_distance_positive"], - on_epoch=True, - logger=True, - ) - self.log( - "val/avg_euclidean_distance_negative", - avg_metrics["val/euclidean_distance_negative"], - on_epoch=True, - logger=True, + if batch_idx < self.log_batches_per_epoch: + self.validation_step_outputs.extend( + _detach_sample((anchor, pos_img, neg_img), self.log_samples_per_batch) ) - self.validation_metrics.clear() - self.val_batch_counter = 0 - - def test_step( - self, - batch: TripletSample, - batch_idx: int, - ) -> Tensor: - """Test step of the model.""" - - anchor = batch["anchor"] - pos_img = batch["positive"] - neg_img = batch["negative"] - _, anchor_projection = self.model(anchor) - _, negative_projection = self.model(neg_img) - _, positive_projection = self.model(pos_img) - loss = self.loss_function( - anchor_projection, positive_projection, negative_projection - ) - - self.log("test/loss_step", loss, on_step=True, prog_bar=True, logger=True) - - self.log_metrics( - anchor_projection, positive_projection, negative_projection, "test" - ) - - self.test_step_outputs.append(loss) - return {"loss": loss} + return loss - @rank_zero_only - def on_test_epoch_end(self) -> None: - epoch_loss = torch.stack(self.test_step_outputs).mean() - self.log( - "test/loss_epoch", epoch_loss, on_epoch=True, prog_bar=True, logger=True - ) - self.test_step_outputs.clear() - - if self.test_metrics: - avg_metrics = self.aggregate_metrics(self.test_metrics, "test") - self.log( - "test/avg_cosine_similarity_positive", - avg_metrics["test/cosine_similarity_positive"], - on_epoch=True, - logger=True, - ) - self.log( - "test/avg_cosine_similarity_negative", - avg_metrics["test/cosine_similarity_negative"], - on_epoch=True, - logger=True, - ) - self.log( - "test/avg_euclidean_distance_positive", - avg_metrics["test/euclidean_distance_positive"], - on_epoch=True, - logger=True, - ) - self.log( - "test/avg_euclidean_distance_negative", - avg_metrics["test/euclidean_distance_negative"], - on_epoch=True, - logger=True, - ) - self.test_metrics.clear() + def on_validation_epoch_end(self) -> None: + super().on_validation_epoch_end() + self._log_samples("val_samples", self.validation_step_outputs) + self.validation_step_outputs = [] def configure_optimizers(self): optimizer = Adam(self.parameters(), lr=self.lr) return optimizer - def aggregate_metrics(self, metrics, phase): - avg_metrics = {} - if metrics: - avg_metrics[f"{phase}/cosine_similarity_positive"] = sum( - m[f"{phase}/cosine_similarity_positive"] for m in metrics - ) / len(metrics) - avg_metrics[f"{phase}/cosine_similarity_negative"] = sum( - m[f"{phase}/cosine_similarity_negative"] for m in metrics - ) / len(metrics) - avg_metrics[f"{phase}/euclidean_distance_positive"] = sum( - m[f"{phase}/euclidean_distance_positive"] for m in metrics - ) / len(metrics) - avg_metrics[f"{phase}/euclidean_distance_negative"] = sum( - m[f"{phase}/euclidean_distance_negative"] for m in metrics - ) / len(metrics) - return avg_metrics - def predict_step(self, batch: TripletSample, batch_idx, dataloader_idx=0): print("running predict step!") """Prediction step for extracting embeddings."""