Skip to content

Commit

Permalink
Merge branch 'main' into nucal_modeling
Browse files Browse the repository at this point in the history
  • Loading branch information
tyler-a-cox authored Nov 9, 2023
2 parents 7bb7408 + 0f1b175 commit b33cf41
Show file tree
Hide file tree
Showing 10 changed files with 518 additions and 198 deletions.
6 changes: 6 additions & 0 deletions hera_cal/data/example_filter_params.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
filter_centers:
(0, 1): 0.1234
(0, 2): 0.173
filter_half_widths:
(0, 1): 0.05
(0, 2): 0.08
434 changes: 291 additions & 143 deletions hera_cal/frf.py

Large diffs are not rendered by default.

25 changes: 12 additions & 13 deletions hera_cal/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import os
import copy
import warnings
from packaging import version
import inspect
from functools import reduce
from collections.abc import Iterable
from pyuvdata import UVCal, UVData
Expand Down Expand Up @@ -1032,19 +1032,18 @@ def partial_write(self, output_path, data=None, flags=None, nsamples=None,
# else: # make a copy of this object and then update the relevant arrays using DataContainers
# this = copy.deepcopy(self)

if version.parse(pyuvdata.__version__) < version.parse("3.0"):
hd_writer.write_uvh5_part(output_path, d, f, n,
run_check_acceptability=(output_path in self._writers),
**self.last_read_kwargs)
write_kwargs = {
"data_array": d,
"nsample_array": n,
"run_check_acceptability": (output_path in self._writers),
**self.last_read_kwargs,
}
# before pyuvdata 3.0, the "flag_array" parameter was called "flags_array"
if "flag_array" in inspect.signature(UVData.write_uvh5_part).parameters:
write_kwargs["flag_array"] = f
else:
hd_writer.write_uvh5_part(
output_path,
data_array=d,
flag_array=f,
nsample_array=n,
run_check_acceptability=(output_path in self._writers),
**self.last_read_kwargs
)
write_kwargs["flags_array"] = f
hd_writer.write_uvh5_part(output_path, **write_kwargs)

def iterate_over_bls(self, Nbls=1, bls=None, chunk_by_redundant_group=False, reds=None,
bl_error_tol=1.0, include_autos=True, frequencies=None):
Expand Down
31 changes: 28 additions & 3 deletions hera_cal/lstbin_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,7 +844,7 @@ def lst_bin_files_for_baselines(
)
if inpfile is not None:
# This returns a DataContainer (unless something went wrong) since it should
# always be a 'baseline' type of UVFlag.s
# always be a 'baseline' type of UVFlag.
inpainted = io.load_flags(inpfile)
if not isinstance(inpainted, DataContainer):
raise ValueError(f"Expected {inpfile} to be a DataContainer")
Expand Down Expand Up @@ -881,14 +881,29 @@ def lst_bin_files_for_baselines(
for i, bl in enumerate(antpairs):
if redundantly_averaged:
bl = keyed.get_ubl_key(bl)

for j, pol in enumerate(pols):
blpol = bl + (pol,)

if blpol in _data: # DataContainer takes care of conjugates.
data[slc, i, :, j] = _data[blpol]
flags[slc, i, :, j] = _flags[blpol]
nsamples[slc, i, :, j] = _nsamples[blpol]

if inpainted is not None:
# Get the representative baseline key from this bl group that
# exists in the where_inpainted data.
if redundantly_averaged:
for inpbl in reds[bl]:
if inpbl + (pol,) in inpainted:
blpol = inpbl + (pol,)
break
else:
raise ValueError(
f"Could not find any baseline from group {bl} in "
"inpainted file"
)

where_inpainted[slc, i, :, j] = inpainted[blpol]
else:
# This baseline+pol doesn't exist in this file. That's
Expand Down Expand Up @@ -1309,6 +1324,7 @@ def lst_bin_files_single_outfile(
where_inpainted_files = _get_where_inpainted_files(
data_files, where_inpainted_file_rules
)

output_flagged, output_inpainted = _configure_inpainted_mode(
output_flagged, output_inpainted, where_inpainted_files
)
Expand All @@ -1317,6 +1333,8 @@ def lst_bin_files_single_outfile(
# they have no associated calibration)
data_files = [df for df in data_files if df]
input_cals = [cf for cf in input_cals if cf]
if where_inpainted_files is not None:
where_inpainted_files = [wif for wif in where_inpainted_files if wif]

logger.info("Got the following numbers of data files per night:")
for dflist in data_files:
Expand Down Expand Up @@ -1419,6 +1437,7 @@ def lst_bin_files_single_outfile(
input_cals,
where_inpainted_files,
)

# If we have no times at all for this file, just return
if len(all_lsts) == 0:
return {}
Expand Down Expand Up @@ -1579,6 +1598,7 @@ def lst_bin_files_single_outfile(
flags=rdc["flags"],
nsamples=rdc["nsamples"],
)

write_baseline_slc_to_file(
fl=out_files[("STD", inpainted)],
slc=slc,
Expand All @@ -1596,7 +1616,7 @@ def lst_bin_files_single_outfile(
nsamples=rdc["nsamples"],
)
write_baseline_slc_to_file(
fl=out_files[("STD", inpainted)],
fl=out_files[("MAD", inpainted)],
slc=slc,
data=rdc["mad"],
flags=rdc["flags"],
Expand Down Expand Up @@ -1858,14 +1878,19 @@ def create_lstbin_output_file(
if lst < lst_branch_cut:
lst += 2 * np.pi

fname = outdir / fname_format.format(
fname = fname_format.format(
kind=kind,
lst=lst,
pol="".join(pols),
inpaint_mode="inpaint"
if inpaint_mode
else ("flagged" if inpaint_mode is False else ""),
)
# There's a weird gotcha with pathlib where if you do path / "/file.name"
# You get just "/file.name" which is in root.
if fname.startswith('/'):
fname = fname[1:]
fname = outdir / fname

logger.info(f"Initializing {fname}")

Expand Down
10 changes: 4 additions & 6 deletions hera_cal/tests/mock_uvdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
with open(f"{DATA_PATH}/hera_antpos.yaml", "r") as fl:
HERA_ANTPOS = yaml.safe_load(fl)

start: 46920776.3671875
end: 234298706.0546875
delta: 122070.3125
PHASEII_FREQS = np.arange(
46920776.3671875, 234298706.0546875 + 10.0, 122070.3125
)


def create_mock_hera_obs(
Expand All @@ -35,9 +35,7 @@ def create_mock_hera_obs(
lst_start=0.1,
jd_start: float | None = None,
ntimes: int = 2,
freqs: np.ndarray = np.arange(
46920776.3671875, 234298706.0546875 + 10.0, 122070.3125
),
freqs: np.ndarray = PHASEII_FREQS,
pols: list[str] = ["xx", "yy", "xy", "yx"],
ants: list[int] | None = None,
antpairs: list[tuple[int, int]] | None = None,
Expand Down
84 changes: 84 additions & 0 deletions hera_cal/tests/test_frf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pyuvdata import UVData
from pyuvdata import utils as uvutils
import unittest
import yaml
from scipy import stats
from scipy import constants
from pyuvdata import UVFlag, UVBeam
Expand Down Expand Up @@ -930,6 +931,89 @@ def test_load_dayenu_filter_and_write(self, tmpdir):
os.remove(outfilename)
shutil.rmtree(cdir)

@pytest.mark.parametrize("flip", [True, False])
def test_load_tophat_frfilter_and_write_with_filter_yaml(self, tmpdir, flip):
tmp_path = tmpdir.strpath
uvh5 = os.path.join(
DATA_PATH, "test_input/zen.2458101.46106.xx.HH.OCR_53x_54x_only.uvh5"
)
frate_centers = {(53, 54): 0.1}
frate_half_widths = {(53, 54): 0.05}
if flip:
filter_info = dict(
filter_centers={
str(k[::-1]): -v for k, v in frate_centers.items()
},
filter_half_widths={
str(k[::-1]): v for k, v in frate_half_widths.items()
},
)
else:
filter_info = dict(
filter_centers={str(k): v for k, v in frate_centers.items()},
filter_half_widths={str(k): v for k, v in frate_half_widths.items()},
)
frate_centers = {k+("ee",): v for k, v in frate_centers.items()}
frate_half_widths = {k+("ee",): v for k, v in frate_half_widths.items()}

with open(tmpdir / "filter_info.yaml", "w") as f:
yaml.dump(filter_info, f)

outfilename = os.path.join(tmp_path, 'temp.h5')
frf.load_tophat_frfilter_and_write(
uvh5,
res_outfilename=outfilename,
tol=1e-4,
clobber=True,
Nbls_per_load=1,
param_file=str(tmpdir / "filter_info.yaml"),
case="param_file",
skip_autos=True,
)
hd = io.HERAData(outfilename)
d, f, n = hd.read(bls=[(53, 54, 'ee')])
for bl in d:
assert not np.allclose(d[bl], 0)

frfil = frf.FRFilter(uvh5, filetype='uvh5')
frfil.read(bls=[(53, 54, 'ee')])
frfil.tophat_frfilter(
keys=[(53, 54, 'ee')],
tol=1e-4,
verbose=True,
frate_centers=frate_centers,
frate_half_widths=frate_half_widths,
)

# Check that the filtered data both ways matches.
np.testing.assert_almost_equal(
d[(53, 54, 'ee')], frfil.clean_resid[(53, 54, 'ee')], decimal=5
)

# Check that the flags match.
np.testing.assert_array_equal(f[(53, 54, 'ee')], frfil.flags[(53, 54, 'ee')])


def test_load_tophat_frfilter_and_write_with_bad_filter_yaml(self, tmpdir):
tmp_path = tmpdir.strpath
uvh5 = os.path.join(
DATA_PATH, "test_input/zen.2458101.46106.xx.HH.OCR_53x_54x_only.uvh5"
)
param_file = os.path.join(DATA_PATH, "example_filter_params.yaml")

outfilename = os.path.join(tmp_path, 'temp.h5')
with pytest.raises(ValueError, match="(53, 54)"):
frf.load_tophat_frfilter_and_write(
uvh5,
res_outfilename=outfilename,
tol=1e-4,
clobber=True,
Nbls_per_load=1,
param_file=param_file,
case="param_file",
)


def test_tophat_clean_argparser(self):
sys.argv = [sys.argv[0], 'a', '--clobber', '--window', 'blackmanharris', '--max_frate_coeffs', '0.024', '-0.229']
parser = frf.tophat_frfilter_argparser()
Expand Down
83 changes: 81 additions & 2 deletions hera_cal/tests/test_lstbin_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,6 @@ def test_flag_below_min_N(self):
assert np.all(norm_n[0] == 2)
assert not np.any(np.isinf(std_n[0]))

print(np.sum(flg_n[1:]), flg_n[1:].size)
assert not np.any(flg_n[1:])
assert np.all(norm_n[1:] == 7)
assert not np.any(np.isinf(std_n[1:]))
Expand Down Expand Up @@ -859,6 +858,83 @@ def test_simple_redundant_averaged_file(self, uvd_redavg, uvd_redavg_file):
uvd_redavg.Npols,
)

def test_redavg_with_where_inpainted(self, tmp_path):
uvds = mockuvd.make_dataset(
ndays=2,
nfiles=3,
ntimes=2,
ants=np.arange(7),
creator=mockuvd.create_uvd_identifiable,
freqs=mockuvd.PHASEII_FREQS[:25],
pols=['xx', 'xy'],
redundantly_averaged=True,
)

uvd_files = mockuvd.write_files_in_hera_format(
uvds, tmp_path, add_where_inpainted_files=True
)

ap = uvds[0][0].get_antpairs()
reds = RedundantGroups.from_antpos(
dict(zip(uvds[0][0].antenna_numbers, uvds[0][0].antenna_positions)),
)
lstbins, d0, f0, n0, inpflg, times0 = lstbin_simple.lst_bin_files_for_baselines(
data_files=sum(uvd_files, []), # flatten the list-of-lists
lst_bin_edges=[0, 1.9 * np.pi],
redundantly_averaged=True,
rephase=False,
antpairs=ap,
reds=reds,
where_inpainted_files=[str(Path(f).with_suffix(".where_inpainted.h5")) for f in sum(uvd_files, [])],
)
assert len(lstbins) == 1

# Also test that if a where_inpainted file has missing baselines, an error is
# raised.
# This is kind of a dodgy way to test it: copy the original data files,
# write a whole new dataset in the same place but with fewer baselines, then
# copy the data files (but not the where_inpainted files) back, so they mismatch.
for flist in uvd_files:
for fl in flist:
fl = Path(fl)
fl.rename(fl.parent / f"{fl.with_suffix('.bk')}")

winp = fl.with_suffix(".where_inpainted.h5")
winp.unlink()

uvds = mockuvd.make_dataset(
ndays=2,
nfiles=3,
ntimes=2,
ants=np.arange(5), # less than the original
creator=mockuvd.create_uvd_identifiable,
freqs=mockuvd.PHASEII_FREQS[:25],
pols=['xx', 'xy'],
redundantly_averaged=True,
)

uvd_files = mockuvd.write_files_in_hera_format(
uvds, tmp_path, add_where_inpainted_files=True
)

# Move back the originals.
for flist in uvd_files:
for fl in flist:
fl = Path(fl)
fl.unlink()
(fl.parent / f"{fl.with_suffix('.bk')}").rename(fl)

with pytest.raises(ValueError, match="Could not find any baseline from group"):
lstbin_simple.lst_bin_files_for_baselines(
data_files=sum(uvd_files, []), # flatten the list-of-lists
lst_bin_edges=[0, 1.9 * np.pi],
redundantly_averaged=True,
rephase=False,
antpairs=ap,
reds=reds,
where_inpainted_files=[str(Path(f).with_suffix(".where_inpainted.h5")) for f in sum(uvd_files, [])],
)


def test_make_lst_grid():
lst_grid = lstbin_simple.make_lst_grid(0.01, begin_lst=None)
Expand Down Expand Up @@ -1463,9 +1539,12 @@ def test_inpaint_mode_no_flags(self, tmp_path_factory):
ntimes_per_file=2,
clobber=True,
)

# Additionally try fname format with leading / which should be removed
# automatically in the writing.
out_files = lstbin_simple.lst_bin_files(
config_file=cfl,
fname_format="zen.{kind}.{lst:7.5f}{inpaint_mode}.uvh5",
fname_format="/zen.{kind}.{lst:7.5f}.{inpaint_mode}.uvh5",
rephase=False,
sigma_clip_thresh=None,
sigma_clip_min_N=2,
Expand Down
Loading

0 comments on commit b33cf41

Please sign in to comment.