Skip to content

Commit

Permalink
Merge branch 'CharlieZCJ-master'
Browse files Browse the repository at this point in the history
  • Loading branch information
kaanaksit committed Oct 24, 2024
2 parents 3771c15 + 726372e commit 809dd7a
Show file tree
Hide file tree
Showing 8 changed files with 1,268 additions and 21 deletions.
512 changes: 494 additions & 18 deletions odak/learn/models/components.py

Large diffs are not rendered by default.

396 changes: 395 additions & 1 deletion odak/learn/models/models.py

Large diffs are not rendered by default.

162 changes: 160 additions & 2 deletions odak/learn/wave/models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import torch
import os
import json
import numpy as np
from tqdm import tqdm
from ..models import unet
from .util import generate_complex_field, wavenumber
from ..models import *
from .util import generate_complex_field, wavenumber,calculate_amplitude


class holobeam_multiholo(torch.nn.Module):
Expand Down Expand Up @@ -133,3 +135,159 @@ def load_weights(self, filename = './weights.pt'):
"""
self.network.load_state_dict(torch.load(os.path.expanduser(filename)))
self.network.eval()


class focal_surface_light_propagation(torch.nn.Module):
"""
focal_surface_light_propagation model.
References
----------
Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Akşit}. "Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions." SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24),December,2024.
"""
def __init__(
self,
depth = 3,
dimensions = 8,
input_channels = 6,
out_channels = 6,
kernel_size = 3,
bias = True,
device = torch.device('cpu'),
activation = torch.nn.LeakyReLU(0.2, inplace = True)
):
"""
Initializes the focal surface light propagation model.
Parameters
----------
depth : int
Number of downsampling and upsampling layers.
dimensions : int
Number of dimensions/features in the model.
input_channels : int
Number of input channels.
out_channels : int
Number of output channels.
kernel_size : int
Size of the convolution kernel.
bias : bool
If True, allows convolutional layers to learn a bias term.
device : torch.device
Default device is CPU.
activation : torch.nn.Module
Activation function (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).
"""
super().__init__()
self.depth = depth
self.device = device
self.sv_kernel_generation = spatially_varying_kernel_generation_model(
depth = depth,
dimensions = dimensions,
input_channels = input_channels + 1, # +1 to account for an extra channel
kernel_size = kernel_size,
bias = bias,
activation = activation
)
self.light_propagation = spatially_adaptive_unet(
depth = depth,
dimensions = dimensions,
input_channels = input_channels,
out_channels = out_channels,
kernel_size = kernel_size,
bias = bias,
activation = activation
)


def forward(self, focal_surface, phase_only_hologram):
"""
Forward pass through the model.
Parameters
----------
focal_surface : torch.Tensor
Input focal surface.
phase_only_hologram : torch.Tensor
Input phase-only hologram.
Returns
----------
result : torch.Tensor
Output tensor after light propagation.
"""
input_field = self.generate_input_field(phase_only_hologram)
sv_kernel = self.sv_kernel_generation(focal_surface, input_field)
output_field = self.light_propagation(sv_kernel, input_field)
final = (output_field[:, 0:3, :, :] + 1j * output_field[:, 3:6, :, :])
result = calculate_amplitude(final) ** 2
return result


def generate_input_field(self, phase_only_hologram):
"""
Generates an input field by combining the real and imaginary parts.
Parameters
----------
phase_only_hologram : torch.Tensor
Input phase-only hologram.
Returns
----------
input_field : torch.Tensor
Concatenated real and imaginary parts of the complex field.
"""
[b, c, h, w] = phase_only_hologram.size()
input_phase = phase_only_hologram * 2 * np.pi
hologram_amplitude = torch.ones(b, c, h, w, requires_grad = False)
field = generate_complex_field(hologram_amplitude, input_phase)
input_field = torch.cat((field.real, field.imag), dim = 1)
return input_field


def load_weights(self, weight_filename, key_mapping_filename):
"""
Function to load weights for this multi-layer perceptron from a file.
Parameters
----------
weight_filename : str
Path to the old model's weight file.
key_mapping_filename : str
Path to the JSON file containing the key mappings.
"""
# Load old model weights
old_model_weights = torch.load(weight_filename, map_location = self.device)

# Load key mappings from JSON file
with open(key_mapping_filename, 'r') as json_file:
key_mappings = json.load(json_file)

# Extract the key mappings for sv_kernel_generation and light_prop
sv_kernel_generation_key_mapping = key_mappings['sv_kernel_generation_key_mapping']
light_prop_key_mapping = key_mappings['light_prop_key_mapping']

# Initialize new state dicts
sv_kernel_generation_new_state_dict = {}
light_prop_new_state_dict = {}

# Map and load sv_kernel_generation_model weights
for old_key, value in old_model_weights.items():
if old_key in sv_kernel_generation_key_mapping:
# Map the old key to the new key
new_key = sv_kernel_generation_key_mapping[old_key]
sv_kernel_generation_new_state_dict[new_key] = value

self.sv_kernel_generation.to(self.device)
self.sv_kernel_generation.load_state_dict(sv_kernel_generation_new_state_dict)

# Map and load light_prop model weights
for old_key, value in old_model_weights.items():
if old_key in light_prop_key_mapping:
# Map the old key to the new key
new_key = light_prop_key_mapping[old_key]
light_prop_new_state_dict[new_key] = value
self.light_propagation.to(self.device)
self.light_propagation.load_state_dict(light_prop_new_state_dict)
Binary file added test/data/focal_surface_sample_model.pt
Binary file not shown.
156 changes: 156 additions & 0 deletions test/data/key_mappings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
{
"sv_kernel_generation_key_mapping": {
"net_kernel.inc.0.conv.0.weight": "inc.model.0.weight",
"net_kernel.inc.0.conv.0.bias": "inc.model.0.bias",
"net_kernel.conv1.0.conv.0.weight": "encoder.1.model.0.model.0.weight",
"net_kernel.conv1.0.conv.0.bias": "encoder.1.model.0.model.0.bias",
"net_kernel.conv1.1.conv.0.weight": "encoder.1.model.1.model.0.weight",
"net_kernel.conv1.1.conv.0.bias": "encoder.1.model.1.model.0.bias",
"net_kernel.conv2.0.conv.0.weight": "encoder.3.model.0.model.0.weight",
"net_kernel.conv2.0.conv.0.bias": "encoder.3.model.0.model.0.bias",
"net_kernel.conv2.1.conv.0.weight": "encoder.3.model.1.model.0.weight",
"net_kernel.conv2.1.conv.0.bias": "encoder.3.model.1.model.0.bias",
"net_kernel.conv3.0.conv.0.weight": "encoder.5.model.0.model.0.weight",
"net_kernel.conv3.0.conv.0.bias": "encoder.5.model.0.model.0.bias",
"net_kernel.conv3.1.conv.0.weight": "encoder.5.model.1.model.0.weight",
"net_kernel.conv3.1.conv.0.bias": "encoder.5.model.1.model.0.bias",
"net_kernel.conv4.0.conv.0.weight": "encoder.7.model.0.model.0.weight",
"net_kernel.conv4.0.conv.0.bias": "encoder.7.model.0.model.0.bias",
"net_kernel.conv4.1.conv.0.weight": "encoder.7.model.1.model.0.weight",
"net_kernel.conv4.1.conv.0.bias": "encoder.7.model.1.model.0.bias",
"net_kernel.glo1_sv.model.0.weight": "spatially_varying_feature.0.4.0.weight",
"net_kernel.glo1_sv.model.0.bias": "spatially_varying_feature.0.4.0.bias",
"net_kernel.glo1_sv.model.2.weight": "spatially_varying_feature.0.4.2.weight",
"net_kernel.glo1_sv.model.2.bias": "spatially_varying_feature.0.4.2.bias",
"net_kernel.glo1_sv.model.4.weight": "spatially_varying_feature.0.4.4.weight",
"net_kernel.glo1_sv.model.4.bias": "spatially_varying_feature.0.4.4.bias",
"net_kernel.convup1_sv.model.0.weight": "spatially_varying_feature.1.3.0.weight",
"net_kernel.convup1_sv.model.0.bias": "spatially_varying_feature.1.3.0.bias",
"net_kernel.convup1_sv.model.2.weight": "spatially_varying_feature.1.3.2.weight",
"net_kernel.convup1_sv.model.2.bias": "spatially_varying_feature.1.3.2.bias",
"net_kernel.convup1_sv.model.4.weight": "spatially_varying_feature.1.3.4.weight",
"net_kernel.convup1_sv.model.4.bias": "spatially_varying_feature.1.3.4.bias",
"net_kernel.convup2_sv.model.0.weight": "spatially_varying_feature.2.2.0.weight",
"net_kernel.convup2_sv.model.0.bias": "spatially_varying_feature.2.2.0.bias",
"net_kernel.convup2_sv.model.2.weight": "spatially_varying_feature.2.2.2.weight",
"net_kernel.convup2_sv.model.2.bias": "spatially_varying_feature.2.2.2.bias",
"net_kernel.convup2_sv.model.4.weight": "spatially_varying_feature.2.2.4.weight",
"net_kernel.convup2_sv.model.4.bias": "spatially_varying_feature.2.2.4.bias",
"net_kernel.convup3_sv.model.0.weight": "spatially_varying_feature.3.1.0.weight",
"net_kernel.convup3_sv.model.0.bias": "spatially_varying_feature.3.1.0.bias",
"net_kernel.convup3_sv.model.2.weight": "spatially_varying_feature.3.1.2.weight",
"net_kernel.convup3_sv.model.2.bias": "spatially_varying_feature.3.1.2.bias",
"net_kernel.convup3_sv.model.4.weight": "spatially_varying_feature.3.1.4.weight",
"net_kernel.convup3_sv.model.4.bias": "spatially_varying_feature.3.1.4.bias",
"net_kernel.convup1.0.conv.0.weight": "decoder.1.1.model.0.model.0.weight",
"net_kernel.convup1.0.conv.0.bias": "decoder.1.1.model.0.model.0.bias",
"net_kernel.convup1.1.conv.0.weight": "decoder.1.1.model.1.model.0.weight",
"net_kernel.convup1.1.conv.0.bias": "decoder.1.1.model.1.model.0.bias",
"net_kernel.convup2.0.conv.0.weight": "decoder.2.1.model.0.model.0.weight",
"net_kernel.convup2.0.conv.0.bias": "decoder.2.1.model.0.model.0.bias",
"net_kernel.convup2.1.conv.0.weight": "decoder.2.1.model.1.model.0.weight",
"net_kernel.convup2.1.conv.0.bias": "decoder.2.1.model.1.model.0.bias",
"net_kernel.convup3.0.conv.0.weight": "decoder.3.1.model.0.model.0.weight",
"net_kernel.convup3.0.conv.0.bias": "decoder.3.1.model.0.model.0.bias",
"net_kernel.convup3.1.conv.0.weight": "decoder.3.1.model.1.model.0.weight",
"net_kernel.convup3.1.conv.0.bias": "decoder.3.1.model.1.model.0.bias",
"net_kernel.glo.global_feature.0.weight": "decoder.0.transformations_1.global_feature_1.0.weight",
"net_kernel.glo.global_feature.0.bias": "decoder.0.transformations_1.global_feature_1.0.bias",
"net_kernel.glo.global_feature_1.0.weight": "decoder.0.transformations_1.global_feature_2.0.weight",
"net_kernel.glo.global_feature_1.0.bias": "decoder.0.transformations_1.global_feature_2.0.bias",
"net_kernel.glo1.global_feature.0.weight": "decoder.0.transformations_2.global_feature_1.0.weight",
"net_kernel.glo1.global_feature.0.bias": "decoder.0.transformations_2.global_feature_1.0.bias",
"net_kernel.glo1.global_feature_1.0.weight": "decoder.0.transformations_2.global_feature_2.0.weight",
"net_kernel.glo1.global_feature_1.0.bias": "decoder.0.transformations_2.global_feature_2.0.bias",
"net_kernel.convglo.0.conv.0.weight": "decoder.0.global_features_1.model.0.model.0.weight",
"net_kernel.convglo.0.conv.0.bias": "decoder.0.global_features_1.model.0.model.0.bias",
"net_kernel.convglo.1.conv.0.weight": "decoder.0.global_features_1.model.1.model.0.weight",
"net_kernel.convglo.1.conv.0.bias": "decoder.0.global_features_1.model.1.model.0.bias",
"net_kernel.convglo1.0.conv.0.weight": "decoder.0.global_features_2.model.0.model.0.weight",
"net_kernel.convglo1.0.conv.0.bias": "decoder.0.global_features_2.model.0.model.0.bias",
"net_kernel.convglo1.1.conv.0.weight": "decoder.0.global_features_2.model.1.model.0.weight",
"net_kernel.convglo1.1.conv.0.bias": "decoder.0.global_features_2.model.1.model.0.bias",
"net_kernel.up1.up.weight": "decoder.1.0.up.weight",
"net_kernel.up1.up.bias": "decoder.1.0.up.bias",
"net_kernel.up2.up.weight": "decoder.2.0.up.weight",
"net_kernel.up2.up.bias": "decoder.2.0.up.bias",
"net_kernel.up3.up.weight": "decoder.3.0.up.weight",
"net_kernel.up3.up.bias": "decoder.3.0.up.bias"
},
"light_prop_key_mapping": {
"net_feature.inc.0.conv.0.weight": "inc.model.0.weight",
"net_feature.inc.0.conv.0.bias": "inc.model.0.bias",
"net_feature.conv1.0.conv.0.weight": "encoder.0.1.model.0.model.0.weight",
"net_feature.conv1.0.conv.0.bias": "encoder.0.1.model.0.model.0.bias",
"net_feature.conv1.1.conv.0.weight": "encoder.0.1.model.1.model.0.weight",
"net_feature.conv1.1.conv.0.bias": "encoder.0.1.model.1.model.0.bias",
"net_feature.conv2.0.conv.0.weight": "encoder.1.1.model.0.model.0.weight",
"net_feature.conv2.0.conv.0.bias": "encoder.1.1.model.0.model.0.bias",
"net_feature.conv2.1.conv.0.weight": "encoder.1.1.model.1.model.0.weight",
"net_feature.conv2.1.conv.0.bias": "encoder.1.1.model.1.model.0.bias",
"net_feature.conv3.0.conv.0.weight": "encoder.2.1.model.0.model.0.weight",
"net_feature.conv3.0.conv.0.bias": "encoder.2.1.model.0.model.0.bias",
"net_feature.conv3.1.conv.0.weight": "encoder.2.1.model.1.model.0.weight",
"net_feature.conv3.1.conv.0.bias": "encoder.2.1.model.1.model.0.bias",
"net_feature.conv4.0.conv.0.weight": "encoder.3.1.model.0.model.0.weight",
"net_feature.conv4.0.conv.0.bias": "encoder.3.1.model.0.model.0.bias",
"net_feature.conv4.1.conv.0.weight": "encoder.3.1.model.1.model.0.weight",
"net_feature.conv4.1.conv.0.bias": "encoder.3.1.model.1.model.0.bias",
"net_feature.sv1.weight": "encoder.0.2.weight",
"net_feature.sv1.net.weight": "encoder.0.2.standard_convolution.weight",
"net_feature.sv1.net.bias": "encoder.0.2.standard_convolution.bias",
"net_feature.sv2.weight": "encoder.1.2.weight",
"net_feature.sv2.net.weight": "encoder.1.2.standard_convolution.weight",
"net_feature.sv2.net.bias": "encoder.1.2.standard_convolution.bias",
"net_feature.sv3.weight": "encoder.2.2.weight",
"net_feature.sv3.net.weight": "encoder.2.2.standard_convolution.weight",
"net_feature.sv3.net.bias": "encoder.2.2.standard_convolution.bias",
"net_feature.sv4.weight": "encoder.3.2.weight",
"net_feature.sv4.net.weight": "encoder.3.2.standard_convolution.weight",
"net_feature.sv4.net.bias": "encoder.3.2.standard_convolution.bias",
"net_feature.convglo0.0.conv.0.weight": "global_feature_module.0.0.model.0.model.0.weight",
"net_feature.convglo0.0.conv.0.bias": "global_feature_module.0.0.model.0.model.0.bias",
"net_feature.convglo0.1.conv.0.weight": "global_feature_module.0.0.model.1.model.0.weight",
"net_feature.convglo0.1.conv.0.bias": "global_feature_module.0.0.model.1.model.0.bias",
"net_feature.glo.global_feature.0.weight": "global_feature_module.0.1.transformations_1.global_feature_1.0.weight",
"net_feature.glo.global_feature.0.bias": "global_feature_module.0.1.transformations_1.global_feature_1.0.bias",
"net_feature.glo.global_feature_1.0.weight": "global_feature_module.0.1.transformations_1.global_feature_2.0.weight",
"net_feature.glo.global_feature_1.0.bias": "global_feature_module.0.1.transformations_1.global_feature_2.0.bias",
"net_feature.convglo.0.conv.0.weight": "global_feature_module.0.1.global_features_1.model.0.model.0.weight",
"net_feature.convglo.0.conv.0.bias": "global_feature_module.0.1.global_features_1.model.0.model.0.bias",
"net_feature.convglo.1.conv.0.weight": "global_feature_module.0.1.global_features_1.model.1.model.0.weight",
"net_feature.convglo.1.conv.0.bias": "global_feature_module.0.1.global_features_1.model.1.model.0.bias",
"net_feature.convglo1.0.conv.0.weight": "global_feature_module.0.1.global_features_2.model.0.model.0.weight",
"net_feature.convglo1.0.conv.0.bias": "global_feature_module.0.1.global_features_2.model.0.model.0.bias",
"net_feature.convglo1.1.conv.0.weight": "global_feature_module.0.1.global_features_2.model.1.model.0.weight",
"net_feature.convglo1.1.conv.0.bias": "global_feature_module.0.1.global_features_2.model.1.model.0.bias",
"net_feature.glo1.global_feature.0.weight": "global_feature_module.0.1.transformations_2.global_feature_1.0.weight",
"net_feature.glo1.global_feature.0.bias": "global_feature_module.0.1.transformations_2.global_feature_1.0.bias",
"net_feature.glo1.global_feature_1.0.weight": "global_feature_module.0.1.transformations_2.global_feature_2.0.weight",
"net_feature.glo1.global_feature_1.0.bias": "global_feature_module.0.1.transformations_2.global_feature_2.0.bias",
"net_feature.up1.up.weight": "decoder.0.0.up.weight",
"net_feature.up1.up.bias": "decoder.0.0.up.bias",
"net_feature.up2.up.weight": "decoder.1.0.up.weight",
"net_feature.up2.up.bias": "decoder.1.0.up.bias",
"net_feature.up3.up.weight": "decoder.2.0.up.weight",
"net_feature.up3.up.bias": "decoder.2.0.up.bias",
"net_feature.up4.up.weight": "decoder.3.0.up.weight",
"net_feature.up4.up.bias": "decoder.3.0.up.bias",
"net_feature.convup1.0.conv.0.weight": "decoder.0.1.model.0.model.0.weight",
"net_feature.convup1.0.conv.0.bias": "decoder.0.1.model.0.model.0.bias",
"net_feature.convup1.1.conv.0.weight": "decoder.0.1.model.1.model.0.weight",
"net_feature.convup1.1.conv.0.bias": "decoder.0.1.model.1.model.0.bias",
"net_feature.convup2.0.conv.0.weight": "decoder.1.1.model.0.model.0.weight",
"net_feature.convup2.0.conv.0.bias": "decoder.1.1.model.0.model.0.bias",
"net_feature.convup2.1.conv.0.weight": "decoder.1.1.model.1.model.0.weight",
"net_feature.convup2.1.conv.0.bias": "decoder.1.1.model.1.model.0.bias",
"net_feature.convup3.0.conv.0.weight": "decoder.2.1.model.0.model.0.weight",
"net_feature.convup3.0.conv.0.bias": "decoder.2.1.model.0.model.0.bias",
"net_feature.convup3.1.conv.0.weight": "decoder.2.1.model.1.model.0.weight",
"net_feature.convup3.1.conv.0.bias": "decoder.2.1.model.1.model.0.bias",
"net_feature.convup4.0.conv.0.weight": "decoder.3.1.0.model.0.weight",
"net_feature.convup4.0.conv.0.bias": "decoder.3.1.0.model.0.bias",
"net_feature.outc.conv.weight": "decoder.3.1.1.model.0.weight",
"net_feature.outc.conv.bias": "decoder.3.1.1.model.0.bias"
}
}
Binary file added test/data/sample_0343_focal_surface.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/data/sample_0343_hologram.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 809dd7a

Please sign in to comment.