Skip to content

Commit

Permalink
Paralelize aolume align
Browse files Browse the repository at this point in the history
  • Loading branch information
larsborm committed Apr 26, 2023
1 parent 70c4758 commit 440f234
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 30 deletions.
2 changes: 1 addition & 1 deletion FISHscale/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,7 +1069,7 @@ def reset_working_selection(self):
"""Reset the working selection to include all datapoints.
"""
for d in self.datasets:
self.set_working_selection(level = None)
d.set_working_selection(level = None)

def visualize(
self,
Expand Down
85 changes: 56 additions & 29 deletions FISHscale/utils/volume_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,38 +184,33 @@ def find_warp(self, images: list, attachment: int=20,
mixing_factor: float=0.3, second_order=True):

#Mixing_factor = 0 #Completely use synthetic made by img0 and img2
#Mixing_facotro = 0.3 # All three images weigh equal.
#Mixing_factor = 0.3 # All three images weigh equal.
#Mixing_factor = 0.5 #img1 weighs 50% and img0 and img2 weigh 25% each
#Mixing_factore = 1 #Img1 weighs 100%
#Mixing_factor = 1 #Img1 weighs 100%

z_ordered = self._check_z_sort()
z_loc = [self.datasets[i].z for i in z_ordered]
n_datasets = len(self.datasets)
dataset_names = [self.datasets[i].dataset_name for i in z_ordered]

synthetic_images = []
warped_images = []
vs = []
us = []

def reg_warp(img0, img1, factor):
v, u = self._register_worker(img0, img1,
attachment=attachment,
tightness=tightness,
num_warp=num_warp,
num_iter=num_iter,
tol=tol,
prefilter=prefilter)
#Make synthetic image that would be the image inbetween img0 and img1
warped = self._warp(img1, v, u, factor=factor)

return warped, v, u

self.vp(f'Warping datasets. Z levels should be consecutive')
for i, dn in enumerate(dataset_names):
self.vp(f'Finding warp for dataset {i+1}/{n_datasets} with Z: {self.datasets[i].z}')

#Prepare imput

def worker(i, images, attachment, tightness, num_warp, num_iter, tol, prefilter, mixing_factor, second_order):

def reg_warp(img0, img1, factor):
v, u = self._register_worker(img0, img1,
attachment=attachment,
tightness=tightness,
num_warp=num_warp,
num_iter=num_iter,
tol=tol,
prefilter=prefilter)
#Make synthetic image that would be the image inbetween img0 and img1
warped = self._warp(img1, v, u, factor=factor)

return warped, v, u

#Prepare input
#First and second section
if i <= 1:
#img0
Expand Down Expand Up @@ -274,14 +269,45 @@ def reg_warp(img0, img1, factor):
synt_image, _, _ = reg_warp(img1, synt_image, factor=mixing_factor)
#Warp image
warped_image, v, u = reg_warp(synt_image, img1, factor=factor)

return v, u, synt_image, warped_image

#Output handling
synthetic_images.append(synt_image)
warped_images.append(warped_image)
vs.append(v)
us.append(u)
lazy_results = []
images_delayed = delayed(images)

for i in range(n_datasets):
r = delayed(worker)(i, images_delayed, attachment, tightness, num_warp, num_iter, tol, prefilter, mixing_factor, second_order)
lazy_results.append(r)

with ProgressBar():
result = dask.compute(*lazy_results)#, scheduler='processes', n_workers=self.cpu_count)

vs = []
us = []
synthetic_images = []
warped_images = []
for r in result:
vs.append(r[0])
us.append(r[1])
synthetic_images.append(r[2])
warped_images.append(r[3])

return vs, us, synthetic_images, warped_images


#self.vp(f'Warping datasets. Z levels should be consecutive')
#for i, dn in enumerate(dataset_names):
# self.vp(f'Finding warp for dataset {i+1}/{n_datasets} with Z: {self.datasets[i].z}')



#Output handling
#synthetic_images.append(synt_image)
#warped_images.append(warped_image)
#vs.append(v)
#us.append(u)

#return vs, us, synthetic_images, warped_images


def warp_all(self, squarebin:list, v:list, u:list):
Expand Down Expand Up @@ -362,6 +388,7 @@ def warped_per_gene(self, warped, bin_size: int=100, return_dict=False, z_locati
def interpolate_genes(self, warped_genes, bin_size: int=100):



gene_data = self.warped_per_gene(warped_genes, bin_size=bin_size, return_dict=False)
n_genes = len(gene_data)
#warped_sum = np.rollaxis(np.sum(gene_data, axis=3), 0, 3)
Expand Down

0 comments on commit 440f234

Please sign in to comment.