From 79dbc39b096185a6035db96439c87aca1282a43b Mon Sep 17 00:00:00 2001 From: Albertgary Date: Sat, 22 Jun 2024 17:07:42 +0100 Subject: [PATCH 1/4] add wmse loss and unit test for losses --- odak/learn/models/models.py | 6 +++--- odak/learn/tools/loss.py | 34 ++++++++++++++++++++++++++++++++-- test/test_learn_components.py | 4 ++-- test/test_tools_losses.py | 29 +++++++++++++++++++++++++++++ 4 files changed, 66 insertions(+), 7 deletions(-) create mode 100644 test/test_tools_losses.py diff --git a/odak/learn/models/models.py b/odak/learn/models/models.py index 4d39a9c7..f50130d0 100644 --- a/odak/learn/models/models.py +++ b/odak/learn/models/models.py @@ -1,6 +1,6 @@ import torch from .components import double_convolution, downsample_layer, upsample_layer, swish, gaussian - +import segmentation_models_pytorch as smp class multi_layer_perceptron(torch.nn.Module): """ @@ -140,8 +140,8 @@ def __init__( kernel_size = kernel_size, bias = bias, activation = activation - ) - factor = 2 if bilinear else 1 + ) + self.downsampling_layers = torch.nn.ModuleList() self.upsampling_layers = torch.nn.ModuleList() for i in range(depth): # downsampling layers diff --git a/odak/learn/tools/loss.py b/odak/learn/tools/loss.py index b83de396..06f2e4f9 100644 --- a/odak/learn/tools/loss.py +++ b/odak/learn/tools/loss.py @@ -125,7 +125,7 @@ def histogram_loss(frame, ground_truth, bins = 32, limits = [0., 1.]): if len(frame.shape) == 3: frame = frame.unsqueeze(0) histogram_frame = torch.zeros(frame.shape[1], bins).to(frame.device) - histogram_ground_truth = torch.zeros(frame.shape[1], bins).to(frame.device) + histogram_ground_truth = torch.zeros(ground_truth.shape[1], bins).to(frame.device) l2 = torch.nn.MSELoss() for i in range(frame.shape[1]): histogram_frame[i] = torch.histc(frame[:, i].flatten(), bins = bins, min = limits[0], max = limits[1]) @@ -141,7 +141,7 @@ def weber_contrast(image, roi_high, roi_low): Parameters ---------- image : torch.tensor - Image to be tested [1 x 3 x m x n] or [3 x m x n] or [m x n]. + Image to be tested [1 x 3 x m x n] or [3 x m x n] or [1 x m x n] or [m x n]. roi_high : torch.tensor Corner locations of the roi for high intensity area [m_start, m_end, n_start, n_end]. roi_low : torch.tensor @@ -192,3 +192,33 @@ def michelson_contrast(image, roi_high, roi_low): low = torch.mean(region_low, dim = (2, 3)) result = (high - low) / (high + low) return result.squeeze(0) + + +def wrapped_mean_squared_error(image, ground_truth, reduction = 'mean'): + """ + A function to calculate the wrapped mean squared error between predicted and target angles. + + Parameters + ---------- + image : torch.tensor + Image to be tested [1 x 3 x m x n] or [3 x m x n] or [1 x m x n] or [m x n]. + ground_truth : torch.tensor + Ground truth to be tested [1 x 3 x m x n] or [3 x m x n] or [1 x m x n] or [m x n]. + reduction : str + Specifies the reduction to apply to the output: 'mean' (default) or 'sum'. + + Returns + ------- + wmse : torch.tensor + The calculated wrapped mean squared error. + """ + sin_diff = torch.sin(image) - torch.sin(ground_truth) + cos_diff = torch.cos(image) - torch.cos(ground_truth) + loss = (sin_diff**2 + cos_diff**2) + + if reduction == 'mean': + return loss.mean() + elif reduction == 'sum': + return loss.sum() + else: + raise ValueError("Invalid reduction type. Choose 'mean' or 'sum'.") \ No newline at end of file diff --git a/test/test_learn_components.py b/test/test_learn_components.py index e5784f7f..2fcecc74 100644 --- a/test/test_learn_components.py +++ b/test/test_learn_components.py @@ -5,7 +5,7 @@ def test(): # test residual block x = torch.randn(1, 2, 32, 32) - residual_inference= components.residual_layer() + residual_inference = components.residual_layer() y = residual_inference(x) # test convolution layer convolution_inference = components.convolution_layer() @@ -16,7 +16,7 @@ def test(): # test normalization layer normalization_inference = components.normalization() y = normalization_inference(x) - # test attention layer + # test attention layerf residual_attention_layer_inference = components.residual_attention_layer() y = residual_attention_layer_inference(x , x) # test self-attention layer diff --git a/test/test_tools_losses.py b/test/test_tools_losses.py new file mode 100644 index 00000000..8f6e3de3 --- /dev/null +++ b/test/test_tools_losses.py @@ -0,0 +1,29 @@ +import torch +import odak +import sys +import random +import os +import odak.learn.tools.loss as loss + + +def test(): + # test residual block + image = [torch.randn(1, 3, 32, 32), torch.randn(1, 32, 32), torch.randn(32, 32), torch.randn(3, 32, 32)] + ground_truth = [torch.randn(1, 3, 32, 32), torch.randn(1, 32, 32), torch.randn(32, 32), torch.randn(3, 32, 32)] + for idx, (img, pred) in enumerate(zip(image, ground_truth)): + print(f'Running test {idx}, input shape: {img.size()}...') + y = loss.psnr(img, pred, peak_value = 2.0) + y = loss.multi_scale_total_variation_loss(img, levels = 4) + y = loss.total_variation_loss(img) + y = loss.histogram_loss(img, pred, bins = 16, limits = [0., 1.]) + roi_high = [0, 16, 0, 16] + roi_low = [16, 32, 16, 32] + y = loss.weber_contrast(img, roi_high, roi_low) + y = loss.michelson_contrast(img, roi_high, roi_low) + y = loss.wrapped_mean_squared_error(img, pred, reduction='sum') + value = torch.tensor(1.0, dtype=torch.float) + y = loss.radial_basis_function(value = value, epsilon = 0.5) + assert True == True + +if __name__ == '__main__': + sys.exit(test()) From 8bff1dc4f5f804568c55e7fe589f692bd75a520d Mon Sep 17 00:00:00 2001 From: Albertgary Date: Sat, 22 Jun 2024 21:31:38 +0100 Subject: [PATCH 2/4] minor update --- odak/learn/models/models.py | 2 +- test/test_tools_losses.py | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/odak/learn/models/models.py b/odak/learn/models/models.py index f50130d0..09efd6ec 100644 --- a/odak/learn/models/models.py +++ b/odak/learn/models/models.py @@ -1,6 +1,6 @@ import torch from .components import double_convolution, downsample_layer, upsample_layer, swish, gaussian -import segmentation_models_pytorch as smp + class multi_layer_perceptron(torch.nn.Module): """ diff --git a/test/test_tools_losses.py b/test/test_tools_losses.py index 8f6e3de3..a8d9310b 100644 --- a/test/test_tools_losses.py +++ b/test/test_tools_losses.py @@ -1,8 +1,5 @@ import torch -import odak import sys -import random -import os import odak.learn.tools.loss as loss From d23bd02439bc6fd997a896c9d4ed4e2d46bd5440 Mon Sep 17 00:00:00 2001 From: Albertgary Date: Sat, 22 Jun 2024 21:35:00 +0100 Subject: [PATCH 3/4] typo --- test/test_learn_components.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_learn_components.py b/test/test_learn_components.py index 2fcecc74..054c22b6 100644 --- a/test/test_learn_components.py +++ b/test/test_learn_components.py @@ -16,7 +16,7 @@ def test(): # test normalization layer normalization_inference = components.normalization() y = normalization_inference(x) - # test attention layerf + # test attention layer residual_attention_layer_inference = components.residual_attention_layer() y = residual_attention_layer_inference(x , x) # test self-attention layer From 83e6359025fbe33c40eddf7de8f7bc84d73bc224 Mon Sep 17 00:00:00 2001 From: Albertgary Date: Sat, 22 Jun 2024 22:06:11 +0100 Subject: [PATCH 4/4] update according to pytest --- odak/learn/tools/loss.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/odak/learn/tools/loss.py b/odak/learn/tools/loss.py index 06f2e4f9..f6bd4c91 100644 --- a/odak/learn/tools/loss.py +++ b/odak/learn/tools/loss.py @@ -121,18 +121,28 @@ def histogram_loss(frame, ground_truth, bins = 32, limits = [0., 1.]): Loss from evaluation. """ if len(frame.shape) == 2: + frame = frame.unsqueeze(0).unsqueeze(0) + elif len(frame.shape) == 3: frame = frame.unsqueeze(0) - if len(frame.shape) == 3: - frame = frame.unsqueeze(0) + + if len(ground_truth.shape) == 2: + ground_truth = ground_truth.unsqueeze(0).unsqueeze(0) + elif len(ground_truth.shape) == 3: + ground_truth = ground_truth.unsqueeze(0) + histogram_frame = torch.zeros(frame.shape[1], bins).to(frame.device) histogram_ground_truth = torch.zeros(ground_truth.shape[1], bins).to(frame.device) + l2 = torch.nn.MSELoss() + for i in range(frame.shape[1]): - histogram_frame[i] = torch.histc(frame[:, i].flatten(), bins = bins, min = limits[0], max = limits[1]) - histogram_ground_truth[i] = torch.histc(frame[:, i].flatten(), bins = bins, min = limits[0], max = limits[1]) + histogram_frame[i] = torch.histc(frame[:, i].flatten(), bins=bins, min=limits[0], max=limits[1]) + histogram_ground_truth[i] = torch.histc(ground_truth[:, i].flatten(), bins=bins, min=limits[0], max=limits[1]) + loss = l2(histogram_frame, histogram_ground_truth) - return loss + return loss + def weber_contrast(image, roi_high, roi_low): """