diff --git a/odak/learn/models/components.py b/odak/learn/models/components.py index 56fc5bad..a63518d7 100644 --- a/odak/learn/models/components.py +++ b/odak/learn/models/components.py @@ -112,6 +112,8 @@ def __init__( output_channels = 2, kernel_size = 3, bias = False, + stride = 1, + normalization = True, activation = torch.nn.ReLU() ): """ @@ -126,40 +128,45 @@ def __init__( Number of output channels. kernel_size : int Kernel size. - bias : bool + bias : bool Set to True to let convolutional layers have bias term. + normalization : bool + If True, adds a Batch Normalization layer after the convolutional layer. activation : torch.nn Nonlinear activation layer to be used. If None, uses torch.nn.ReLU(). """ super().__init__() - self.activation = activation - self.model = torch.nn.Sequential( - torch.nn.Conv2d( - input_channels, - output_channels, - kernel_size = kernel_size, - padding = kernel_size // 2, - bias = bias - ), - torch.nn.BatchNorm2d(output_channels), - self.activation - ) + layers = [ + torch.nn.Conv2d( + input_channels, + output_channels, + kernel_size = kernel_size, + stride = stride, + padding = kernel_size // 2, + bias = bias + ) + ] + if normalization: + layers.append(torch.nn.BatchNorm2d(output_channels)) + if activation: + layers.append(activation) + self.model = torch.nn.Sequential(*layers) def forward(self, x): """ Forward model. - + Parameters ---------- x : torch.tensor Input data. - - + + Returns ---------- result : torch.tensor - Estimated output. + Estimated output. """ result = self.model(x) return result @@ -176,6 +183,7 @@ def __init__( output_channels = 2, kernel_size = 3, bias = False, + normalization = True, activation = torch.nn.ReLU() ): """ @@ -194,6 +202,8 @@ def __init__( Kernel size. bias : bool Set to True to let convolutional layers have bias term. + normalization : bool + If True, adds a Batch Normalization layer after the convolutional layer. activation : torch.nn Nonlinear activation layer to be used. If None, uses torch.nn.ReLU(). """ @@ -207,6 +217,7 @@ def __init__( output_channels = mid_channels, kernel_size = kernel_size, bias = bias, + normalization = normalization, activation = self.activation ), convolution_layer( @@ -214,6 +225,7 @@ def __init__( output_channels = output_channels, kernel_size = kernel_size, bias = bias, + normalization = normalization, activation = self.activation ) ) @@ -815,4 +827,468 @@ def forward(self, x): results = torch.cat(results, dim=2) results = results.permute(0, 2, 1) results = results.reshape(B, -1) - return results + return results + + +class upsample_convtranspose2d_layer(torch.nn.Module): + """ + An upsampling convtranspose2d layer. + """ + def __init__( + self, + input_channels, + output_channels, + kernel_size = 2, + stride = 2, + bias = False, + ): + """ + A downscaling component with a double convolution. + + Parameters + ---------- + input_channels : int + Number of input channels. + output_channels : int + Number of output channels. + kernel_size : int + Kernel size. + bias : bool + Set to True to let convolutional layers have bias term. + activation : torch.nn + Nonlinear activation layer to be used. If None, uses torch.nn.ReLU(). + bilinear : bool + If set to True, bilinear sampling is used. + """ + super().__init__() + self.up = torch.nn.ConvTranspose2d( + in_channels = input_channels, + out_channels = output_channels, + bias = bias, + kernel_size = kernel_size, + stride = stride + ) + + def forward(self, x1, x2): + """ + Forward model. + + Parameters + ---------- + x1 : torch.tensor + First input data. + x2 : torch.tensor + Second input data. + + + Returns + ---------- + result : torch.tensor + Result of the forward operation + """ + x1 = self.up(x1) + diffY = x2.size()[2] - x1.size()[2] + diffX = x2.size()[3] - x1.size()[3] + x1 = torch.nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2, + diffY // 2, diffY - diffY // 2]) + result = x1 + x2 + return result + + +class global_transformations(torch.nn.Module): + """ + A global feature layer that processes global features from input channels and + applies learned transformations to another input tensor. + + This implementation is adapted from RSGUnet: + https://github.com/MTLab/rsgunet_image_enhance. + + Reference: + J. Huang, P. Zhu, M. Geng et al. "Range Scaling Global U-Net for Perceptual Image Enhancement on Mobile Devices." + """ + def __init__( + self, + input_channels, + output_channels + ): + """ + A global feature layer. + + Parameters + ---------- + input_channels : int + Number of input channels. + output_channels : int + Number of output channels. + """ + super().__init__() + self.global_feature_1 = torch.nn.Sequential( + torch.nn.Linear(input_channels, output_channels), + torch.nn.LeakyReLU(0.2, inplace = True), + ) + self.global_feature_2 = torch.nn.Sequential( + torch.nn.Linear(output_channels, output_channels), + torch.nn.LeakyReLU(0.2, inplace = True) + ) + + + def forward(self, x1, x2): + """ + Forward model. + + Parameters + ---------- + x1 : torch.tensor + First input data. + x2 : torch.tensor + Second input data. + + Returns + ---------- + result : torch.tensor + Estimated output. + """ + y = torch.mean(x2, dim = (2, 3)) + y1 = self.global_feature_1(y) + y2 = self.global_feature_2(y1) + y1 = y1.unsqueeze(2).unsqueeze(3) + y2 = y2.unsqueeze(2).unsqueeze(3) + result = x1 * y1 + y2 + return result + + +class global_feature_module(torch.nn.Module): + """ + A global feature layer that processes global features from input channels and + applies them to another input tensor via learned transformations. + """ + def __init__( + self, + input_channels, + mid_channels, + output_channels, + kernel_size, + bias = False, + normalization = False, + activation = torch.nn.ReLU() + ): + """ + A global feature layer. + + Parameters + ---------- + input_channels : int + Number of input channels. + mid_channels : int + Number of mid channels. + output_channels : int + Number of output channels. + kernel_size : int + Kernel size. + bias : bool + Set to True to let convolutional layers have bias term. + normalization : bool + If True, adds a Batch Normalization layer after the convolutional layer. + activation : torch.nn + Nonlinear activation layer to be used. If None, uses torch.nn.ReLU(). + """ + super().__init__() + self.transformations_1 = global_transformations(input_channels, output_channels) + self.global_features_1 = double_convolution( + input_channels = input_channels, + mid_channels = mid_channels, + output_channels = output_channels, + kernel_size = kernel_size, + bias = bias, + normalization = normalization, + activation = activation + ) + self.global_features_2 = double_convolution( + input_channels = input_channels, + mid_channels = mid_channels, + output_channels = output_channels, + kernel_size = kernel_size, + bias = bias, + normalization = normalization, + activation = activation + ) + self.transformations_2 = global_transformations(input_channels, output_channels) + + + def forward(self, x1, x2): + """ + Forward model. + + Parameters + ---------- + x1 : torch.tensor + First input data. + x2 : torch.tensor + Second input data. + + Returns + ---------- + result : torch.tensor + Estimated output. + """ + global_tensor_1 = self.transformations_1(x1, x2) + y1 = self.global_features_1(global_tensor_1) + y2 = self.global_features_2(y1) + global_tensor_2 = self.transformations_2(y1, y2) + return global_tensor_2 + + +class spatially_adaptive_convolution(torch.nn.Module): + """ + A spatially adaptive convolution layer. + + References + ---------- + + C. Zheng et al. "Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions." + C, Xu et al. "Squeezesegv3: Spatially-adaptive Convolution for Efficient Point-Cloud Segmentation." + C. Zheng et al. "Windowing Decomposition Convolutional Neural Network for Image Enhancement." + """ + def __init__( + self, + input_channels = 2, + output_channels = 2, + kernel_size = 3, + stride = 1, + padding = 1, + bias = False, + activation = torch.nn.LeakyReLU(0.2, inplace = True) + ): + """ + Initializes a spatially adaptive convolution layer. + + Parameters + ---------- + input_channels : int + Number of input channels. + output_channels : int + Number of output channels. + kernel_size : int + Size of the convolution kernel. + stride : int + Stride of the convolution. + padding : int + Padding added to both sides of the input. + bias : bool + If True, includes a bias term in the convolution. + activation : torch.nn.Module + Activation function to apply. If None, no activation is applied. + """ + super(spatially_adaptive_convolution, self).__init__() + self.kernel_size = kernel_size + self.input_channels = input_channels + self.output_channels = output_channels + self.stride = stride + self.padding = padding + self.standard_convolution = torch.nn.Conv2d( + in_channels = input_channels, + out_channels = self.output_channels, + kernel_size = kernel_size, + stride = stride, + padding = padding, + bias = bias + ) + self.weight = torch.nn.Parameter(data = self.standard_convolution.weight, requires_grad = True) + self.activation = activation + + + def forward(self, x, sv_kernel_feature): + """ + Forward pass for the spatially adaptive convolution layer. + + Parameters + ---------- + x : torch.tensor + Input data tensor. + Dimension: (1, C, H, W) + sv_kernel_feature : torch.tensor + Spatially varying kernel features. + Dimension: (1, C_i * kernel_size * kernel_size, H, W) + + Returns + ------- + sa_output : torch.tensor + Estimated output tensor. + Dimension: (1, output_channels, H_out, W_out) + """ + # Pad input and sv_kernel_feature if necessary + if sv_kernel_feature.size(-1) * self.stride != x.size(-1) or sv_kernel_feature.size( + -2) * self.stride != x.size(-2): + diffY = sv_kernel_feature.size(-2) % self.stride + diffX = sv_kernel_feature.size(-1) % self.stride + sv_kernel_feature = torch.nn.functional.pad(sv_kernel_feature, (diffX // 2, diffX - diffX // 2, + diffY // 2, diffY - diffY // 2)) + diffY = x.size(-2) % self.stride + diffX = x.size(-1) % self.stride + x = torch.nn.functional.pad(x, (diffX // 2, diffX - diffX // 2, + diffY // 2, diffY - diffY // 2)) + + # Unfold the input tensor for matrix multiplication + input_feature = torch.nn.functional.unfold( + x, + kernel_size = (self.kernel_size, self.kernel_size), + stride = self.stride, + padding = self.padding + ) + + # Resize sv_kernel_feature to match the input feature + sv_kernel = sv_kernel_feature.reshape( + 1, + self.input_channels * self.kernel_size * self.kernel_size, + (x.size(-2) // self.stride) * (x.size(-1) // self.stride) + ) + + # Resize weight to match the input channels and kernel size + si_kernel = self.weight.reshape( + self.weight_output_channels, + self.input_channels * self.kernel_size * self.kernel_size + ) + + # Apply spatially varying kernels + sv_feature = input_feature * sv_kernel + + # Perform matrix multiplication + sa_output = torch.matmul(si_kernel, sv_feature).reshape( + 1, self.weight_output_channels, + (x.size(-2) // self.stride), + (x.size(-1) // self.stride) + ) + return sa_output + + +class spatially_adaptive_module(torch.nn.Module): + """ + A spatially adaptive module that combines learned spatially adaptive convolutions. + + References + ---------- + + Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Ak{\c{s}}it}. "Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions." SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24),December,2024. + """ + def __init__( + self, + input_channels = 2, + output_channels = 2, + kernel_size = 3, + stride = 1, + padding = 1, + bias = False, + activation = torch.nn.LeakyReLU(0.2, inplace = True) + ): + """ + Initializes a spatially adaptive module. + + Parameters + ---------- + input_channels : int + Number of input channels. + output_channels : int + Number of output channels. + kernel_size : int + Size of the convolution kernel. + stride : int + Stride of the convolution. + padding : int + Padding added to both sides of the input. + bias : bool + If True, includes a bias term in the convolution. + activation : torch.nn + Nonlinear activation layer to be used. If None, uses torch.nn.ReLU(). + """ + super(spatially_adaptive_module, self).__init__() + self.kernel_size = kernel_size + self.input_channels = input_channels + self.output_channels = output_channels + self.stride = stride + self.padding = padding + self.weight_output_channels = self.output_channels - 1 + self.standard_convolution = torch.nn.Conv2d( + in_channels = input_channels, + out_channels = self.weight_output_channels, + kernel_size = kernel_size, + stride = stride, + padding = padding, + bias = bias + ) + self.weight = torch.nn.Parameter(data = self.standard_convolution.weight, requires_grad = True) + self.activation = activation + + + def forward(self, x, sv_kernel_feature): + """ + Forward pass for the spatially adaptive module. + + Parameters + ---------- + x : torch.tensor + Input data tensor. + Dimension: (1, C, H, W) + sv_kernel_feature : torch.tensor + Spatially varying kernel features. + Dimension: (1, C_i * kernel_size * kernel_size, H, W) + + Returns + ------- + output : torch.tensor + Combined output tensor from standard and spatially adaptive convolutions. + Dimension: (1, output_channels, H_out, W_out) + """ + # Pad input and sv_kernel_feature if necessary + if sv_kernel_feature.size(-1) * self.stride != x.size(-1) or sv_kernel_feature.size( + -2) * self.stride != x.size(-2): + diffY = sv_kernel_feature.size(-2) % self.stride + diffX = sv_kernel_feature.size(-1) % self.stride + sv_kernel_feature = torch.nn.functional.pad(sv_kernel_feature, (diffX // 2, diffX - diffX // 2, + diffY // 2, diffY - diffY // 2)) + diffY = x.size(-2) % self.stride + diffX = x.size(-1) % self.stride + x = torch.nn.functional.pad(x, (diffX // 2, diffX - diffX // 2, + diffY // 2, diffY - diffY // 2)) + + # Unfold the input tensor for matrix multiplication + input_feature = torch.nn.functional.unfold( + x, + kernel_size = (self.kernel_size, self.kernel_size), + stride = self.stride, + padding = self.padding + ) + + # Resize sv_kernel_feature to match the input feature + sv_kernel = sv_kernel_feature.reshape( + 1, + self.input_channels * self.kernel_size * self.kernel_size, + (x.size(-2) // self.stride) * (x.size(-1) // self.stride) + ) + + # Apply sv_kernel to the input_feature + sv_feature = input_feature * sv_kernel + + # Original spatially varying convolution output + sv_output = torch.sum(sv_feature, dim = 1).reshape( + 1, + 1, + (x.size(-2) // self.stride), + (x.size(-1) // self.stride) + ) + + # Reshape weight for spatially adaptive convolution + si_kernel = self.weight.reshape( + self.weight_output_channels, + self.input_channels * self.kernel_size * self.kernel_size + ) + + # Apply si_kernel on sv convolution output + sa_output = torch.matmul(si_kernel, sv_feature).reshape( + 1, self.weight_output_channels, + (x.size(-2) // self.stride), + (x.size(-1) // self.stride) + ) + + # Combine the outputs and apply activation function + output = self.activation(torch.cat((sv_output, sa_output), dim = 1)) + return output diff --git a/odak/learn/models/models.py b/odak/learn/models/models.py index 09efd6ec..f7750552 100644 --- a/odak/learn/models/models.py +++ b/odak/learn/models/models.py @@ -1,5 +1,5 @@ import torch -from .components import double_convolution, downsample_layer, upsample_layer, swish, gaussian +from .components import * class multi_layer_perceptron(torch.nn.Module): @@ -194,3 +194,397 @@ def forward(self, x): result = self.outc(x_up) return result + +class spatially_varying_kernel_generation_model(torch.nn.Module): + """ + Spatially_varying_kernel_generation_model revised from RSGUnet: + https://github.com/MTLab/rsgunet_image_enhance. + + Refer to: + J. Huang, P. Zhu, M. Geng et al. Range Scaling Global U-Net for Perceptual Image Enhancement on Mobile Devices. + """ + + def __init__( + self, + depth = 3, + dimensions = 8, + input_channels = 7, + kernel_size = 3, + bias = True, + normalization = False, + activation = torch.nn.LeakyReLU(0.2, inplace = True) + ): + """ + U-Net model. + + Parameters + ---------- + depth : int + Number of upsampling and downsampling layers. + dimensions : int + Number of dimensions. + input_channels : int + Number of input channels. + bias : bool + Set to True to let convolutional layers learn a bias term. + normalization : bool + If True, adds a Batch Normalization layer after the convolutional layer. + activation : torch.nn + Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()). + """ + super().__init__() + self.depth = depth + self.inc = convolution_layer( + input_channels = input_channels, + output_channels = dimensions, + kernel_size = kernel_size, + bias = bias, + normalization = normalization, + activation = activation + ) + self.encoder = torch.nn.ModuleList() + for i in range(depth + 1): # downsampling layers + if i == 0: + in_channels = dimensions * (2 ** i) + out_channels = dimensions * (2 ** i) + elif i == depth: + in_channels = dimensions * (2 ** (i - 1)) + out_channels = dimensions * (2 ** (i - 1)) + else: + in_channels = dimensions * (2 ** (i - 1)) + out_channels = 2 * in_channels + pooling_layer = torch.nn.AvgPool2d(2) + double_convolution_layer = double_convolution( + input_channels = in_channels, + mid_channels = in_channels, + output_channels = out_channels, + kernel_size = kernel_size, + bias = bias, + normalization = normalization, + activation = activation + ) + self.encoder.append(pooling_layer) + self.encoder.append(double_convolution_layer) + self.spatially_varying_feature = torch.nn.ModuleList() # for kernel generation + for i in range(depth, -1, -1): + if i == 1: + svf_in_channels = dimensions + 2 ** (self.depth + i) + 1 + else: + svf_in_channels = 2 ** (self.depth + i) + 1 + svf_out_channels = (2 ** (self.depth + i)) * (kernel_size * kernel_size) + svf_mid_channels = dimensions * (2 ** (self.depth - 1)) + spatially_varying_kernel_generation = torch.nn.ModuleList() + for j in range(i, -1, -1): + pooling_layer = torch.nn.AvgPool2d(2 ** (j + 1)) + spatially_varying_kernel_generation.append(pooling_layer) + kernel_generation_block = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels = svf_in_channels, + out_channels = svf_mid_channels, + kernel_size = kernel_size, + padding = kernel_size // 2, + bias = bias + ), + activation, + torch.nn.Conv2d( + in_channels = svf_mid_channels, + out_channels = svf_mid_channels, + kernel_size = kernel_size, + padding = kernel_size // 2, + bias = bias + ), + activation, + torch.nn.Conv2d( + in_channels = svf_mid_channels, + out_channels = svf_out_channels, + kernel_size = kernel_size, + padding = kernel_size // 2, + bias = bias + ), + ) + spatially_varying_kernel_generation.append(kernel_generation_block) + self.spatially_varying_feature.append(spatially_varying_kernel_generation) + self.decoder = torch.nn.ModuleList() + global_feature_layer = global_feature_module( # global feature layer + input_channels = dimensions * (2 ** (depth - 1)), + mid_channels = dimensions * (2 ** (depth - 1)), + output_channels = dimensions * (2 ** (depth - 1)), + kernel_size = kernel_size, + bias = bias, + activation = torch.nn.LeakyReLU(0.2, inplace = True) + ) + self.decoder.append(global_feature_layer) + for i in range(depth, 0, -1): + if i == 2: + up_in_channels = (dimensions // 2) * (2 ** i) + up_out_channels = up_in_channels + up_mid_channels = up_in_channels + elif i == 1: + up_in_channels = dimensions * 2 + up_out_channels = dimensions + up_mid_channels = up_out_channels + else: + up_in_channels = (dimensions // 2) * (2 ** i) + up_out_channels = up_in_channels // 2 + up_mid_channels = up_in_channels + upsample_layer = upsample_convtranspose2d_layer( + input_channels = up_in_channels, + output_channels = up_mid_channels, + kernel_size = 2, + stride = 2, + bias = bias, + ) + conv_layer = double_convolution( + input_channels = up_mid_channels, + output_channels = up_out_channels, + kernel_size = kernel_size, + bias = bias, + normalization = normalization, + activation = activation, + ) + self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer])) + + + def forward(self, focal_surface, field): + """ + Forward model. + + Parameters + ---------- + focal_surface : torch.tensor + Input focal surface data. + Dimension: (1, 1, H, W) + + field : torch.tensor + Input field data. + Dimension: (1, 6, H, W) + + Returns + ------- + sv_kernel : list of torch.tensor + Learned spatially varying kernels. + Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i), + where C_i, H_i, and W_i represent the channel, height, and width + of each feature at a certain scale. + """ + x = self.inc(torch.cat((focal_surface, field), dim = 1)) + downsampling_outputs = [focal_surface] + downsampling_outputs.append(x) + for i, down_layer in enumerate(self.encoder): + x_down = down_layer(downsampling_outputs[-1]) + downsampling_outputs.append(x_down) + sv_kernels = [] + for i, (up_layer, svf_layer) in enumerate(zip(self.decoder, self.spatially_varying_feature)): + if i == 0: + global_feature = up_layer(downsampling_outputs[-2], downsampling_outputs[-1]) + downsampling_outputs[-1] = global_feature + sv_feature = [global_feature, downsampling_outputs[0]] + for j in range(self.depth - i + 1): + sv_feature[1] = svf_layer[self.depth - i](sv_feature[1]) + if j > 0: + sv_feature.append(svf_layer[j](downsampling_outputs[2 * j])) + sv_feature = [sv_feature[0], sv_feature[1], sv_feature[4], sv_feature[2], + sv_feature[3]] + sv_kernel = svf_layer[-1](torch.cat(sv_feature, dim = 1)) + sv_kernels.append(sv_kernel) + else: + x_up = up_layer[0](downsampling_outputs[-1], + downsampling_outputs[2 * (self.depth + 1 - i) + 1]) + x_up = up_layer[1](x_up) + downsampling_outputs[-1] = x_up + sv_feature = [x_up, downsampling_outputs[0]] + for j in range(self.depth - i + 1): + sv_feature[1] = svf_layer[self.depth - i](sv_feature[1]) + if j > 0: + sv_feature.append(svf_layer[j](downsampling_outputs[2 * j])) + if i == 1: + sv_feature = [sv_feature[0], sv_feature[1], sv_feature[3], sv_feature[2]] + sv_kernel = svf_layer[-1](torch.cat(sv_feature, dim = 1)) + sv_kernels.append(sv_kernel) + return sv_kernels + + +class spatially_adaptive_unet(torch.nn.Module): + """ + Spatially varying U-Net model based on spatially adaptive convolution. + + References + ---------- + + Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Ak{\c{s}}it}. "Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions." SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24),December,2024. + """ + def __init__( + self, + depth=3, + dimensions=8, + input_channels=6, + out_channels=6, + kernel_size=3, + bias=True, + normalization=False, + activation=torch.nn.LeakyReLU(0.2, inplace=True) + ): + """ + U-Net model. + + Parameters + ---------- + depth : int + Number of upsampling and downsampling layers. + dimensions : int + Number of dimensions. + input_channels : int + Number of input channels. + out_channels : int + Number of output channels. + bias : bool + Set to True to let convolutional layers learn a bias term. + normalization : bool + If True, adds a Batch Normalization layer after the convolutional layer. + activation : torch.nn + Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()). + """ + super().__init__() + self.depth = depth + self.out_channels = out_channels + self.inc = convolution_layer( + input_channels=input_channels, + output_channels=dimensions, + kernel_size=kernel_size, + bias=bias, + normalization=normalization, + activation=activation + ) + + self.encoder = torch.nn.ModuleList() + for i in range(self.depth + 1): # Downsampling layers + down_in_channels = dimensions * (2 ** i) + down_out_channels = 2 * down_in_channels + pooling_layer = torch.nn.AvgPool2d(2) + double_convolution_layer = double_convolution( + input_channels=down_in_channels, + mid_channels=down_in_channels, + output_channels=down_in_channels, + kernel_size=kernel_size, + bias=bias, + normalization=normalization, + activation=activation + ) + sam = spatially_adaptive_module( + input_channels=down_in_channels, + output_channels=down_out_channels, + kernel_size=kernel_size, + bias=bias, + activation=activation + ) + self.encoder.append(torch.nn.ModuleList([pooling_layer, double_convolution_layer, sam])) + self.global_feature_module = torch.nn.ModuleList() + double_convolution_layer = double_convolution( + input_channels=dimensions * (2 ** (depth + 1)), + mid_channels=dimensions * (2 ** (depth + 1)), + output_channels=dimensions * (2 ** (depth + 1)), + kernel_size=kernel_size, + bias=bias, + normalization=normalization, + activation=activation + ) + global_feature_layer = global_feature_module( + input_channels=dimensions * (2 ** (depth + 1)), + mid_channels=dimensions * (2 ** (depth + 1)), + output_channels=dimensions * (2 ** (depth + 1)), + kernel_size=kernel_size, + bias=bias, + activation=torch.nn.LeakyReLU(0.2, inplace=True) + ) + self.global_feature_module.append(torch.nn.ModuleList([double_convolution_layer, global_feature_layer])) + self.decoder = torch.nn.ModuleList() + for i in range(depth, -1, -1): + up_in_channels = dimensions * (2 ** (i + 1)) + up_mid_channels = up_in_channels // 2 + if i == 0: + up_out_channels = self.out_channels + upsample_layer = upsample_convtranspose2d_layer( + input_channels=up_in_channels, + output_channels=up_mid_channels, + kernel_size=2, + stride=2, + bias=bias, + ) + conv_layer = torch.nn.Sequential( + convolution_layer( + input_channels=up_mid_channels, + output_channels=up_mid_channels, + kernel_size=kernel_size, + bias=bias, + normalization=normalization, + activation=activation, + ), + convolution_layer( + input_channels=up_mid_channels, + output_channels=up_out_channels, + kernel_size=1, + bias=bias, + normalization=normalization, + activation=None, + ) + ) + self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer])) + else: + up_out_channels = up_in_channels // 2 + upsample_layer = upsample_convtranspose2d_layer( + input_channels=up_in_channels, + output_channels=up_mid_channels, + kernel_size=2, + stride=2, + bias=bias, + ) + conv_layer = double_convolution( + input_channels=up_mid_channels, + mid_channels=up_mid_channels, + output_channels=up_out_channels, + kernel_size=kernel_size, + bias=bias, + normalization=normalization, + activation=activation, + ) + self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer])) + + + def forward(self, sv_kernel, field): + """ + Forward model. + + Parameters + ---------- + sv_kernel : list of torch.tensor + Learned spatially varying kernels. + Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i), + where C_i, H_i, and W_i represent the channel, height, and width + of each feature at a certain scale. + + field : torch.tensor + Input field data. + Dimension: (1, 6, H, W) + + Returns + ------- + target_field : torch.tensor + Estimated output. + Dimension: (1, 6, H, W) + """ + x = self.inc(field) + downsampling_outputs = [x] + for i, down_layer in enumerate(self.encoder): + x_down = down_layer[0](downsampling_outputs[-1]) + downsampling_outputs.append(x_down) + sam_output = down_layer[2](x_down + down_layer[1](x_down), sv_kernel[self.depth - i]) + downsampling_outputs.append(sam_output) + global_feature = self.global_feature_module[0][0](downsampling_outputs[-1]) + global_feature = self.global_feature_module[0][1](downsampling_outputs[-1], global_feature) + downsampling_outputs.append(global_feature) + x_up = downsampling_outputs[-1] + for i, up_layer in enumerate(self.decoder): + x_up = up_layer[0](x_up, downsampling_outputs[2 * (self.depth - i)]) + x_up = up_layer[1](x_up) + result = x_up + return result diff --git a/odak/learn/wave/models.py b/odak/learn/wave/models.py index acea5cda..0b98e98e 100644 --- a/odak/learn/wave/models.py +++ b/odak/learn/wave/models.py @@ -1,8 +1,10 @@ import torch import os +import json +import numpy as np from tqdm import tqdm -from ..models import unet -from .util import generate_complex_field, wavenumber +from ..models import * +from .util import generate_complex_field, wavenumber,calculate_amplitude class holobeam_multiholo(torch.nn.Module): @@ -133,3 +135,160 @@ def load_weights(self, filename = './weights.pt'): """ self.network.load_state_dict(torch.load(os.path.expanduser(filename))) self.network.eval() + + +class focal_surface_light_propagation(torch.nn.Module): + """ + focal_surface_light_propagation model. + + References + ---------- + + Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Ak{\c{s}}it}. "Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions." SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24),December,2024. + """ + def __init__( + self, + depth = 3, + dimensions = 8, + input_channels = 6, + out_channels = 6, + kernel_size = 3, + bias = True, + device = torch.device('cpu'), + activation = torch.nn.LeakyReLU(0.2, inplace = True) + ): + """ + Initializes the focal surface light propagation model. + + Parameters + ---------- + depth : int + Number of downsampling and upsampling layers. + dimensions : int + Number of dimensions/features in the model. + input_channels : int + Number of input channels. + out_channels : int + Number of output channels. + kernel_size : int + Size of the convolution kernel. + bias : bool + If True, allows convolutional layers to learn a bias term. + device : torch.device + Default device is CPU. + activation : torch.nn.Module + Activation function (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()). + """ + super().__init__() + self.depth = depth + self.device = device + self.sv_kernel_generation = spatially_varying_kernel_generation_model( + depth = depth, + dimensions = dimensions, + input_channels = input_channels + 1, # +1 to account for an extra channel + kernel_size = kernel_size, + bias = bias, + activation = activation + ) + self.light_propagation = spatially_adaptive_unet( + depth = depth, + dimensions = dimensions, + input_channels = input_channels, + out_channels = out_channels, + kernel_size = kernel_size, + bias = bias, + activation = activation + ) + + + def forward(self, focal_surface, phase_only_hologram): + """ + Forward pass through the model. + + Parameters + ---------- + focal_surface : torch.Tensor + Input focal surface. + phase_only_hologram : torch.Tensor + Input phase-only hologram. + + Returns + ---------- + result : torch.Tensor + Output tensor after light propagation. + """ + input_field = self.generate_input_field(phase_only_hologram) + sv_kernel = self.sv_kernel_generation(focal_surface, input_field) + output_field = self.light_propagation(sv_kernel, input_field) + final = (output_field[:, 0:3, :, :] + 1j * output_field[:, 3:6, :, :]) + result = calculate_amplitude(final) ** 2 + return result + + + def generate_input_field(self, phase_only_hologram): + """ + Generates an input field by combining the real and imaginary parts. + + Parameters + ---------- + phase_only_hologram : torch.Tensor + Input phase-only hologram. + + Returns + ---------- + input_field : torch.Tensor + Concatenated real and imaginary parts of the complex field. + """ + [b, c, h, w] = phase_only_hologram.size() + input_phase = phase_only_hologram * 2 * np.pi + hologram_amplitude = torch.ones(b, c, h, w, requires_grad = False) + field = generate_complex_field(hologram_amplitude, input_phase) + input_field = torch.cat((field.real, field.imag), dim = 1) + return input_field + + + def load_weights(self, weight_filename, key_mapping_filename): + """ + Function to load weights for this multi-layer perceptron from a file. + + Parameters + ---------- + weight_filename : str + Path to the old model's weight file. + key_mapping_filename : str + Path to the JSON file containing the key mappings. + """ + # Load old model weights + old_model_weights = torch.load(weight_filename, map_location = self.device) + + # Load key mappings from JSON file + with open(key_mapping_filename, 'r') as json_file: + key_mappings = json.load(json_file) + + # Extract the key mappings for sv_kernel_generation and light_prop + sv_kernel_generation_key_mapping = key_mappings['sv_kernel_generation_key_mapping'] + light_prop_key_mapping = key_mappings['light_prop_key_mapping'] + + # Initialize new state dicts + sv_kernel_generation_new_state_dict = {} + light_prop_new_state_dict = {} + + # Map and load sv_kernel_generation_model weights + for old_key, value in old_model_weights.items(): + if old_key in sv_kernel_generation_key_mapping: + # Map the old key to the new key + new_key = sv_kernel_generation_key_mapping[old_key] + sv_kernel_generation_new_state_dict[new_key] = value + + self.sv_kernel_generation.to(self.device) + self.sv_kernel_generation.load_state_dict(sv_kernel_generation_new_state_dict) + + # Map and load light_prop model weights + for old_key, value in old_model_weights.items(): + if old_key in light_prop_key_mapping: + # Map the old key to the new key + new_key = light_prop_key_mapping[old_key] + light_prop_new_state_dict[new_key] = value + self.light_propagation.to(self.device) + self.light_propagation.load_state_dict(light_prop_new_state_dict) + print("Weights loaded successfully into the new model.") diff --git a/test/data/sample_0343_focal_surface.png b/test/data/sample_0343_focal_surface.png new file mode 100644 index 00000000..d739096a Binary files /dev/null and b/test/data/sample_0343_focal_surface.png differ diff --git a/test/data/sample_0343_hologram.png b/test/data/sample_0343_hologram.png new file mode 100644 index 00000000..4cfdbb8d Binary files /dev/null and b/test/data/sample_0343_hologram.png differ diff --git a/test/test_learn_wave_focal_surface_light_propagation.py b/test/test_learn_wave_focal_surface_light_propagation.py new file mode 100644 index 00000000..e3e3ca9c --- /dev/null +++ b/test/test_learn_wave_focal_surface_light_propagation.py @@ -0,0 +1,84 @@ +import sys +import os +import odak +import torch +import requests + + +def test(output_directory = 'test_output'): + number_of_planes = 6 + location_offset = 0. + volume_depth = 5e-3 + device = torch.device('cpu') + + # Download the weight and key mapping files from GitHub + + weight_url = 'https://raw.githubusercontent.com/complight/focal_surface_holographic_light_transport/main/weight/model_0mm.pt' + key_mapping_url = 'https://raw.githubusercontent.com/complight/focal_surface_holographic_light_transport/main/weight/key_mappings.json' + weight_filename = os.path.join(output_directory, 'model_0mm.pt') + key_mapping_filename = os.path.join(output_directory, 'key_mappings.json') + download_file(weight_url, weight_filename) + download_file(key_mapping_url, key_mapping_filename) + + # Preparing focal surface + focal_surface_filename = os.path.join(output_directory, 'sample_0343_focal_surface.png') + focal_surface = odak.learn.tools.load_image( + focal_surface_filename, + normalizeby = 255., + torch_style = True + ).to(device) + distances = torch.linspace(-volume_depth / 2., volume_depth / 2., number_of_planes) + location_offset + y = (distances - torch.min(distances)) + distances = (y / torch.max(y)) + focal_surface = focal_surface * (number_of_planes - 1) + focal_surface = torch.round(focal_surface, decimals = 0) + for i in range(number_of_planes): + focal_surface = torch.where(focal_surface == i, distances[i], focal_surface) + focal_surface = focal_surface.unsqueeze(0).unsqueeze(0) + + # Preparing hologram + hologram_phases_filename = os.path.join(output_directory, 'sample_0343_hologram.png') + hologram_phases = odak.learn.tools.load_image( + hologram_phases_filename, + normalizeby = 255., + torch_style = True + ).to(device) + hologram_phases = hologram_phases.unsqueeze(0) + + # Load the focal surface light propagation model + focal_surface_light_propagation_model = odak.learn.wave.focal_surface_light_propagation(device = device) + focal_surface_light_propagation_model.load_weights( + weight_filename = weight_filename, + key_mapping_filename = key_mapping_filename + ) + + # Perform the focal surface light propagation model + result = focal_surface_light_propagation_model(focal_surface, hologram_phases) + odak.learn.tools.save_image( + '{}/reconstruction_image.png'.format(output_directory), + result, + cmin = 0., + cmax = 1. + ) + print("Reconstruction complete.") + return True + + +# Function to download a file from GitHub +def download_file(url, filename): + try: + print(f"Starting download: {url}") + response = requests.get(url, stream = True) + response.raise_for_status() + os.makedirs(os.path.dirname(filename), exist_ok = True) + with open(filename, 'wb') as file: + for chunk in response.iter_content(chunk_size = 8192): + file.write(chunk) + print(f"Downloaded: {filename}") + except requests.exceptions.RequestException as e: + print(f"Failed to download {url}. Error: {e}") + sys.exit(1) + + +if __name__ == '__main__': + sys.exit(test())