From 2c0db743516eaaa085c8a8e44293414dfa371b89 Mon Sep 17 00:00:00 2001 From: Daria Doubine Date: Thu, 12 Jan 2023 11:34:30 +0100 Subject: [PATCH 1/2] passing custom mean_std at instanciation --- aloscene/frame.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/aloscene/frame.py b/aloscene/frame.py index 304fc6fd..13ff973e 100644 --- a/aloscene/frame.py +++ b/aloscene/frame.py @@ -465,6 +465,9 @@ def norm_minmax_sym(self): tensor = 2 * tensor - 1.0 elif tensor.normalization == "255": tensor = 2 * (tensor / 255.0) - 1.0 + elif tensor.mean_std is not None: + tensor = tensor.norm01() + tensor = 2 * tensor - 1.0 else: raise Exception(f"Can't convert from {tensor.normalization} to norm255") tensor.mean_std = None @@ -484,7 +487,7 @@ def mean_std_norm(self, mean, std, name) -> Frame: """ tensor = self mean_tensor, std_tensor = self._get_mean_std_tensor( - tensor.shape, tensor.names, tensor._resnet_mean_std, device=tensor.device + tensor.shape, tensor.names, (mean, std), device=tensor.device ) if tensor.normalization == "01": tensor = tensor - mean_tensor @@ -507,15 +510,23 @@ def mean_std_norm(self, mean, std, name) -> Frame: return tensor - def norm_resnet(self) -> Frame: - """Normalized the current frame based on the normalized use on resnet on pytorch. This method will - simply call `frame.mean_std_norm()` with the resnet mean/std and the name `resnet`. - + def norm_meanstd(self, mean_std=None, name=None) -> Frame: + """Returns z-norm of the current frame. + This method will simply call `frame.mean_std_norm()` with the mean/std property of the frame and the selected name. + Instead of a custom mean/std, you can use the resnet norm based on the normalized use of resnet on pytorch. Examples -------- - >>> frame_resnet = frame.norm_resnet() + >>> frame_resnet = frame.norm_resnet(name="resnet") + >>> frame_custom_norm = frame.norm_resnet(name="custom") """ - return self.mean_std_norm(mean=self._resnet_mean_std[0], std=self._resnet_mean_std[1], name="resnet") + if name == "resnet": + return self.mean_std_norm(mean=self._resnet_mean_std[0], std=self._resnet_mean_std[1], name="resnet") + elif mean_std is not None and len(mean_std) == 2: + return self.mean_std_norm(mean=mean_std[0], std=mean_std[1], name=name) + elif self.mean_std is not None: + return self.mean_std_norm(mean=self.mean_std[0], std=self.mean_std[1], name=name) + else: + raise Exception("Please pass a mean_std tuple or use the resnet norm") def __get_view__(self, title=None): """Create a view of the frame""" From a43235e27e56e55cab54e4db30e7a5759aa4a9da Mon Sep 17 00:00:00 2001 From: Daria Doubine Date: Thu, 12 Jan 2023 13:28:17 +0100 Subject: [PATCH 2/2] brought back norm_resnet() --- aloscene/frame.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/aloscene/frame.py b/aloscene/frame.py index 13ff973e..ba71711a 100644 --- a/aloscene/frame.py +++ b/aloscene/frame.py @@ -520,7 +520,7 @@ def norm_meanstd(self, mean_std=None, name=None) -> Frame: >>> frame_custom_norm = frame.norm_resnet(name="custom") """ if name == "resnet": - return self.mean_std_norm(mean=self._resnet_mean_std[0], std=self._resnet_mean_std[1], name="resnet") + return self.norm_resnet() elif mean_std is not None and len(mean_std) == 2: return self.mean_std_norm(mean=mean_std[0], std=mean_std[1], name=name) elif self.mean_std is not None: @@ -528,6 +528,15 @@ def norm_meanstd(self, mean_std=None, name=None) -> Frame: else: raise Exception("Please pass a mean_std tuple or use the resnet norm") + def norm_resnet(self) -> Frame: + """Normalized the current frame based on the normalized use on resnet on pytorch. This method will + simply call `frame.mean_std_norm()` with the resnet mean/std and the name `resnet`. + Examples + -------- + >>> frame_resnet = frame.norm_resnet() + """ + return self.mean_std_norm(mean=self._resnet_mean_std[0], std=self._resnet_mean_std[1], name="resnet") + def __get_view__(self, title=None): """Create a view of the frame""" assert self.names[0] != "T" and self.names[1] != "B"