Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change setup to setup_method in nucal test #930

Merged
merged 3 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions hera_cal/frf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1493,7 +1493,8 @@ def time_avg_data_and_write(input_data_list, output_data, t_avg, baseline_list=N
avg_lsts = avg_lsts[:ntimes]

# write data
output_data_name = output_data.replace('.uvh5', f'.interleave_{inum}.uvh5')
output_data_name = os.path.join(os.path.dirname(output_data),
os.path.basename(output_data).replace('.uvh5', f'.interleave_{inum}.uvh5'))
fr.write_data(data=avg_data, filename=output_data_name, flags=avg_flags, nsamples=avg_nsamples,
times=avg_times, lsts=avg_lsts, filetype=filetype, overwrite=clobber)
if flag_output is not None:
Expand All @@ -1502,7 +1503,8 @@ def time_avg_data_and_write(input_data_list, output_data, t_avg, baseline_list=N
uv_avg.use_future_array_shapes()
uvf = UVFlag(uv_avg, mode='flag', copy_flags=True)
uvf.to_waterfall(keep_pol=False, method='and')
uvf.write(flag_output.replace('h5', f'.interleave_{inum}.h5'), clobber=clobber)
uvf.write(os.path.join(os.path.dirname(flag_output),
os.path.basename(flag_output).replace('.h5', f'.interleave_{inum}.h5')), clobber=clobber)
else:
fr.timeavg_data(fr.data, fr.times, fr.lsts, t_avg, flags=fr.flags, nsamples=fr.nsamples,
wgt_by_nsample=wgt_by_nsample, wgt_by_favg_nsample=wgt_by_favg_nsample, rephase=rephase)
Expand Down Expand Up @@ -1840,13 +1842,13 @@ def load_tophat_frfilter_and_write(
for key in keys:
ai, aj = key[:2]
if (ai, aj) in filter_antpairs:
frate_centers[key] = filter_info["filter_centers"][(ai,aj)]
frate_half_widths[key] = filter_info["filter_half_widths"][(ai,aj)]
frate_centers[key] = filter_info["filter_centers"][(ai, aj)]
frate_half_widths[key] = filter_info["filter_half_widths"][(ai, aj)]
else:
# We've already enforced that all the data baselines
# are in the filter file, so we should be safe here.
frate_centers[key] = -filter_info["filter_centers"][(aj,ai)]
frate_half_widths[key] = filter_info["filter_half_widths"][(aj,ai)]
frate_centers[key] = -filter_info["filter_centers"][(aj, ai)]
frate_half_widths[key] = filter_info["filter_half_widths"][(aj, ai)]
else:
# Otherwise, we need to compute the filter parameters.
if case not in ("sky", "max_frate_coeffs", "uvbeam"):
Expand All @@ -1867,7 +1869,7 @@ def load_tophat_frfilter_and_write(
fr_freq_skip=fr_freq_skip,
verbose=verbose, nfr=nfr,
)
# Lists of names of datacontainers that will hold each interleaved
# Lists of names of datacontainers that will hold each interleaved
# data set until they are recombined.
filtered_data_names = [
f'clean_data_interleave_{inum}' for inum in range(ninterleave)
Expand Down
30 changes: 11 additions & 19 deletions hera_cal/tests/test_frf.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,6 @@ def test_timeavg_data(self):

# exceptions
pytest.raises(AssertionError, self.F.timeavg_data, self.F.data, self.F.times, self.F.lsts, 1.0)


def test_filter_data(self):
# construct high-pass filter
Expand Down Expand Up @@ -206,7 +205,7 @@ def test_write_data(self):

pytest.raises(AssertionError, self.F.write_data, self.F.avg_data, "./out.uv", times=self.F.avg_times)
pytest.raises(ValueError, self.F.write_data, self.F.data, "hi", filetype='foo')

@pytest.mark.parametrize("equalize_times", [True, False])
def test_time_avg_data_and_write(self, tmpdir, equalize_times):
# time-averaged data written too file will be compared to this.
Expand All @@ -224,12 +223,11 @@ def test_time_avg_data_and_write(self, tmpdir, equalize_times):
assert np.allclose(data_out.flags[k], self.F.avg_flags[k])
assert np.allclose(data_out.nsamples[k], self.F.avg_nsamples[k])


@pytest.mark.parametrize(
"ninterleave, equalize_times", [(2, True), (2, False), (3, True), (3, False),
(4, True), (4, False), (5, True),
(5, False), (6, True), (6, False)])
def test_time_avg_data_and_write_interleave(self, tmpdir, ninterleave, equalize_times):
def test_time_avg_data_and_write_interleave(self, tmpdir, ninterleave, equalize_times):
tmp_path = tmpdir.strpath
input_name = os.path.join(tmp_path, 'test_input.uvh5')
uvd = UVData()
Expand All @@ -244,7 +242,7 @@ def test_time_avg_data_and_write_interleave(self, tmpdir, ninterleave, equalize_
# check that the correct number of files exist.
interleaved_data = {}
for inum in range(ninterleave):
iname = output_name.replace('.uvh5', f'.interleave_{inum}.uvh5')
iname = os.path.join(os.path.dirname(output_name), os.path.basename(output_name).replace('.uvh5', f'.interleave_{inum}.uvh5'))
assert os.path.exists(iname)
hd = io.HERAData(iname)
hd.read()
Expand All @@ -256,11 +254,10 @@ def test_time_avg_data_and_write_interleave(self, tmpdir, ninterleave, equalize_
if inum > 0:
for tn in range(interleaved_data[inum].Ntimes):
if not equalize_times:
assert interleaved_data[inum].times[tn] > interleaved_data[inum-1].times[tn]
assert interleaved_data[inum].times[tn] > interleaved_data[inum - 1].times[tn]
else:
assert interleaved_data[inum].times[tn] == interleaved_data[inum-1].times[tn]
assert interleaved_data[inum].lsts[tn] == interleaved_data[inum-1].lsts[tn]

assert interleaved_data[inum].times[tn] == interleaved_data[inum - 1].times[tn]
assert interleaved_data[inum].lsts[tn] == interleaved_data[inum - 1].lsts[tn]

def test_time_avg_data_and_write_baseline_list(self, tmpdir):
# compare time averaging over baseline list versus time averaging
Expand Down Expand Up @@ -617,7 +614,6 @@ def test_load_tophat_frfilter_and_write_all_baselines(self, tmpdir):
np.testing.assert_almost_equal(d[(53, 54, 'ee')], frfil.clean_resid[(53, 54, 'ee')], decimal=5)
np.testing.assert_array_equal(f[(53, 54, 'ee')], frfil.flags[(53, 54, 'ee')])


def test_load_tophat_frfilter_and_write_cal(self, tmpdir):
tmp_path = tmpdir.strpath
uvh5 = os.path.join(DATA_PATH, "test_input/zen.2458101.46106.xx.HH.OCR_53x_54x_only.uvh5")
Expand Down Expand Up @@ -696,7 +692,7 @@ def test_load_tophat_frfilter_and_write_skip_autos(self, tmpdir):
CLEAN_outfilename = os.path.join(tmp_path, 'temp_clean.h5')
filled_outfilename = os.path.join(tmp_path, 'temp_filled.h5')
# test skip_autos

frf.load_tophat_frfilter_and_write(uvh5, calfile_list=None, tol=1e-4, res_outfilename=outfilename,
filled_outfilename=filled_outfilename, CLEAN_outfilename=CLEAN_outfilename,
Nbls_per_load=2, clobber=True, skip_autos=True, case='sky')
Expand All @@ -720,15 +716,14 @@ def test_load_tophat_frfilter_and_write_skip_autos(self, tmpdir):
assert not np.allclose(do[bl], d[bl])
assert np.allclose(no[bl], n[bl])


def test_load_tophat_frfilter_and_write_broadcast_flags(self, tmpdir):
# test flag broadcasting with frf.
tmp_path = tmpdir.strpath
uvh5 = os.path.join(DATA_PATH, "test_input/zen.2458101.46106.xx.HH.OCR_53x_54x_only.uvh5")
outfilename = os.path.join(tmp_path, 'temp.h5')
CLEAN_outfilename = os.path.join(tmp_path, 'temp_clean.h5')
filled_outfilename = os.path.join(tmp_path, 'temp_filled.h5')

# prepare an input file for broadcasting flags
input_file = os.path.join(tmp_path, 'temp_special_flags.h5')
shutil.copy(uvh5, input_file)
Expand Down Expand Up @@ -785,7 +780,7 @@ def test_load_tophat_frfilter_and_write_partial_io(self, tmpdir):
time_thresh = 2. / hd.Ntimes
# test delay filtering and writing with factorized flags and partial i/o
time_thresh = 2. / hd.Ntimes

frf.load_tophat_frfilter_and_write(input_file, res_outfilename=outfilename, tol=1e-4, case='sky',
factorize_flags=True, time_thresh=time_thresh, clobber=True)
hd = io.HERAData(outfilename)
Expand All @@ -806,7 +801,6 @@ def test_load_tophat_frfilter_and_write_partial_io(self, tmpdir):
assert np.all(f[bl][:, -1])
assert not np.all(np.isclose(d[bl], 0.))


def test_load_tophat_frfilter_and_write_yaml(self, tmpdir):
# test apriori flags and flag_yaml
tmp_path = tmpdir.strpath
Expand Down Expand Up @@ -953,8 +947,8 @@ def test_load_tophat_frfilter_and_write_with_filter_yaml(self, tmpdir, flip):
filter_centers={str(k): v for k, v in frate_centers.items()},
filter_half_widths={str(k): v for k, v in frate_half_widths.items()},
)
frate_centers = {k+("ee",): v for k, v in frate_centers.items()}
frate_half_widths = {k+("ee",): v for k, v in frate_half_widths.items()}
frate_centers = {k + ("ee",): v for k, v in frate_centers.items()}
frate_half_widths = {k + ("ee",): v for k, v in frate_half_widths.items()}

with open(tmpdir / "filter_info.yaml", "w") as f:
yaml.dump(filter_info, f)
Expand Down Expand Up @@ -993,7 +987,6 @@ def test_load_tophat_frfilter_and_write_with_filter_yaml(self, tmpdir, flip):
# Check that the flags match.
np.testing.assert_array_equal(f[(53, 54, 'ee')], frfil.flags[(53, 54, 'ee')])


def test_load_tophat_frfilter_and_write_with_bad_filter_yaml(self, tmpdir):
tmp_path = tmpdir.strpath
uvh5 = os.path.join(
Expand All @@ -1013,7 +1006,6 @@ def test_load_tophat_frfilter_and_write_with_bad_filter_yaml(self, tmpdir):
case="param_file",
)


def test_tophat_clean_argparser(self):
sys.argv = [sys.argv[0], 'a', '--clobber', '--window', 'blackmanharris', '--max_frate_coeffs', '0.024', '-0.229']
parser = frf.tophat_frfilter_argparser()
Expand Down
2 changes: 1 addition & 1 deletion hera_cal/tests/test_nucal.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_get_unique_orientations():
assert len(group) >= 5

class TestRadialRedundancy:
def setup(self):
def setup_method(self):
self.antpos = hex_array(4, outriggers=0, split_core=False)
self.radial_reds = nucal.RadialRedundancy(self.antpos)

Expand Down
Loading