Skip to content

Commit

Permalink
Merge pull request #930 from HERA-Team/jsdillon-patch-1
Browse files Browse the repository at this point in the history
Change setup to setup_method in nucal test
  • Loading branch information
tyler-a-cox authored Jan 29, 2024
2 parents d45df27 + ff5d892 commit 45ed23c
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 27 deletions.
16 changes: 9 additions & 7 deletions hera_cal/frf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1493,7 +1493,8 @@ def time_avg_data_and_write(input_data_list, output_data, t_avg, baseline_list=N
avg_lsts = avg_lsts[:ntimes]

# write data
output_data_name = output_data.replace('.uvh5', f'.interleave_{inum}.uvh5')
output_data_name = os.path.join(os.path.dirname(output_data),
os.path.basename(output_data).replace('.uvh5', f'.interleave_{inum}.uvh5'))
fr.write_data(data=avg_data, filename=output_data_name, flags=avg_flags, nsamples=avg_nsamples,
times=avg_times, lsts=avg_lsts, filetype=filetype, overwrite=clobber)
if flag_output is not None:
Expand All @@ -1502,7 +1503,8 @@ def time_avg_data_and_write(input_data_list, output_data, t_avg, baseline_list=N
uv_avg.use_future_array_shapes()
uvf = UVFlag(uv_avg, mode='flag', copy_flags=True)
uvf.to_waterfall(keep_pol=False, method='and')
uvf.write(flag_output.replace('h5', f'.interleave_{inum}.h5'), clobber=clobber)
uvf.write(os.path.join(os.path.dirname(flag_output),
os.path.basename(flag_output).replace('.h5', f'.interleave_{inum}.h5')), clobber=clobber)
else:
fr.timeavg_data(fr.data, fr.times, fr.lsts, t_avg, flags=fr.flags, nsamples=fr.nsamples,
wgt_by_nsample=wgt_by_nsample, wgt_by_favg_nsample=wgt_by_favg_nsample, rephase=rephase)
Expand Down Expand Up @@ -1840,13 +1842,13 @@ def load_tophat_frfilter_and_write(
for key in keys:
ai, aj = key[:2]
if (ai, aj) in filter_antpairs:
frate_centers[key] = filter_info["filter_centers"][(ai,aj)]
frate_half_widths[key] = filter_info["filter_half_widths"][(ai,aj)]
frate_centers[key] = filter_info["filter_centers"][(ai, aj)]
frate_half_widths[key] = filter_info["filter_half_widths"][(ai, aj)]
else:
# We've already enforced that all the data baselines
# are in the filter file, so we should be safe here.
frate_centers[key] = -filter_info["filter_centers"][(aj,ai)]
frate_half_widths[key] = filter_info["filter_half_widths"][(aj,ai)]
frate_centers[key] = -filter_info["filter_centers"][(aj, ai)]
frate_half_widths[key] = filter_info["filter_half_widths"][(aj, ai)]
else:
# Otherwise, we need to compute the filter parameters.
if case not in ("sky", "max_frate_coeffs", "uvbeam"):
Expand All @@ -1867,7 +1869,7 @@ def load_tophat_frfilter_and_write(
fr_freq_skip=fr_freq_skip,
verbose=verbose, nfr=nfr,
)
# Lists of names of datacontainers that will hold each interleaved
# Lists of names of datacontainers that will hold each interleaved
# data set until they are recombined.
filtered_data_names = [
f'clean_data_interleave_{inum}' for inum in range(ninterleave)
Expand Down
30 changes: 11 additions & 19 deletions hera_cal/tests/test_frf.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,6 @@ def test_timeavg_data(self):

# exceptions
pytest.raises(AssertionError, self.F.timeavg_data, self.F.data, self.F.times, self.F.lsts, 1.0)


def test_filter_data(self):
# construct high-pass filter
Expand Down Expand Up @@ -206,7 +205,7 @@ def test_write_data(self):

pytest.raises(AssertionError, self.F.write_data, self.F.avg_data, "./out.uv", times=self.F.avg_times)
pytest.raises(ValueError, self.F.write_data, self.F.data, "hi", filetype='foo')

@pytest.mark.parametrize("equalize_times", [True, False])
def test_time_avg_data_and_write(self, tmpdir, equalize_times):
# time-averaged data written too file will be compared to this.
Expand All @@ -224,12 +223,11 @@ def test_time_avg_data_and_write(self, tmpdir, equalize_times):
assert np.allclose(data_out.flags[k], self.F.avg_flags[k])
assert np.allclose(data_out.nsamples[k], self.F.avg_nsamples[k])


@pytest.mark.parametrize(
"ninterleave, equalize_times", [(2, True), (2, False), (3, True), (3, False),
(4, True), (4, False), (5, True),
(5, False), (6, True), (6, False)])
def test_time_avg_data_and_write_interleave(self, tmpdir, ninterleave, equalize_times):
def test_time_avg_data_and_write_interleave(self, tmpdir, ninterleave, equalize_times):
tmp_path = tmpdir.strpath
input_name = os.path.join(tmp_path, 'test_input.uvh5')
uvd = UVData()
Expand All @@ -244,7 +242,7 @@ def test_time_avg_data_and_write_interleave(self, tmpdir, ninterleave, equalize_
# check that the correct number of files exist.
interleaved_data = {}
for inum in range(ninterleave):
iname = output_name.replace('.uvh5', f'.interleave_{inum}.uvh5')
iname = os.path.join(os.path.dirname(output_name), os.path.basename(output_name).replace('.uvh5', f'.interleave_{inum}.uvh5'))
assert os.path.exists(iname)
hd = io.HERAData(iname)
hd.read()
Expand All @@ -256,11 +254,10 @@ def test_time_avg_data_and_write_interleave(self, tmpdir, ninterleave, equalize_
if inum > 0:
for tn in range(interleaved_data[inum].Ntimes):
if not equalize_times:
assert interleaved_data[inum].times[tn] > interleaved_data[inum-1].times[tn]
assert interleaved_data[inum].times[tn] > interleaved_data[inum - 1].times[tn]
else:
assert interleaved_data[inum].times[tn] == interleaved_data[inum-1].times[tn]
assert interleaved_data[inum].lsts[tn] == interleaved_data[inum-1].lsts[tn]

assert interleaved_data[inum].times[tn] == interleaved_data[inum - 1].times[tn]
assert interleaved_data[inum].lsts[tn] == interleaved_data[inum - 1].lsts[tn]

def test_time_avg_data_and_write_baseline_list(self, tmpdir):
# compare time averaging over baseline list versus time averaging
Expand Down Expand Up @@ -617,7 +614,6 @@ def test_load_tophat_frfilter_and_write_all_baselines(self, tmpdir):
np.testing.assert_almost_equal(d[(53, 54, 'ee')], frfil.clean_resid[(53, 54, 'ee')], decimal=5)
np.testing.assert_array_equal(f[(53, 54, 'ee')], frfil.flags[(53, 54, 'ee')])


def test_load_tophat_frfilter_and_write_cal(self, tmpdir):
tmp_path = tmpdir.strpath
uvh5 = os.path.join(DATA_PATH, "test_input/zen.2458101.46106.xx.HH.OCR_53x_54x_only.uvh5")
Expand Down Expand Up @@ -696,7 +692,7 @@ def test_load_tophat_frfilter_and_write_skip_autos(self, tmpdir):
CLEAN_outfilename = os.path.join(tmp_path, 'temp_clean.h5')
filled_outfilename = os.path.join(tmp_path, 'temp_filled.h5')
# test skip_autos

frf.load_tophat_frfilter_and_write(uvh5, calfile_list=None, tol=1e-4, res_outfilename=outfilename,
filled_outfilename=filled_outfilename, CLEAN_outfilename=CLEAN_outfilename,
Nbls_per_load=2, clobber=True, skip_autos=True, case='sky')
Expand All @@ -720,15 +716,14 @@ def test_load_tophat_frfilter_and_write_skip_autos(self, tmpdir):
assert not np.allclose(do[bl], d[bl])
assert np.allclose(no[bl], n[bl])


def test_load_tophat_frfilter_and_write_broadcast_flags(self, tmpdir):
# test flag broadcasting with frf.
tmp_path = tmpdir.strpath
uvh5 = os.path.join(DATA_PATH, "test_input/zen.2458101.46106.xx.HH.OCR_53x_54x_only.uvh5")
outfilename = os.path.join(tmp_path, 'temp.h5')
CLEAN_outfilename = os.path.join(tmp_path, 'temp_clean.h5')
filled_outfilename = os.path.join(tmp_path, 'temp_filled.h5')

# prepare an input file for broadcasting flags
input_file = os.path.join(tmp_path, 'temp_special_flags.h5')
shutil.copy(uvh5, input_file)
Expand Down Expand Up @@ -785,7 +780,7 @@ def test_load_tophat_frfilter_and_write_partial_io(self, tmpdir):
time_thresh = 2. / hd.Ntimes
# test delay filtering and writing with factorized flags and partial i/o
time_thresh = 2. / hd.Ntimes

frf.load_tophat_frfilter_and_write(input_file, res_outfilename=outfilename, tol=1e-4, case='sky',
factorize_flags=True, time_thresh=time_thresh, clobber=True)
hd = io.HERAData(outfilename)
Expand All @@ -806,7 +801,6 @@ def test_load_tophat_frfilter_and_write_partial_io(self, tmpdir):
assert np.all(f[bl][:, -1])
assert not np.all(np.isclose(d[bl], 0.))


def test_load_tophat_frfilter_and_write_yaml(self, tmpdir):
# test apriori flags and flag_yaml
tmp_path = tmpdir.strpath
Expand Down Expand Up @@ -953,8 +947,8 @@ def test_load_tophat_frfilter_and_write_with_filter_yaml(self, tmpdir, flip):
filter_centers={str(k): v for k, v in frate_centers.items()},
filter_half_widths={str(k): v for k, v in frate_half_widths.items()},
)
frate_centers = {k+("ee",): v for k, v in frate_centers.items()}
frate_half_widths = {k+("ee",): v for k, v in frate_half_widths.items()}
frate_centers = {k + ("ee",): v for k, v in frate_centers.items()}
frate_half_widths = {k + ("ee",): v for k, v in frate_half_widths.items()}

with open(tmpdir / "filter_info.yaml", "w") as f:
yaml.dump(filter_info, f)
Expand Down Expand Up @@ -993,7 +987,6 @@ def test_load_tophat_frfilter_and_write_with_filter_yaml(self, tmpdir, flip):
# Check that the flags match.
np.testing.assert_array_equal(f[(53, 54, 'ee')], frfil.flags[(53, 54, 'ee')])


def test_load_tophat_frfilter_and_write_with_bad_filter_yaml(self, tmpdir):
tmp_path = tmpdir.strpath
uvh5 = os.path.join(
Expand All @@ -1013,7 +1006,6 @@ def test_load_tophat_frfilter_and_write_with_bad_filter_yaml(self, tmpdir):
case="param_file",
)


def test_tophat_clean_argparser(self):
sys.argv = [sys.argv[0], 'a', '--clobber', '--window', 'blackmanharris', '--max_frate_coeffs', '0.024', '-0.229']
parser = frf.tophat_frfilter_argparser()
Expand Down
2 changes: 1 addition & 1 deletion hera_cal/tests/test_nucal.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_get_unique_orientations():
assert len(group) >= 5

class TestRadialRedundancy:
def setup(self):
def setup_method(self):
self.antpos = hex_array(4, outriggers=0, split_core=False)
self.radial_reds = nucal.RadialRedundancy(self.antpos)

Expand Down

0 comments on commit 45ed23c

Please sign in to comment.