-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Performance issue in bispectrum method and gradient #293
Comments
Hi @Hackasteroid142. Thanks for the very detailed issue. I think it may be an easy fix. Try changing the following in your blockwise call: |
Hi @sjperkins, thank you for replying. I followed your advice, but unfortunately, it didn't resolve the issue. However, another error has become more frequent than the one I mentioned before, and it is related to adjust_chunks. Here are the details of that error:
This error appears when i set a size different of -1 for the chunks when i read the dataset. Do you think there is a error in my implementation? or something is missing? |
Hi @Hackasteroid142. The following code should work. Note that I am using import dask.array as da
import numpy as np
from numba import njit
from daskms import xds_from_ms
import itertools
@njit(cache=True, nogil=True)
def get_bispectrum_data(data, ant1, ant2, comb, utime_inv, n_row, data_shape):
cnt = np.zeros(data_shape, dtype=np.int8)
data_bis = np.ones(data_shape, dtype=data.dtype)
for row in range(n_row):
ut = utime_inv[row]
a1 = ant1[row]
a2 = ant2[row]
for ic, c in enumerate(comb):
if (a1 == c[0]) and (a2 == c[1]):
data_bis[ic, ut, 0] = data[row]
cnt[ic, ut, :] += 1
elif (a1 == c[1]) and (a2 == c[-1]):
data_bis[ic, ut, 1] = data[row]
cnt[ic, ut, :] += 1
elif (a1 == c[0]) and (a2 == c[-1]):
data_bis[ic, ut, 2] = data[row].conjugate()
cnt[ic, ut, :] += 1
return data_bis, cnt
def get_bispectrum(data, ant1, ant2, comb, time, type):
utime, utime_inv = np.unique(time, return_inverse=True)
n_utime = utime.size
n_comb = len(comb)
n_row, n_chan, n_corr = data.shape
shape_data = (n_comb, n_utime, 3, n_chan, n_corr)
bis, cnt = get_bispectrum_data(data, ant1, ant2, comb, utime_inv, n_row, shape_data)
bis[cnt != 3] = 0
return bis
if __name__ == "__main__":
input_name = "path/to/ms"
xdsl = xds_from_ms(
input_name,
group_cols=("SCAN_NUMBER", "FIELD_ID", "DATA_DESC_ID"),
index_cols=("TIME", "ANTENNA1", "ANTENNA2"),
chunks={'row': 10000}
)
n_ant = 28
antenna_reference = 0
comb = itertools.combinations(range(n_ant), 3)
filter_comb = np.array([i for i in comb if antenna_reference in i])
results = []
for xds in xdsl:
bis = da.blockwise(
get_bispectrum,
('triangle', 't', 'p', 'f', 'c'),
xds.DATA.data,
('t', 'f', 'c'),
xds.ANTENNA1.data, ('t'),
xds.ANTENNA2.data, ('t'),
filter_comb,
None,
xds.TIME.data, ('t'),
type,
None,
align_arrays=False,
adjust_chunks={'t': (np.nan,)*xds.DATA.data.numblocks[0]},
new_axes={
'triangle': len(filter_comb),
'p': 3
},
dtype=xds.DATA.data.dtype
)
results.append(bis)
results = da.compute(results) |
Hi @sjperkins, I apologize for the delay in my response. I attempted the solution you provided, but as you mentioned, I need to be aware of the chunk sizes. In later operations, when I use Also, I attempted to modify your solution slightly by using
|
The following will make sure that the chunks on the bispectrum are correct. However, this will not guarantee that all rows associated with a single time will be in the same chunk. That is also possible but requires a fair amount of additional code. You can see an example of the approach here. import dask.array as da
import numpy as np
from numba import njit
from daskms import xds_from_ms
import itertools
@njit(cache=True, nogil=True)
def get_bispectrum_data(data, ant1, ant2, comb, utime_inv, n_row, data_shape):
cnt = np.zeros(data_shape, dtype=np.int8)
data_bis = np.ones(data_shape, dtype=data.dtype)
for row in range(n_row):
ut = utime_inv[row]
a1 = ant1[row]
a2 = ant2[row]
for ic, c in enumerate(comb):
if (a1 == c[0]) and (a2 == c[1]):
data_bis[ic, ut, 0] = data[row]
cnt[ic, ut, :] += 1
elif (a1 == c[1]) and (a2 == c[-1]):
data_bis[ic, ut, 1] = data[row]
cnt[ic, ut, :] += 1
elif (a1 == c[0]) and (a2 == c[-1]):
data_bis[ic, ut, 2] = data[row].conjugate()
cnt[ic, ut, :] += 1
return data_bis, cnt
def get_bispectrum(data, ant1, ant2, comb, time, type):
utime, utime_inv = np.unique(time, return_inverse=True)
n_utime = utime.size
n_comb = len(comb)
n_row, n_chan, n_corr = data.shape
shape_data = (n_comb, n_utime, 3, n_chan, n_corr)
bis, cnt = get_bispectrum_data(data, ant1, ant2, comb, utime_inv, n_row, shape_data)
bis[cnt != 3] = 0
return bis
def compute_utime_chunks(xdsl):
utime_chunks = []
for xds in xdsl:
time = xds.TIME.data
utimes_per_chunk = da.blockwise(
lambda t: np.array(len(np.unique(t))), "t",
time, "t",
adjust_chunks={'t': (np.nan,)*time.numblocks[0]}
)
utime_chunks.append(utimes_per_chunk)
return [tuple(utc) for utc in da.compute(utime_chunks)[0]]
if __name__ == "__main__":
input_name = "path/to/ms"
xdsl = xds_from_ms(
input_name,
group_cols=("SCAN_NUMBER", "FIELD_ID", "DATA_DESC_ID"),
index_cols=("TIME", "ANTENNA1", "ANTENNA2"),
chunks={'row': 10000}
)
n_ant = 28
antenna_reference = 0
comb = itertools.combinations(range(n_ant), 3)
filter_comb = np.array([i for i in comb if antenna_reference in i])
utime_chunks = compute_utime_chunks(xdsl)
results = []
for xds, utc in zip(xdsl, utime_chunks):
bis = da.blockwise(
get_bispectrum,
('triangle', 't', 'p', 'f', 'c'),
xds.DATA.data,
('t', 'f', 'c'),
xds.ANTENNA1.data, ('t'),
xds.ANTENNA2.data, ('t'),
filter_comb,
None,
xds.TIME.data, ('t'),
type,
None,
align_arrays=False,
adjust_chunks={'t': utc},
new_axes={
'triangle': len(filter_comb),
'p': 3
},
dtype=xds.DATA.data.dtype
)
results.append(bis)
results = da.compute(results) |
@Hackasteroid142 I took the liberty of adding the chunking functionality as it may be useful to other users. Here is the example, which now allows chunking by unique time. import dask.array as da
import numpy as np
from numba import njit
from daskms import xds_from_ms
import itertools
@njit(cache=True, nogil=True)
def get_bispectrum_data(data, ant1, ant2, comb, utime_inv, n_row, data_shape):
cnt = np.zeros(data_shape, dtype=np.int8)
data_bis = np.ones(data_shape, dtype=data.dtype)
for row in range(n_row):
ut = utime_inv[row]
a1 = ant1[row]
a2 = ant2[row]
for ic, c in enumerate(comb):
if (a1 == c[0]) and (a2 == c[1]):
data_bis[ic, ut, 0] = data[row]
cnt[ic, ut, :] += 1
elif (a1 == c[1]) and (a2 == c[-1]):
data_bis[ic, ut, 1] = data[row]
cnt[ic, ut, :] += 1
elif (a1 == c[0]) and (a2 == c[-1]):
data_bis[ic, ut, 2] = data[row].conjugate()
cnt[ic, ut, :] += 1
return data_bis, cnt
def get_bispectrum(data, ant1, ant2, comb, time, type):
utime, utime_inv = np.unique(time, return_inverse=True)
n_utime = utime.size
n_comb = len(comb)
n_row, n_chan, n_corr = data.shape
shape_data = (n_comb, n_utime, 3, n_chan, n_corr)
bis, cnt = get_bispectrum_data(data, ant1, ant2, comb, utime_inv, n_row, shape_data)
bis[cnt != 3] = 0
return bis
def utime_and_row_chunks(time, req_utime=1):
"""Internals of compute_utime_and_row_chunks."""
utime, utime_counts = np.unique(time, return_counts=True)
n_utime = utime.size
req_utime = req_utime or n_utime # Catch zero.
chunk_starts = np.arange(0, n_utime, req_utime)
utime_chunks = np.array(
[
req_utime if i + req_utime < n_utime else n_utime - i
for i in chunk_starts
]
)
row_chunks = np.add.reduceat(utime_counts, chunk_starts)
return np.stack([utime_chunks, row_chunks], axis=0)
def compute_utime_and_row_chunks(indexing_xdsl, req_utime=1):
"""Figure out the chunking in unique time and row. Triggers compute."""
chunking = []
for xds in indexing_xdsl:
chunking.append(
xds.TIME.data.map_blocks(
utime_and_row_chunks,
req_utime,
chunks=((2,), (np.nan,)),
new_axis=0,
dtype=int
)
)
result = da.compute(chunking)[0]
utime_chunks = [tuple(arr[0]) for arr in result]
row_chunks = [tuple(arr[1]) for arr in result]
return utime_chunks, row_chunks
if __name__ == "__main__":
input_name = "~/reductions/3C147/msdir/C147_unflagged.MS"
# Set up TIME only datasets which we can use to establish chunking. Note
# that we use only a single chunk per dataset.
indexing_xdsl = xds_from_ms(
input_name,
group_cols=("SCAN_NUMBER", "FIELD_ID", "DATA_DESC_ID"),
index_cols=("TIME", "ANTENNA1", "ANTENNA2"),
columns=("TIME",),
chunks={'row': -1}
)
# req_utime controls the number of unique times that you want in a
# single chunk. This triggers some early but very lightweight compute
# to figure out the required chunking - much cheaper than
# compute_chunk_sizes.
utime_chunks_list, row_chunks_list = compute_utime_and_row_chunks(
indexing_xdsl, req_utime=30
)
# Now that we know the desired chunking, load the data we want to
# manipulate with the required chunking.
xdsl = xds_from_ms(
input_name,
group_cols=("SCAN_NUMBER", "FIELD_ID", "DATA_DESC_ID"),
index_cols=("TIME", "ANTENNA1", "ANTENNA2"),
columns=("TIME", "ANTENNA1", "ANTENNA2", "DATA"),
chunks=[{'row': rcs} for rcs in row_chunks_list]
)
n_ant = 28
antenna_reference = 0
comb = itertools.combinations(range(n_ant), 3)
filter_comb = np.array([i for i in comb if antenna_reference in i])
results = []
for xds, utc in zip(xdsl, utime_chunks_list):
bis = da.blockwise(
get_bispectrum,
('triangle', 't', 'p', 'f', 'c'),
xds.DATA.data,
('t', 'f', 'c'),
xds.ANTENNA1.data, ('t'),
xds.ANTENNA2.data, ('t'),
filter_comb,
None,
xds.TIME.data, ('t'),
type,
None,
align_arrays=False,
adjust_chunks={'t': utc},
new_axes={
'triangle': len(filter_comb),
'p': 3
},
dtype=xds.DATA.data.dtype
)
results.append(bis)
results = da.compute(results) |
Thank you for the previous code; it was very helpful for working with different chunk sizes in the Bispectrum. However, I've been trying to use this to improve the performance of my Bispectrum method and the optimization method I'm working on. Despite implementing the code you provided, I haven't seen any improvement in time or memory usage. I'm working with an extract of the HD163296 dataset and a dirty image of this extract, which are available at this link: https://we.tl/t-R4nVGgd2vN. The image is used to calculate the model visibilities for the dataset. In the optimization method, I use the gradient of the chi-square Bispectrum as shown in this article (Equation 3 of the appendix) by Andrew A. Chael et al. Based on this code and the dataset I mentioned earlier, is there something I should take into consideration? Is there any way to further optimize my code? As background information, I conducted a small experiment to measure the time it takes to compute the gradient. It takes a few minutes to execute the gradient function thanks to Dask, but if I compute the result, it can take up to 2 days. Also, in the code, there is a class named "Mask" that aims to calculate only a portion of the image to improve execution time. It's like a matrix with zeros around it, representing the areas where calculations are not needed. I hope I have explained myself clearly, and the code is understandable. If you have any questions, I'd be happy to answer them. import astropy.units as un
import dask.array as da
import numpy as np
from numba import njit, prange
def gradient(dataset, mask):
if mask is None:
dchi2_1d = __gradient_no_mask(dataset)
else:
dchi2_1d = __gradient_mask(dataset, mask=mask)
_grad_value = gradient_image_reconstruction(dchi2_1d, mask)
return _grad_value
def gradient_image_reconstruction(image, mask):
# Array full with zeros to assign only valid indices with the values of the gradient
dchi2_1d = da.zeros((np.prod(image.data.shape), ), dtype=np.float32)
# Get ravel indices of the mask where its values are True
ravel_mask_idx = da.where(mask.data.data.ravel())[0]
ravel_mask_idx.compute_chunk_sizes()
# In the array, assign only valid indices (ravel_mask_idx) with the gradient
dchi2_1d[(ravel_mask_idx, )] = image
# Reshape gradient (transposed because indices are in F order)
image_2d = dchi2_1d.reshape_grad_value.shape).T
# Flip needed from fft
flipped = da.flip(image_2d, axis=[0, 1])
return flipped.rechunk(_grad_value.chunksize)
def __gradient_no_mask(dataset):
delta_x, delta_y = image.cellsize.to(un.rad).value
x_ind_2d, y_ind_2d, *z_ind_2d = da.indices(
image.data.data.shape, dtype=np.int32, chunks=image.data.data.chunksize
)
x_ind = x_ind_2d.ravel()
y_ind = y_ind_2d.ravel()
x_cell = __cell_index_delta(x_ind, delta_x, np.float32)
y_cell = __cell_index_delta(y_ind, delta_y, np.float32)
dchi2_1d = da.zeros_like(x_cell, dtype=np.float32, chunks=x_cell.chunks)
for i, ms in enumerate(dataset.ms_list):
dchi2_1d += __ms_gradient(ms, x_cell, y_cell, delta_x, delta_y)
return dchi2_1d
def __gradient_mask(dataset, mask):
delta_x, delta_y = image.cellsize.to(un.rad).value
x_ind, y_ind, *z_ind = mask.indices
x_cell = __cell_index_delta(x_ind, delta_x, np.float32)
y_cell = __cell_index_delta(y_ind, delta_y, np.float32)
dchi2_1d_masked = da.zeros_like(x_cell, dtype=np.float32, chunks=x_cell.chunks)
for i, ms in enumerate(dataset.ms_list):
dchi2_1d_masked += __ms_gradient(ms, x_cell, y_cell, delta_x, delta_y)
return dchi2_1d_masked
def __ms_gradient(ms, x_cell, y_cell, delta_x, delta_y, *, mask):
bis_obs = ms.visibilities.cal_data # (ncomb, ntime, nchans, ncorrs)
bis_model = ms.visibilities.cal_model
bis_weight = ms.visibilities.cal_weight
bis_r = bis_obs - bis_model # Calculando visibilidad residuo
vis = ms.visibilities.bis_data.data # (ncomb, ntime, nant, nchans, ncorrs)
uvw = ms.visibilities.bis_uvw.data.astype(np.float32) * un.m # (ncomb, ntime, nant, uvw)
pol_id = ms.polarization_id
corr_names = dataset.polarization.corrs_string[pol_id]
corr_idx = [x in _corr for x in corr_names]
bis_weight = bis_weight[:, :, :, corr_idx] # Filter by correlation
bis_r = bis_r[:, :, :, corr_idx] # Filter by correlation
spw_id = ms.spw_id
nchans = dataset.spws.nchans[spw_id]
chans = (dataset.spws.dataset[spw_id].CHAN_FREQ.data.squeeze(axis=0) *
un.Hz).rechunk(bis_obs.chunksize[-2])
uvw_lambdas = _uvw_lambdas(uvw, chans, nchans)
uv_lambdas = uvw_lambdas[:, :, :, :, :2]
uv_lambdas = uv_lambdas.map_blocks(
lambda x: x.value if isinstance(x, un.Quantity) else x, dtype=np.float32
)
phase_dirs_x = (image.data.shape[0] // 2) * delta_x
phase_dirs_y = (image.data.shape[1] // 2) * delta_y
x = x_cell - np.float32(phase_dirs_x)
y = y_cell - np.float32(phase_dirs_y)
chans = chans.map_blocks(lambda x: x.value if isinstance(x, un.Quantity) else x)
beam = __primary_beam(
chans,
image.data.shape,
chunks=mask.data.chunks if mask is not None else image.data.data.chunksize
)
return _array_gradient(
x, y, uv_lambdas, bis_weight, vis, bis_r, beam, bis_model, mask=mask
)
def _uvw_lambdas(uvw, chans, nchans):
chans_broadcast = chans[np.newaxis, :, np.newaxis]
uvw_broadcast = da.repeat(uvw[:, :, :, np.newaxis, :], nchans, axis=3)
uvw_lambdas = array_unit_conversion(
array=uvw_broadcast,
unit=un.lambdas,
equivalencies=lambdas_equivalencies(restfreq=chans_broadcast),
)
return uvw_lambdas
def _array_gradient(x, y, uv, w, vis, vr, pb, bm, *, mask):
data_br = da.blockwise(
_block_gradient,
("ncomb", "chan", "corr", "idx"),
x,
("idx", ),
y,
("idx", ),
uv,
("ncomb", "nutime", "nant", "chan", "corr"),
w,
("ncomb", "nutime", "chan", "corr"),
vis,
("ncomb", "nutime", "nant", "chan", "corr"),
vr,
("ncomb", "nutime", "chan", "corr"),
bm,
("ncomb", "nutime", "chan", "corr"),
adjust_chunks={
"ncomb": 1,
"corr": 1
},
dtype=np.float64,
)
data = data_br.sum(axis=(0, 2))
if mask is not None:
# Broadcast mask to match the shape of the PB, and obtain the masked primary beam values
# via bool indexing
broadcasted_mask = da.broadcast_to(mask.data.data, pb.shape)
# Get the Primary Beam masked values
pb_fitted = pb[broadcasted_mask]
pb_fitted.compute_chunk_sizes()
else:
pb_fitted = pb # If there is no mask, there is no need to filter the primary beam
# Reshape to 2-d where the second dim is the cell idx and the first dim is the channel
# (frequency) dim
pb_2d = pb_fitted.reshape((-1, x.shape[0]))
data = da.einsum('in,in->n', data, pb_2d)
return data
def _block_gradient(x, y, uv, w, vis, vr, bm):
uv = uv[0][0].astype(np.float64)
dchi2_broad = _ms_memory_gradient(x, y, uv, w[0], vis[0][0], vr[0], bm[0])
return dchi2_broad[None, :, None, :]
@staticmethod
@njit(nogil=True, cache=True, parallel=True)
def _ms_memory_gradient(x, y, uv, w, vis, vr, bm):
nidx = x.shape[0]
out_dtype = vis.real.dtype
ncomb, utime, nant, nchan, ncorr = vis.shape
ms_gradient = np.zeros((nchan, nidx), dtype=out_dtype)
for i in prange(nidx):
for c in range(ncomb):
for t in range(utime):
for a in range(nant):
for f in range(nchan):
u, v = uv[c, t, a, f]
uv_r = u * x[i] + v * y[i]
uv_r *= 2 * np.pi * 1j
a_ij = np.exp(uv_r)
for r in range(ncorr):
vm = vis[c, t, a, f, r].conjugate()
ed = a_ij / vm if vm != 0 else 0
i_sum = vr[c, t, f, r] * bm[c, t, f, r].conjugate() * ed
s_sum = -w[c, t, f, r] * i_sum.real
data = s_sum / ncomb
ms_gradient[f, i] += data
return ms_gradient
def __cell_index_delta(index_array, delta, dtype=np.float32):
"""
Scale an array by a delta, so that each cell/pixel has the increment in radians specified
by delta.
"""
return (index_array * delta).astype(dtype)
def __primary_beam(chans, shape, antenna=np.array([0]), chunks="auto"):
"""
Get the primary beam image for every frequency.
"""
beam = dataset.antenna.primary_beam.beam(
chans, shape, antenna=antenna, imchunks=chunks
)
beam = beam[0].astype(np.float32) # temporal indexing
return beam
input_name = "/home/datasets/ms_name.ms"
ms = xds_from_ms(
input_name,
group_cols=("SCAN_NUMBER", "FIELD_ID", "DATA_DESC_ID"),
index_cols=("TIME", "ANTENNA1", "ANTENNA2"),
chunks={'row': 1500}
)
mask = Mask()
res = gradient(ms, mask) |
Hi @Hackasteroid142! This problem is now beginning to enter the more general territory of optimising dask code. Doing so requires in depth understanding of dask graphs (https://docs.dask.org/en/latest/graphs.html). A good starting point is to inspect the graph (https://docs.dask.org/en/stable/graphviz.html). This can show you whether your implementation contains problematic many-to-many mappings which dask is known to struggle with. Regarding performance, be sure to remember that dask does few/no computations until In general, while it is useful to see your code, I would encourage you to post code which can be run. This makes it much easier for us to provide useful feedback. As it stands, I cannot give you more precise advice. |
Hello, I've been reviewing what you advised me, but I haven't achieved good results. I still have a very high runtime. As you instructed, here's a functional code that includes the chi-squared bispectrum, and I've also added the necessary components to calculate the bispectrum and the code you give me for the chunks. Any advice or help to improve my code would be greatly appreciated. If there's any issue with the code, I'll be attentive to help. from daskms import xds_from_ms
from numba import njit, prange
import astropy.units as un
import dask.array as da
import numpy as np
import itertools
def gradient(dataset, imageshape):
delta_x, delta_y = [-2.42406841e-08, 2.42406841e-08]
x_ind_2d, y_ind_2d, *z_ind_2d = da.indices(
imageshape, dtype=np.int32
)
x_ind = x_ind_2d.ravel()
y_ind = y_ind_2d.ravel()
x_cell = x_ind * delta_x
y_cell = y_ind * delta_y
grad = da.zeros_like(x_cell, dtype=np.float32)
for i, ms in enumerate(dataset):
grad += ms_gradient(ms, x_cell, y_cell, delta_x, delta_y, imageshape)
grad_2d = grad.reshape(imageshape).T
image = da.flip(grad_2d, axis=[0, 1])
return image
def ms_gradient(ms, x_cell, y_cell, delta_x, delta_y, imageshape):
bis_obs = ms.CAL_DATA.data # (ncomb, ntime, nchans, ncorrs)
bis_model = ms.CAL_MODEL.data
bis_weight = ms.CAL_WEIGHT.data
bis_r = bis_obs - bis_model
vis = ms.BIS_DATA.data # (ncomb, ntime, nant, nchans, ncorrs)
uvw = ms.BIS_UVW.data.astype(np.float32) * un.m # (ncomb, ntime, nant, uvw)
nchans = ms.DATA.shape[1]
uvw_lambdas = da.repeat(uvw[:, :, :, np.newaxis, :], nchans, axis=3)
uv_lambdas = uvw_lambdas[:, :, :, :, :2]
uv_lambdas = uv_lambdas.map_blocks(
lambda x: x.value if isinstance(x, un.Quantity) else x, dtype=np.float32
)
phase_dirs_x = (imageshape[0] // 2) * delta_x
phase_dirs_y = (imageshape[1] // 2) * delta_y
x = x_cell - np.float32(phase_dirs_x)
y = y_cell - np.float32(phase_dirs_y)
data_br = da.blockwise(
_block_gradient,
("ncomb", "chan", "corr", "idx"),
x,
("idx", ),
y,
("idx", ),
uv_lambdas,
("ncomb", "nutime", "nant", "chan", "corr"),
bis_weight,
("ncomb", "nutime", "chan", "corr"),
vis,
("ncomb", "nutime", "nant", "chan", "corr"),
bis_r,
("ncomb", "nutime", "chan", "corr"),
bis_model,
("ncomb", "nutime", "chan", "corr"),
dtype=np.float64,
)
data = data_br.sum(axis=(0, 2))
pb = da.ones_like(data)
data = da.einsum('in,in->n', data, pb)
return data
def _block_gradient(x, y, uv, w, vis, vr, bm):
uv = uv[0][0].astype(np.float64)
dchi2_broad = _ms_memory_gradient(x, y, uv, w[0], vis[0][0], vr[0], bm[0])
return dchi2_broad[None, :, None, :]
@staticmethod
@njit(nogil=True, cache=True, parallel=True)
def _ms_memory_gradient(x, y, uv, w, vis, vr, bm):
nidx = x.shape[0]
out_dtype = vis.real.dtype
ncomb, utime, nant, nchan, ncorr = vis.shape
ms_gradient = np.zeros((nchan, nidx), dtype=out_dtype)
for i in prange(nidx):
for c in range(ncomb):
for t in range(utime):
for a in range(nant):
for f in range(nchan):
u, v = uv[c, t, a, f]
uv_r = u * x[i] + v * y[i]
uv_r *= 2 * np.pi * 1j
a_ij = np.exp(uv_r)
for r in range(ncorr):
vm = vis[c, t, a, f, r].conjugate()
ed = a_ij / vm if vm != 0 else 0
i_sum = vr[c, t, f, r] * bm[c, t, f, r].conjugate() * ed
s_sum = -w[c, t, f, r] * i_sum.real
data = s_sum / ncomb
ms_gradient[f, i] += data
return ms_gradient
def utime_and_row_chunks(time, req_utime=1):
"""Internals of compute_utime_and_row_chunks."""
utime, utime_counts = np.unique(time, return_counts=True)
n_utime = utime.size
req_utime = req_utime or n_utime # Catch zero.
chunk_starts = np.arange(0, n_utime, req_utime)
utime_chunks = np.array(
[
req_utime if i + req_utime < n_utime else n_utime - i
for i in chunk_starts
]
)
row_chunks = np.add.reduceat(utime_counts, chunk_starts)
return np.stack([utime_chunks, row_chunks], axis=0)
def compute_utime_and_row_chunks(indexing_xdsl, req_utime=1):
"""Figure out the chunking in unique time and row. Triggers compute."""
chunking = []
for xds in indexing_xdsl:
chunking.append(
xds.TIME.data.map_blocks(
utime_and_row_chunks,
req_utime,
chunks=((2,), (np.nan,)),
new_axis=0,
dtype=int
)
)
result = da.compute(chunking)[0]
utime_chunks = [tuple(arr[0]) for arr in result]
row_chunks = [tuple(arr[1]) for arr in result]
return utime_chunks, row_chunks
@njit(cache=True, nogil=True)
def get_bispectrum_data(data, ant1, ant2, comb, utime_inv, n_row, data_shape):
cnt = np.zeros(data_shape, dtype=np.int8)
data_bis = np.ones(data_shape, dtype=data.dtype)
for row in range(n_row):
ut = utime_inv[row]
a1 = ant1[row]
a2 = ant2[row]
for ic, c in enumerate(comb):
if (a1 == c[0]) and (a2 == c[1]):
data_bis[ic, ut, 0] = data[row]
cnt[ic, ut, :] += 1
elif (a1 == c[1]) and (a2 == c[-1]):
data_bis[ic, ut, 1] = data[row]
cnt[ic, ut, :] += 1
elif (a1 == c[0]) and (a2 == c[-1]):
data_bis[ic, ut, 2] = data[row].conjugate()
cnt[ic, ut, :] += 1
return data_bis, cnt
def get_bispectrum(data, ant1, ant2, comb, time, type):
utime, utime_inv = np.unique(time, return_inverse=True)
n_utime = utime.size
n_comb = len(comb)
if type == 'UVW':
n_row, n_id = data.shape
shape_data = (n_comb, n_utime, 3, n_id)
else:
n_row, n_chan, n_corr = data.shape
shape_data = (n_comb, n_utime, 3, n_chan, n_corr)
bis, cnt = get_bispectrum_data(data, ant1, ant2, comb, utime_inv, n_row, shape_data)
bis[cnt != 3] = 0
return bis
def bispectrum_data(data, antenna1, antenna2, filter_comb, time, type, utime_chunks_list):
utime_size = np.unique(time).size
if type == 'UVW':
bis_shape = ('triangle', 't', 'p', 'id')
data_shape = ('t', 'id')
else:
bis_shape = ('triangle', 't', 'p', 'f', 'c')
data_shape = ('t', 'f', 'c')
bis = da.blockwise(
get_bispectrum,
bis_shape,
data,
data_shape,
antenna1, ('t'),
antenna2, ('t'),
filter_comb,
None,
time.data, ('t'),
type,
None,
adjust_chunks={'t': utime_chunks_list},
new_axes={
'triangle': len(filter_comb),
'p': 3
},
dtype=data.dtype
)
return bis
def bispectrum(dataset, utime_chunks_list):
n_ant = 28
antenna_reference = 0
comb = itertools.combinations(range(n_ant), 3)
filter_comb = np.array([i for i in comb if antenna_reference in i])
for i, xds in enumerate(dataset):
data = xds.DATA.data
model = xds.MODEL.data
weight = xds.WEIGHT.data
time = xds.TIME
antenna1 = xds.ANTENNA1.data
antenna2 = xds.ANTENNA2.data
uvw = xds.UVW.data
flags = xds.FLAG.data
data = data * ~flags
model = model * ~flags
weight = weight * ~flags
bm = bispectrum_data(model, antenna1, antenna2, filter_comb, time, 'DATA', utime_chunks_list[i])
uvw_bis = bispectrum_data(uvw, antenna1, antenna2, filter_comb, time, 'UVW', utime_chunks_list[i])
bw = bispectrum_data(weight, antenna1, antenna2, filter_comb, time, 'DATA', utime_chunks_list[i])
bo = bispectrum_data(data, antenna1, antenna2, filter_comb, time, 'DATA', utime_chunks_list[i])
cal_data = da.prod(bo, axis=2)
cal_model = da.prod(bm, axis=2)
weighted_vis = (da.absolute(bm)**2) * bw
aux = da.divide(
1, weighted_vis, out=da.zeros_like(weighted_vis), where=weighted_vis != 0
)
bis_sigma_squared = (da.absolute(cal_data)**2) * da.sum(aux, axis=2)
bis_weight = da.divide(
1,
bis_sigma_squared,
out=da.zeros_like(bis_sigma_squared),
where=bis_sigma_squared != 0
)
uvw_bis[:, :, 2] *= -1
dataset[i] = xds.assign({
"BIS_MODEL": (('triangle', 't', 'p', 'f', 'c'), bm),
"BIS_DATA": (('triangle', 't', 'p', 'f', 'c'), bo),
"BIS_WEIGHT": (('triangle', 't', 'p', 'f', 'c'), bw),
"BIS_UVW" : (('triangle', 't', 'p', 'id'), uvw_bis),
"CAL_DATA": (('triangle', 't', 'f', 'c'), cal_data),
"CAL_MODEL": (('triangle', 't', 'f', 'c'), cal_model),
"CAL_WEIGHT": (('triangle', 't', 'f', 'c'), bis_weight),
})
return dataset
if __name__ == "__main__":
input_name = "/home/datasets/ms_name.ms"
dataset = xds_from_ms(
input_name,
group_cols=("SCAN_NUMBER", "FIELD_ID", "DATA_DESC_ID"),
index_cols=("TIME", "ANTENNA1", "ANTENNA2"),
chunks={'row': -1}
)
utime_chunks_list, row_chunks_list = compute_utime_and_row_chunks(
dataset, req_utime=30
)
xdsl = xds_from_ms(
input_name,
group_cols=("SCAN_NUMBER", "FIELD_ID", "DATA_DESC_ID"),
index_cols=("TIME", "ANTENNA1", "ANTENNA2"),
chunks=[{'row': rcs} for rcs in row_chunks_list]
)
xdsl = [
xds.assign(
{
"MODEL": (
xds.DATA.dims,
da.ones(
xds.DATA.data.shape,
dtype=np.complex64,
chunks=xds.DATA.data.chunksize
) / 2
),
"WEIGHT": (
xds.DATA.dims,
da.tile(xds.WEIGHT.data, xds.DATA.shape[1]).reshape(len(xds.WEIGHT.data),xds.DATA.shape[1], xds.DATA.shape[2])
)
}
) for xds in xdsl
]
dataset_bis = bispectrum(xdsl, utime_chunks_list)
grad = gradient(dataset_bis, (512,512)) |
Hi @Hackasteroid142. I had an offline chat with @JSKenyon regarding the above code. Running it seems to take all available cores, so without fully understanding the optimality of the code for the individual chunks, the dask part seems to be doing its job. Additionally, it seems to be some variant of the DFT which is always going to be much slower than an FFT. For example, https://github.com/caracal-pipeline/crystalball does a DFT predict of a WSClean sky model using dask This can take up to a week for large Measurement Sets and sources. DFT's tend to be embarrasingly parallel, so your algorithm will probably benefit from distributing the problem on a compute cluster. Unfortunately we don't have the resources to help set this up and debug -- I suggest you consult some of the dask resources available on the internet. |
Description
Hello, some time ago I opened an issue where I asked for help with the implementation of the bispectrum. Taking into account the responses I received, I started developing a gradient and an optimizer for obtaining images using this method. However, I've encountered some issues in terms of memory and time. I've tried to address this by changing the default chunk size in dataset reading, but I've run into problems with shape values.
Currently, I am reading the dataset using the
xds_from_ms
function from Dask. However, if I change the chunk size to another value, such as 1500, I get the following error.Based on the tests I've conducted, I believe the issue may be related to the
adjust_chunks
parameter in theblockwise
function that I'm using for the bispectrum. However, I've tried editing this parameter to solve the problem, but none of my attempts have been successful. Is there a way to resolve this? Additionally, do you think working with the dataset in this manner and performing the bispectrum calculation like this could lead to performance problems?Here's an excerpt from the code I'm using. If anything is unclear, I'd be happy to provide additional information, and any help is greatly appreciated. Thank you in advance!
The text was updated successfully, but these errors were encountered: