From b4bbf3973c9420750741c26ae8bd39946e47fbbd Mon Sep 17 00:00:00 2001 From: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> Date: Tue, 11 Jun 2024 14:11:26 -0700 Subject: [PATCH] Rename UNeXt2 (#84) * rename file * rename the architecture * fix merge --- viscy/light/engine.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/viscy/light/engine.py b/viscy/light/engine.py index a338497b..7e4f1118 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -136,7 +136,7 @@ def __init__( test_cellpose_diameter: float = None, test_evaluate_cellpose: bool = False, test_time_augmentations: bool = False, - tta_type: Literal["mean", "median","product"] = "mean", + tta_type: Literal["mean", "median", "product"] = "mean", ) -> None: super().__init__() net_class = _UNET_ARCHITECTURE.get(architecture) @@ -364,7 +364,9 @@ def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): prediction = torch.stack(predictions).cpu().median(dim=0).values elif self.tta_type == "product": # Perform multiplication of predictions in logarithmic space for numerical stability adding epsion to avoid log(0) case - log_predictions = torch.stack([torch.log(p + 1e-9) for p in predictions]) + log_predictions = torch.stack( + [torch.log(p + 1e-9) for p in predictions] + ) log_prediction_sum = log_predictions.sum(dim=0) prediction = torch.exp(log_prediction_sum) # Put back to GPU