From c9bb2bcd55f165cb08ff3c0c4b979e234695b818 Mon Sep 17 00:00:00 2001 From: Geoffrey Woollard Date: Wed, 22 Jan 2025 22:45:25 -0500 Subject: [PATCH] nonuniform global --- .../gromov_wasserstein/gw_weighted_voxels.py | 54 +++++++++++-------- .../_map_to_map/gromov_wasserstein/qp.py | 2 + 2 files changed, 35 insertions(+), 21 deletions(-) diff --git a/src/cryo_challenge/_map_to_map/gromov_wasserstein/gw_weighted_voxels.py b/src/cryo_challenge/_map_to_map/gromov_wasserstein/gw_weighted_voxels.py index 5ce1167..69b9bfb 100644 --- a/src/cryo_challenge/_map_to_map/gromov_wasserstein/gw_weighted_voxels.py +++ b/src/cryo_challenge/_map_to_map/gromov_wasserstein/gw_weighted_voxels.py @@ -36,7 +36,7 @@ def make_sparse_cost(idx_above_thresh, dtype): coordinates = coordinates.reshape(-1, 3) sparse_coordiantes = coordinates[idx_above_thresh.flatten()] pairwise_distances = torch.cdist(sparse_coordiantes, sparse_coordiantes) - return pairwise_distances + return pairwise_distances, sparse_coordiantes def normalize_mass_to_one(p): @@ -52,15 +52,15 @@ def prepare_volume_and_distance( volume = downsample_volume(volume, n_downsample_pix).numpy().astype(numpy_dtype) idx_above_thresh = return_top_k_voxel_idxs(volume, top_k) marginal = normalize_mass_to_one(volume[idx_above_thresh].flatten()) - pairwise_distance = ( - make_sparse_cost(idx_above_thresh, dtype=torch_dtype) - .numpy() - .astype(numpy_dtype) - ) + pairwise_distance, sparse_coordiantes = [ + x.numpy().astype(numpy_dtype) + for x in make_sparse_cost(idx_above_thresh, dtype=torch_dtype) + ] + if normalize: pairwise_distance /= pairwise_distance.max() pairwise_distance = (cost_scale_factor * pairwise_distance) ** exponent - return marginal, pairwise_distance + return marginal, sparse_coordiantes, pairwise_distance def gw_distance_wrapper_element_wise( @@ -371,35 +371,45 @@ def setup_volume_and_distance( marginals_i = np.empty((len(volumes_i), top_k)) marginals_j = np.empty((len(volumes_j), top_k)) + sparse_coordinates_sets_i = np.empty((len(volumes_j), top_k, 3)) + sparse_coordinates_sets_j = np.empty((len(volumes_j), top_k, 3)) pairwise_distances_i = np.empty((len(volumes_i), top_k, top_k)) pairwise_distances_j = np.empty((len(volumes_j), top_k, top_k)) for i in range(len(volumes_i)): - volume_i, pairwise_distance_i = prepare_volume_and_distance( - volumes_i[i], - top_k, - n_downsample_pix, - exponent, - cost_scale_factor, - normalize, + volume_i, sparse_coordinates_i, pairwise_distance_i = ( + prepare_volume_and_distance( + volumes_i[i], + top_k, + n_downsample_pix, + exponent, + cost_scale_factor, + normalize, + ) ) marginals_i[i] = volume_i + sparse_coordinates_sets_i[i] = sparse_coordinates_i pairwise_distances_i[i] = pairwise_distance_i for j in range(len(volumes_j)): - volume_j, pairwise_distance_j = prepare_volume_and_distance( - volumes_j[j], - top_k, - n_downsample_pix, - exponent, - cost_scale_factor, - normalize, + volume_j, sparse_coordinates_j, pairwise_distance_j = ( + prepare_volume_and_distance( + volumes_j[j], + top_k, + n_downsample_pix, + exponent, + cost_scale_factor, + normalize, + ) ) marginals_j[j] = volume_j + sparse_coordinates_sets_j[j] = sparse_coordinates_j pairwise_distances_j[j] = pairwise_distance_j return ( marginals_i, marginals_j, + sparse_coordinates_sets_i, + sparse_coordinates_sets_j, pairwise_distances_i, pairwise_distances_j, volumes_i, @@ -419,6 +429,8 @@ def main(args): ( marginals_i, marginals_j, + _, + _, pairwise_distances_i, pairwise_distances_j, volumes_i, diff --git a/src/cryo_challenge/_map_to_map/gromov_wasserstein/qp.py b/src/cryo_challenge/_map_to_map/gromov_wasserstein/qp.py index 21309df..6292a09 100644 --- a/src/cryo_challenge/_map_to_map/gromov_wasserstein/qp.py +++ b/src/cryo_challenge/_map_to_map/gromov_wasserstein/qp.py @@ -141,6 +141,8 @@ def main(args): marginals_j, pairwise_distances_i, pairwise_distances_j, + _, + _, volumes_i, volumes_j, ) = setup_volume_and_distance(