Skip to content

Commit

Permalink
nonuniform global
Browse files Browse the repository at this point in the history
  • Loading branch information
geoffwoollard committed Jan 23, 2025
1 parent f6e1ceb commit c9bb2bc
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def make_sparse_cost(idx_above_thresh, dtype):
coordinates = coordinates.reshape(-1, 3)
sparse_coordiantes = coordinates[idx_above_thresh.flatten()]
pairwise_distances = torch.cdist(sparse_coordiantes, sparse_coordiantes)
return pairwise_distances
return pairwise_distances, sparse_coordiantes


def normalize_mass_to_one(p):
Expand All @@ -52,15 +52,15 @@ def prepare_volume_and_distance(
volume = downsample_volume(volume, n_downsample_pix).numpy().astype(numpy_dtype)
idx_above_thresh = return_top_k_voxel_idxs(volume, top_k)
marginal = normalize_mass_to_one(volume[idx_above_thresh].flatten())
pairwise_distance = (
make_sparse_cost(idx_above_thresh, dtype=torch_dtype)
.numpy()
.astype(numpy_dtype)
)
pairwise_distance, sparse_coordiantes = [
x.numpy().astype(numpy_dtype)
for x in make_sparse_cost(idx_above_thresh, dtype=torch_dtype)
]

if normalize:
pairwise_distance /= pairwise_distance.max()
pairwise_distance = (cost_scale_factor * pairwise_distance) ** exponent
return marginal, pairwise_distance
return marginal, sparse_coordiantes, pairwise_distance


def gw_distance_wrapper_element_wise(
Expand Down Expand Up @@ -371,35 +371,45 @@ def setup_volume_and_distance(

marginals_i = np.empty((len(volumes_i), top_k))
marginals_j = np.empty((len(volumes_j), top_k))
sparse_coordinates_sets_i = np.empty((len(volumes_j), top_k, 3))
sparse_coordinates_sets_j = np.empty((len(volumes_j), top_k, 3))
pairwise_distances_i = np.empty((len(volumes_i), top_k, top_k))
pairwise_distances_j = np.empty((len(volumes_j), top_k, top_k))

for i in range(len(volumes_i)):
volume_i, pairwise_distance_i = prepare_volume_and_distance(
volumes_i[i],
top_k,
n_downsample_pix,
exponent,
cost_scale_factor,
normalize,
volume_i, sparse_coordinates_i, pairwise_distance_i = (
prepare_volume_and_distance(
volumes_i[i],
top_k,
n_downsample_pix,
exponent,
cost_scale_factor,
normalize,
)
)
marginals_i[i] = volume_i
sparse_coordinates_sets_i[i] = sparse_coordinates_i
pairwise_distances_i[i] = pairwise_distance_i
for j in range(len(volumes_j)):
volume_j, pairwise_distance_j = prepare_volume_and_distance(
volumes_j[j],
top_k,
n_downsample_pix,
exponent,
cost_scale_factor,
normalize,
volume_j, sparse_coordinates_j, pairwise_distance_j = (
prepare_volume_and_distance(
volumes_j[j],
top_k,
n_downsample_pix,
exponent,
cost_scale_factor,
normalize,
)
)
marginals_j[j] = volume_j
sparse_coordinates_sets_j[j] = sparse_coordinates_j
pairwise_distances_j[j] = pairwise_distance_j

return (
marginals_i,
marginals_j,
sparse_coordinates_sets_i,
sparse_coordinates_sets_j,
pairwise_distances_i,
pairwise_distances_j,
volumes_i,
Expand All @@ -419,6 +429,8 @@ def main(args):
(
marginals_i,
marginals_j,
_,
_,
pairwise_distances_i,
pairwise_distances_j,
volumes_i,
Expand Down
2 changes: 2 additions & 0 deletions src/cryo_challenge/_map_to_map/gromov_wasserstein/qp.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ def main(args):
marginals_j,
pairwise_distances_i,
pairwise_distances_j,
_,
_,
volumes_i,
volumes_j,
) = setup_volume_and_distance(
Expand Down

0 comments on commit c9bb2bc

Please sign in to comment.