Skip to content

Commit

Permalink
Revisions.
Browse files Browse the repository at this point in the history
  • Loading branch information
kaanaksit committed Oct 24, 2024
1 parent 5a27ec1 commit 726372e
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 29 deletions.
3 changes: 1 addition & 2 deletions odak/learn/wave/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class focal_surface_light_propagation(torch.nn.Module):
References
----------
Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Ak{\c{s}}it}. "Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions." SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24),December,2024.
Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Akşit}. "Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions." SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24),December,2024.
"""
def __init__(
self,
Expand Down Expand Up @@ -291,4 +291,3 @@ def load_weights(self, weight_filename, key_mapping_filename):
light_prop_new_state_dict[new_key] = value
self.light_propagation.to(self.device)
self.light_propagation.load_state_dict(light_prop_new_state_dict)
print("Weights loaded successfully into the new model.")
Binary file added test/data/focal_surface_sample_model.pt
Binary file not shown.
156 changes: 156 additions & 0 deletions test/data/key_mappings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
{
"sv_kernel_generation_key_mapping": {
"net_kernel.inc.0.conv.0.weight": "inc.model.0.weight",
"net_kernel.inc.0.conv.0.bias": "inc.model.0.bias",
"net_kernel.conv1.0.conv.0.weight": "encoder.1.model.0.model.0.weight",
"net_kernel.conv1.0.conv.0.bias": "encoder.1.model.0.model.0.bias",
"net_kernel.conv1.1.conv.0.weight": "encoder.1.model.1.model.0.weight",
"net_kernel.conv1.1.conv.0.bias": "encoder.1.model.1.model.0.bias",
"net_kernel.conv2.0.conv.0.weight": "encoder.3.model.0.model.0.weight",
"net_kernel.conv2.0.conv.0.bias": "encoder.3.model.0.model.0.bias",
"net_kernel.conv2.1.conv.0.weight": "encoder.3.model.1.model.0.weight",
"net_kernel.conv2.1.conv.0.bias": "encoder.3.model.1.model.0.bias",
"net_kernel.conv3.0.conv.0.weight": "encoder.5.model.0.model.0.weight",
"net_kernel.conv3.0.conv.0.bias": "encoder.5.model.0.model.0.bias",
"net_kernel.conv3.1.conv.0.weight": "encoder.5.model.1.model.0.weight",
"net_kernel.conv3.1.conv.0.bias": "encoder.5.model.1.model.0.bias",
"net_kernel.conv4.0.conv.0.weight": "encoder.7.model.0.model.0.weight",
"net_kernel.conv4.0.conv.0.bias": "encoder.7.model.0.model.0.bias",
"net_kernel.conv4.1.conv.0.weight": "encoder.7.model.1.model.0.weight",
"net_kernel.conv4.1.conv.0.bias": "encoder.7.model.1.model.0.bias",
"net_kernel.glo1_sv.model.0.weight": "spatially_varying_feature.0.4.0.weight",
"net_kernel.glo1_sv.model.0.bias": "spatially_varying_feature.0.4.0.bias",
"net_kernel.glo1_sv.model.2.weight": "spatially_varying_feature.0.4.2.weight",
"net_kernel.glo1_sv.model.2.bias": "spatially_varying_feature.0.4.2.bias",
"net_kernel.glo1_sv.model.4.weight": "spatially_varying_feature.0.4.4.weight",
"net_kernel.glo1_sv.model.4.bias": "spatially_varying_feature.0.4.4.bias",
"net_kernel.convup1_sv.model.0.weight": "spatially_varying_feature.1.3.0.weight",
"net_kernel.convup1_sv.model.0.bias": "spatially_varying_feature.1.3.0.bias",
"net_kernel.convup1_sv.model.2.weight": "spatially_varying_feature.1.3.2.weight",
"net_kernel.convup1_sv.model.2.bias": "spatially_varying_feature.1.3.2.bias",
"net_kernel.convup1_sv.model.4.weight": "spatially_varying_feature.1.3.4.weight",
"net_kernel.convup1_sv.model.4.bias": "spatially_varying_feature.1.3.4.bias",
"net_kernel.convup2_sv.model.0.weight": "spatially_varying_feature.2.2.0.weight",
"net_kernel.convup2_sv.model.0.bias": "spatially_varying_feature.2.2.0.bias",
"net_kernel.convup2_sv.model.2.weight": "spatially_varying_feature.2.2.2.weight",
"net_kernel.convup2_sv.model.2.bias": "spatially_varying_feature.2.2.2.bias",
"net_kernel.convup2_sv.model.4.weight": "spatially_varying_feature.2.2.4.weight",
"net_kernel.convup2_sv.model.4.bias": "spatially_varying_feature.2.2.4.bias",
"net_kernel.convup3_sv.model.0.weight": "spatially_varying_feature.3.1.0.weight",
"net_kernel.convup3_sv.model.0.bias": "spatially_varying_feature.3.1.0.bias",
"net_kernel.convup3_sv.model.2.weight": "spatially_varying_feature.3.1.2.weight",
"net_kernel.convup3_sv.model.2.bias": "spatially_varying_feature.3.1.2.bias",
"net_kernel.convup3_sv.model.4.weight": "spatially_varying_feature.3.1.4.weight",
"net_kernel.convup3_sv.model.4.bias": "spatially_varying_feature.3.1.4.bias",
"net_kernel.convup1.0.conv.0.weight": "decoder.1.1.model.0.model.0.weight",
"net_kernel.convup1.0.conv.0.bias": "decoder.1.1.model.0.model.0.bias",
"net_kernel.convup1.1.conv.0.weight": "decoder.1.1.model.1.model.0.weight",
"net_kernel.convup1.1.conv.0.bias": "decoder.1.1.model.1.model.0.bias",
"net_kernel.convup2.0.conv.0.weight": "decoder.2.1.model.0.model.0.weight",
"net_kernel.convup2.0.conv.0.bias": "decoder.2.1.model.0.model.0.bias",
"net_kernel.convup2.1.conv.0.weight": "decoder.2.1.model.1.model.0.weight",
"net_kernel.convup2.1.conv.0.bias": "decoder.2.1.model.1.model.0.bias",
"net_kernel.convup3.0.conv.0.weight": "decoder.3.1.model.0.model.0.weight",
"net_kernel.convup3.0.conv.0.bias": "decoder.3.1.model.0.model.0.bias",
"net_kernel.convup3.1.conv.0.weight": "decoder.3.1.model.1.model.0.weight",
"net_kernel.convup3.1.conv.0.bias": "decoder.3.1.model.1.model.0.bias",
"net_kernel.glo.global_feature.0.weight": "decoder.0.transformations_1.global_feature_1.0.weight",
"net_kernel.glo.global_feature.0.bias": "decoder.0.transformations_1.global_feature_1.0.bias",
"net_kernel.glo.global_feature_1.0.weight": "decoder.0.transformations_1.global_feature_2.0.weight",
"net_kernel.glo.global_feature_1.0.bias": "decoder.0.transformations_1.global_feature_2.0.bias",
"net_kernel.glo1.global_feature.0.weight": "decoder.0.transformations_2.global_feature_1.0.weight",
"net_kernel.glo1.global_feature.0.bias": "decoder.0.transformations_2.global_feature_1.0.bias",
"net_kernel.glo1.global_feature_1.0.weight": "decoder.0.transformations_2.global_feature_2.0.weight",
"net_kernel.glo1.global_feature_1.0.bias": "decoder.0.transformations_2.global_feature_2.0.bias",
"net_kernel.convglo.0.conv.0.weight": "decoder.0.global_features_1.model.0.model.0.weight",
"net_kernel.convglo.0.conv.0.bias": "decoder.0.global_features_1.model.0.model.0.bias",
"net_kernel.convglo.1.conv.0.weight": "decoder.0.global_features_1.model.1.model.0.weight",
"net_kernel.convglo.1.conv.0.bias": "decoder.0.global_features_1.model.1.model.0.bias",
"net_kernel.convglo1.0.conv.0.weight": "decoder.0.global_features_2.model.0.model.0.weight",
"net_kernel.convglo1.0.conv.0.bias": "decoder.0.global_features_2.model.0.model.0.bias",
"net_kernel.convglo1.1.conv.0.weight": "decoder.0.global_features_2.model.1.model.0.weight",
"net_kernel.convglo1.1.conv.0.bias": "decoder.0.global_features_2.model.1.model.0.bias",
"net_kernel.up1.up.weight": "decoder.1.0.up.weight",
"net_kernel.up1.up.bias": "decoder.1.0.up.bias",
"net_kernel.up2.up.weight": "decoder.2.0.up.weight",
"net_kernel.up2.up.bias": "decoder.2.0.up.bias",
"net_kernel.up3.up.weight": "decoder.3.0.up.weight",
"net_kernel.up3.up.bias": "decoder.3.0.up.bias"
},
"light_prop_key_mapping": {
"net_feature.inc.0.conv.0.weight": "inc.model.0.weight",
"net_feature.inc.0.conv.0.bias": "inc.model.0.bias",
"net_feature.conv1.0.conv.0.weight": "encoder.0.1.model.0.model.0.weight",
"net_feature.conv1.0.conv.0.bias": "encoder.0.1.model.0.model.0.bias",
"net_feature.conv1.1.conv.0.weight": "encoder.0.1.model.1.model.0.weight",
"net_feature.conv1.1.conv.0.bias": "encoder.0.1.model.1.model.0.bias",
"net_feature.conv2.0.conv.0.weight": "encoder.1.1.model.0.model.0.weight",
"net_feature.conv2.0.conv.0.bias": "encoder.1.1.model.0.model.0.bias",
"net_feature.conv2.1.conv.0.weight": "encoder.1.1.model.1.model.0.weight",
"net_feature.conv2.1.conv.0.bias": "encoder.1.1.model.1.model.0.bias",
"net_feature.conv3.0.conv.0.weight": "encoder.2.1.model.0.model.0.weight",
"net_feature.conv3.0.conv.0.bias": "encoder.2.1.model.0.model.0.bias",
"net_feature.conv3.1.conv.0.weight": "encoder.2.1.model.1.model.0.weight",
"net_feature.conv3.1.conv.0.bias": "encoder.2.1.model.1.model.0.bias",
"net_feature.conv4.0.conv.0.weight": "encoder.3.1.model.0.model.0.weight",
"net_feature.conv4.0.conv.0.bias": "encoder.3.1.model.0.model.0.bias",
"net_feature.conv4.1.conv.0.weight": "encoder.3.1.model.1.model.0.weight",
"net_feature.conv4.1.conv.0.bias": "encoder.3.1.model.1.model.0.bias",
"net_feature.sv1.weight": "encoder.0.2.weight",
"net_feature.sv1.net.weight": "encoder.0.2.standard_convolution.weight",
"net_feature.sv1.net.bias": "encoder.0.2.standard_convolution.bias",
"net_feature.sv2.weight": "encoder.1.2.weight",
"net_feature.sv2.net.weight": "encoder.1.2.standard_convolution.weight",
"net_feature.sv2.net.bias": "encoder.1.2.standard_convolution.bias",
"net_feature.sv3.weight": "encoder.2.2.weight",
"net_feature.sv3.net.weight": "encoder.2.2.standard_convolution.weight",
"net_feature.sv3.net.bias": "encoder.2.2.standard_convolution.bias",
"net_feature.sv4.weight": "encoder.3.2.weight",
"net_feature.sv4.net.weight": "encoder.3.2.standard_convolution.weight",
"net_feature.sv4.net.bias": "encoder.3.2.standard_convolution.bias",
"net_feature.convglo0.0.conv.0.weight": "global_feature_module.0.0.model.0.model.0.weight",
"net_feature.convglo0.0.conv.0.bias": "global_feature_module.0.0.model.0.model.0.bias",
"net_feature.convglo0.1.conv.0.weight": "global_feature_module.0.0.model.1.model.0.weight",
"net_feature.convglo0.1.conv.0.bias": "global_feature_module.0.0.model.1.model.0.bias",
"net_feature.glo.global_feature.0.weight": "global_feature_module.0.1.transformations_1.global_feature_1.0.weight",
"net_feature.glo.global_feature.0.bias": "global_feature_module.0.1.transformations_1.global_feature_1.0.bias",
"net_feature.glo.global_feature_1.0.weight": "global_feature_module.0.1.transformations_1.global_feature_2.0.weight",
"net_feature.glo.global_feature_1.0.bias": "global_feature_module.0.1.transformations_1.global_feature_2.0.bias",
"net_feature.convglo.0.conv.0.weight": "global_feature_module.0.1.global_features_1.model.0.model.0.weight",
"net_feature.convglo.0.conv.0.bias": "global_feature_module.0.1.global_features_1.model.0.model.0.bias",
"net_feature.convglo.1.conv.0.weight": "global_feature_module.0.1.global_features_1.model.1.model.0.weight",
"net_feature.convglo.1.conv.0.bias": "global_feature_module.0.1.global_features_1.model.1.model.0.bias",
"net_feature.convglo1.0.conv.0.weight": "global_feature_module.0.1.global_features_2.model.0.model.0.weight",
"net_feature.convglo1.0.conv.0.bias": "global_feature_module.0.1.global_features_2.model.0.model.0.bias",
"net_feature.convglo1.1.conv.0.weight": "global_feature_module.0.1.global_features_2.model.1.model.0.weight",
"net_feature.convglo1.1.conv.0.bias": "global_feature_module.0.1.global_features_2.model.1.model.0.bias",
"net_feature.glo1.global_feature.0.weight": "global_feature_module.0.1.transformations_2.global_feature_1.0.weight",
"net_feature.glo1.global_feature.0.bias": "global_feature_module.0.1.transformations_2.global_feature_1.0.bias",
"net_feature.glo1.global_feature_1.0.weight": "global_feature_module.0.1.transformations_2.global_feature_2.0.weight",
"net_feature.glo1.global_feature_1.0.bias": "global_feature_module.0.1.transformations_2.global_feature_2.0.bias",
"net_feature.up1.up.weight": "decoder.0.0.up.weight",
"net_feature.up1.up.bias": "decoder.0.0.up.bias",
"net_feature.up2.up.weight": "decoder.1.0.up.weight",
"net_feature.up2.up.bias": "decoder.1.0.up.bias",
"net_feature.up3.up.weight": "decoder.2.0.up.weight",
"net_feature.up3.up.bias": "decoder.2.0.up.bias",
"net_feature.up4.up.weight": "decoder.3.0.up.weight",
"net_feature.up4.up.bias": "decoder.3.0.up.bias",
"net_feature.convup1.0.conv.0.weight": "decoder.0.1.model.0.model.0.weight",
"net_feature.convup1.0.conv.0.bias": "decoder.0.1.model.0.model.0.bias",
"net_feature.convup1.1.conv.0.weight": "decoder.0.1.model.1.model.0.weight",
"net_feature.convup1.1.conv.0.bias": "decoder.0.1.model.1.model.0.bias",
"net_feature.convup2.0.conv.0.weight": "decoder.1.1.model.0.model.0.weight",
"net_feature.convup2.0.conv.0.bias": "decoder.1.1.model.0.model.0.bias",
"net_feature.convup2.1.conv.0.weight": "decoder.1.1.model.1.model.0.weight",
"net_feature.convup2.1.conv.0.bias": "decoder.1.1.model.1.model.0.bias",
"net_feature.convup3.0.conv.0.weight": "decoder.2.1.model.0.model.0.weight",
"net_feature.convup3.0.conv.0.bias": "decoder.2.1.model.0.model.0.bias",
"net_feature.convup3.1.conv.0.weight": "decoder.2.1.model.1.model.0.weight",
"net_feature.convup3.1.conv.0.bias": "decoder.2.1.model.1.model.0.bias",
"net_feature.convup4.0.conv.0.weight": "decoder.3.1.0.model.0.weight",
"net_feature.convup4.0.conv.0.bias": "decoder.3.1.0.model.0.bias",
"net_feature.outc.conv.weight": "decoder.3.1.1.model.0.weight",
"net_feature.outc.conv.bias": "decoder.3.1.1.model.0.bias"
}
}
33 changes: 6 additions & 27 deletions test/test_learn_wave_focal_surface_light_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,17 @@


def test(output_directory = 'test_output'):
odak.tools.check_directory(output_directory)
number_of_planes = 6
location_offset = 0.
volume_depth = 5e-3
device = torch.device('cpu')

# Download the weight and key mapping files from GitHub

weight_url = 'https://raw.githubusercontent.com/complight/focal_surface_holographic_light_transport/main/weight/model_0mm.pt'
key_mapping_url = 'https://raw.githubusercontent.com/complight/focal_surface_holographic_light_transport/main/weight/key_mappings.json'
weight_filename = os.path.join(output_directory, 'model_0mm.pt')
key_mapping_filename = os.path.join(output_directory, 'key_mappings.json')
download_file(weight_url, weight_filename)
download_file(key_mapping_url, key_mapping_filename)
weight_filename = 'test/data/focal_surface_sample_model.pt'
key_mapping_filename = 'test/data/key_mappings.json'

# Preparing focal surface
focal_surface_filename = os.path.join(output_directory, 'sample_0343_focal_surface.png')
focal_surface_filename = 'test/data/sample_0343_focal_surface.png'
focal_surface = odak.learn.tools.load_image(
focal_surface_filename,
normalizeby = 255.,
Expand All @@ -37,7 +32,7 @@ def test(output_directory = 'test_output'):
focal_surface = focal_surface.unsqueeze(0).unsqueeze(0)

# Preparing hologram
hologram_phases_filename = os.path.join(output_directory, 'sample_0343_hologram.png')
hologram_phases_filename = 'test/data/sample_0343_hologram.png'
hologram_phases = odak.learn.tools.load_image(
hologram_phases_filename,
normalizeby = 255.,
Expand All @@ -54,31 +49,15 @@ def test(output_directory = 'test_output'):

# Perform the focal surface light propagation model
result = focal_surface_light_propagation_model(focal_surface, hologram_phases)

odak.learn.tools.save_image(
'{}/reconstruction_image.png'.format(output_directory),
result,
cmin = 0.,
cmax = 1.
)
print("Reconstruction complete.")
return True


# Function to download a file from GitHub
def download_file(url, filename):
try:
print(f"Starting download: {url}")
response = requests.get(url, stream = True)
response.raise_for_status()
os.makedirs(os.path.dirname(filename), exist_ok = True)
with open(filename, 'wb') as file:
for chunk in response.iter_content(chunk_size = 8192):
file.write(chunk)
print(f"Downloaded: {filename}")
except requests.exceptions.RequestException as e:
print(f"Failed to download {url}. Error: {e}")
sys.exit(1)


if __name__ == '__main__':
sys.exit(test())

0 comments on commit 726372e

Please sign in to comment.