diff --git a/hera_cal/data/example_filter_params.yaml b/hera_cal/data/example_filter_params.yaml new file mode 100644 index 000000000..af52605e0 --- /dev/null +++ b/hera_cal/data/example_filter_params.yaml @@ -0,0 +1,6 @@ +filter_centers: + (0, 1): 0.1234 + (0, 2): 0.173 +filter_half_widths: + (0, 1): 0.05 + (0, 2): 0.08 diff --git a/hera_cal/frf.py b/hera_cal/frf.py index e157134e7..fd34940d6 100644 --- a/hera_cal/frf.py +++ b/hera_cal/frf.py @@ -27,6 +27,9 @@ import astropy.constants as const from . import redcal from .utils import echo +import os +import yaml +import re SPEED_OF_LIGHT = const.c.si.value SDAY_SEC = units.sday.to("s") @@ -1560,36 +1563,50 @@ def tophat_frfilter_argparser(mode='clean'): filt_options.add_argument("--blacklist_wgt", type=float, default=0.0, help="Relative weight to assign to blacklisted lsts compared to 1.0. Default 0.0 \ means no weight. Note that 0.0 will create problems for DPSS at edge times and frequencies.") - desc = ("Filtering case ['max_frate_coeffs', 'uvbeam', 'sky']", + desc = ("Filtering case ['max_frate_coeffs', 'uvbeam', 'sky', 'param_file']", "If case == 'max_frate_coeffs', then determine fringe rate centers", "and half-widths based on the max_frate_coeffs arg (see below).", "If case == 'uvbeam', then determine fringe rate centers and half widths", "from histogram of main-beam wrt instantaneous sky fringe rates.", "If case == 'sky': then use fringe-rates corresponding to range of ", - "instantanous fringe-rates that include sky emission.") + "instantanous fringe-rates that include sky emission.", + "If case == 'param_file': then use a provided parameter file to determine", + "filter centers and filter half-widths. See param_file help for more ", + "information regarding parameter file structure.") filt_options.add_argument("--case", default="sky", help=' '.join(desc), type=str) desc = ("Number interleaved time subsets to split the data into ", "and apply independent fringe-rate filters. Default is 1 (no interleaved filters).", "This does not change the format of the output files but it does change the nature of their content.") filt_options.add_argument("--ninterleave", default=1, type=int, help=desc) + desc = ("File containing filter parameters. Parameter file must be yaml-readable ", + "and contain two entries: filter_centers and filter_half_widths. Each of ", + "these entries must be dictionaries whose keys are strings of antenna ", + "pairs and whose values are floats. Filter parameters are assumed to ", + "correspond to a particular range of frequencies (chosen when making the ", + "filter parameter file) and depend on baseline but not polarization (i.e., ", + "the 'ee' and 'nn' polarizations will have the same filter for a given baseline.") + filt_options.add_argument("--param_file", default="", type=str, help=desc) + return ap -def load_tophat_frfilter_and_write(datafile_list, case, baseline_list=None, calfile_list=None, - Nbls_per_load=None, spw_range=None, external_flags=None, - factorize_flags=False, time_thresh=0.05, wgt_by_nsample=False, - lst_blacklists=None, blacklist_wgt=0.0, - res_outfilename=None, CLEAN_outfilename=None, filled_outfilename=None, - clobber=False, add_to_history='', avg_red_bllens=False, polarizations=None, - overwrite_flags=False, - flag_yaml=None, skip_autos=False, beamfitsfile=None, verbose=False, - read_axis=None, - percentile_low=5., percentile_high=95., - frate_standoff=0.0, frate_width_multiplier=1.0, - min_frate_half_width=0.025, max_frate_half_width=np.inf, - max_frate_coeffs=None, fr_freq_skip=1, ninterleave=1, - **filter_kwargs): +def load_tophat_frfilter_and_write( + datafile_list, case, baseline_list=None, calfile_list=None, + Nbls_per_load=None, spw_range=None, external_flags=None, + factorize_flags=False, time_thresh=0.05, wgt_by_nsample=False, + lst_blacklists=None, blacklist_wgt=0.0, + res_outfilename=None, CLEAN_outfilename=None, filled_outfilename=None, + clobber=False, add_to_history='', avg_red_bllens=False, polarizations=None, + overwrite_flags=False, + flag_yaml=None, skip_autos=False, beamfitsfile=None, verbose=False, + read_axis=None, + percentile_low=5., percentile_high=95., + frate_standoff=0.0, frate_width_multiplier=1.0, + min_frate_half_width=0.025, max_frate_half_width=np.inf, + max_frate_coeffs=None, fr_freq_skip=1, ninterleave=1, param_file="", + **filter_kwargs +): ''' A tophat fr-filtering method that only simultaneously loads and writes user-provided list of baselines. This is to support parallelization over baseline (rather then time) if baseline_list is specified. @@ -1672,143 +1689,274 @@ def load_tophat_frfilter_and_write(datafile_list, case, baseline_list=None, calf only used if case == 'uvbeam' ninterleave: int, optional Number of interleaved sets to run time filtering on. + param_file: str, optional + File containing filter parameters (e.g., centers and half-widths). + The file must be readable by yaml, with a "filter_centers" entry and + a "filter_half_widths" entry. Each of these entries should correspond + to dictionaries with antenna pair strings as keys and floating point + numbers as values. The filter centers and filter half widths are assumed + to be provided in mHz. When writing filter parameters to a file, ensure + that the antenna pair keys have been converted to strings prior to + dumping the contents to the file. See the file example_filter_params.yaml + for an example. filter_kwargs: additional keyword arguments to be passed to FRFilter.tophat_frfilter() ''' if baseline_list is not None and Nbls_per_load is not None: raise NotImplementedError("baseline loading and partial i/o not yet implemented.") - hd = io.HERAData(datafile_list, filetype='uvh5') if baseline_list is not None and len(baseline_list) == 0: - warnings.warn("Length of baseline list is zero." - "This can happen under normal circumstances when there are more files in datafile_list then baselines." - "in your dataset. Exiting without writing any output.", RuntimeWarning) + warnings.warn( + "Length of baseline list is zero. This can happen under normal " + "circumstances when there are more files in datafile_list then baselines." + "in your dataset. Exiting without writing any output.", RuntimeWarning + ) + return + + if case == "param_file" and not os.path.exists(param_file): + raise ValueError( + "When using a filter parameter file, a valid parameter file must " + f"be provided. The provided file {param_file} could not be found." + ) + + hd = io.HERAData(datafile_list, filetype='uvh5') + # Figure out which baselines to load if not provided. + if baseline_list is None: + if len(hd.filepaths) > 1: + baseline_list = list(hd.antpairs.values())[0] + else: + baseline_list = hd.antpairs + + # Figure out which frequencies to load. + if spw_range is None: + spw_range = [0, hd.Nfreqs] + freqs = hd.freq_array.flatten()[spw_range[0]:spw_range[1]] + + # Figure out which antennas to load if calfiles are provided. + baseline_antennas = [] + for blpolpair in baseline_list: + baseline_antennas += list(blpolpair[:2]) + baseline_antennas = np.unique(baseline_antennas).astype(int) + + # Read calibration solutions if provided. + if calfile_list is not None: + cals = io.HERACal(calfile_list) + cals.read(antenna_nums=baseline_antennas, frequencies=freqs) else: - if baseline_list is None: - if len(hd.filepaths) > 1: - baseline_list = list(hd.antpairs.values())[0] - else: - baseline_list = hd.antpairs - if spw_range is None: - spw_range = [0, hd.Nfreqs] - freqs = hd.freq_array.flatten()[spw_range[0]:spw_range[1]] - baseline_antennas = [] - for blpolpair in baseline_list: - baseline_antennas += list(blpolpair[:2]) - baseline_antennas = np.unique(baseline_antennas).astype(int) - if calfile_list is not None: - cals = io.HERACal(calfile_list) - cals.read(antenna_nums=baseline_antennas, frequencies=freqs) + cals = None + + # Figure out which polarizations to use. + if polarizations is None: + if len(hd.filepaths) > 1: + polarizations = list(hd.pols.values())[0] else: - cals = None - if polarizations is None: - if len(hd.filepaths) > 1: - polarizations = list(hd.pols.values())[0] - else: - polarizations = hd.pols - if Nbls_per_load is None: - Nbls_per_load = len(baseline_list) - for i in range(0, len(baseline_list), Nbls_per_load): - frfil = FRFilter(hd, input_cal=cals) - frfil.read(bls=baseline_list[i:i + Nbls_per_load], - frequencies=freqs, polarizations=polarizations, axis=read_axis) - if avg_red_bllens: - frfil.avg_red_baseline_vectors() - if external_flags is not None: - frfil.apply_flags(external_flags, overwrite_flags=overwrite_flags) - if flag_yaml is not None: - frfil.apply_flags(flag_yaml, overwrite_flags=overwrite_flags, filetype='yaml') - if factorize_flags: - frfil.factorize_flags(time_thresh=time_thresh, inplace=True) - keys = frfil.data.keys() - if skip_autos: - keys = [bl for bl in keys if bl[0] != bl[1]] - if beamfitsfile is not None: - uvb = UVBeam() - uvb.read_beamfits(beamfitsfile) - uvb.use_future_array_shapes() + polarizations = hd.pols + + # Load all baselines if a baseline list not provided. + if Nbls_per_load is None: + Nbls_per_load = len(baseline_list) + + # If a filter parameter file is provided, let's check that all of the + # baselines are present in the filter file. + if case == "param_file": + with open(param_file, "r") as f: + _filter_info = yaml.load(f.read(), Loader=yaml.SafeLoader) + + # Convert the dictionary keys into tuples. + filter_info = {param: {} for param in _filter_info.keys()} + for filter_param, info in _filter_info.items(): + for antpair_str, value in info.items(): + antpair = tuple( + int(ant) for ant in re.findall("[0-9]+", antpair_str) + ) + filter_info[filter_param][antpair] = value + + filter_antpairs = set(filter_info["filter_centers"].keys()) + have_bl_info = [] + missing_bls = set() + for bl in baseline_list: + if skip_autos and bl[0] == bl[1]: + continue + have_bl = (bl[:2] in filter_antpairs) or (bl[:2][::-1] in filter_antpairs) + have_bl_info.append(have_bl) + if not have_bl: + missing_bls.add(bl[:2]) + + if missing_bls: + missing_bls = [str(bl) for bl in missing_bls] + raise ValueError( + "Provided filter file doesn't have every baseline. The following" + "baselines could not be found: " + " ".join(missing_bls) + ) + + # Read the data in chunks and perform filtering. + for i in range(0, len(baseline_list), Nbls_per_load): + # Read data from this chunk of baselines. + frfil = FRFilter(hd, input_cal=cals) + frfil.read( + bls=baseline_list[i:i + Nbls_per_load], + frequencies=freqs, + polarizations=polarizations, + axis=read_axis, + ) + + # Some extra handling if requested. + if avg_red_bllens: + frfil.avg_red_baseline_vectors() + if external_flags is not None: + frfil.apply_flags(external_flags, overwrite_flags=overwrite_flags) + if flag_yaml is not None: + frfil.apply_flags(flag_yaml, overwrite_flags=overwrite_flags, filetype='yaml') + if factorize_flags: + frfil.factorize_flags(time_thresh=time_thresh, inplace=True) + + # Figure out which baselines we'll need for filtering. + keys = frfil.data.keys() + if skip_autos: + keys = [bl for bl in keys if bl[0] != bl[1]] + + # Read in the beam file if provided. + if beamfitsfile is not None: + uvb = UVBeam() + uvb.read_beamfits(beamfitsfile) + uvb.use_future_array_shapes() + else: + uvb = None + + # Filter the data if there is data to filter. + if len(keys) > 0: + # Deal with interleaved sets + frfil._deinterleave_data_in_time('data', ninterleave=ninterleave) + frfil._deinterleave_data_in_time( + 'flags', ninterleave=ninterleave, set_time_sets=False + ) + frfil._deinterleave_data_in_time( + 'nsamples', ninterleave=ninterleave, set_time_sets=False + ) + + # Figure out fringe-rate centers and half-widths. + if case == "param_file": + frate_centers = {} + frate_half_widths = {} + # Assuming we use the same filters for all polarizations. + for key in keys: + ai, aj = key[:2] + if (ai, aj) in filter_antpairs: + frate_centers[key] = filter_info["filter_centers"][(ai,aj)] + frate_half_widths[key] = filter_info["filter_half_widths"][(ai,aj)] + else: + # We've already enforced that all the data baselines + # are in the filter file, so we should be safe here. + frate_centers[key] = -filter_info["filter_centers"][(aj,ai)] + frate_half_widths[key] = filter_info["filter_half_widths"][(aj,ai)] else: - uvb = None - if len(keys) > 0: - # Deal with interleaved sets - frfil._deinterleave_data_in_time('data', ninterleave=ninterleave) - frfil._deinterleave_data_in_time('flags', ninterleave=ninterleave, set_time_sets=False) - frfil._deinterleave_data_in_time('nsamples', ninterleave=ninterleave, set_time_sets=False) - # figure out frige rate centers and half-widths - assert case in ['sky', 'max_frate_coeffs', 'uvbeam'], f'case={case} is not valid.' - # use conservative nfr (lowest resolution set). + # Otherwise, we need to compute the filter parameters. + if case not in ("sky", "max_frate_coeffs", "uvbeam"): + raise ValueError(f"case={case} is not valid") + + # Use conservative nfr (lowest resolution set). nfr = int(np.min([len(tset) for tset in frfil.time_sets])) - frate_centers, frate_half_widths = select_tophat_frates(uvd=frfil.hd, blvecs=frfil.blvecs, - case=case, keys=keys, uvb=uvb, - frate_standoff=frate_standoff, - frate_width_multiplier=frate_width_multiplier, - min_frate_half_width=min_frate_half_width, - max_frate_half_width=max_frate_half_width, - max_frate_coeffs=max_frate_coeffs, - percentile_low=percentile_low, - percentile_high=percentile_high, - fr_freq_skip=fr_freq_skip, - verbose=verbose, nfr=nfr) - # Lists of names of datacontainers that will hold each interleaved data set until they are - # recombined. - filtered_data_names = [f'clean_data_interleave_{inum}' for inum in range(ninterleave)] - filtered_flag_names = [fstr.replace('data', 'flags') for fstr in filtered_data_names] - filtered_resid_names = [fstr.replace('data', 'resid') for fstr in filtered_data_names] - filtered_model_names = [fstr.replace('data', 'model') for fstr in filtered_data_names] - filtered_resid_flag_names = [fstr.replace('data', 'resid_flags') for fstr in filtered_data_names] - - for inum in range(ninterleave): - - # Build weights using flags, nsamples, and exlcuded lsts - flags = getattr(frfil, f'flags_interleave_{inum}') - nsamples = getattr(frfil, f'nsamples_interleave_{inum}') - wgts = io.DataContainer({k: (~flags[k]).astype(float) for k in flags}) - - lsts = frfil.lst_sets[inum] - for k in wgts: - if wgt_by_nsample: - wgts[k] *= nsamples[k] - if lst_blacklists is not None: - for lb in lst_blacklists: - if lb[0] < lb[1]: - is_blacklisted = (lsts >= lb[0] * np.pi / 12)\ - & (lsts <= lb[1] * np.pi / 12) - else: - is_blacklisted = (lsts >= lb[0] * np.pi / 12) | (lsts <= lb[1] * np.pi / 12) - wgts[k][is_blacklisted, :] = wgts[k][is_blacklisted, :] * blacklist_wgt - # run tophat filter - frfil.tophat_frfilter(frate_centers=frate_centers, frate_half_widths=frate_half_widths, - keys=keys, verbose=verbose, wgts=wgts, flags=getattr(frfil, f'flags_interleave_{inum}'), - data=getattr(frfil, f'data_interleave_{inum}'), output_postfix=f'interleave_{inum}', - times=frfil.time_sets[inum] * SDAY_SEC * 1e-3, - **filter_kwargs) - - frfil._interleave_data_in_time(filtered_data_names, 'clean_data') - frfil._interleave_data_in_time(filtered_flag_names, 'clean_flags') - frfil._interleave_data_in_time(filtered_resid_names, 'clean_resid') - frfil._interleave_data_in_time(filtered_resid_flag_names, 'clean_resid_flags') - frfil._interleave_data_in_time(filtered_model_names, 'clean_model') + frate_centers, frate_half_widths = select_tophat_frates( + uvd=frfil.hd, blvecs=frfil.blvecs, + case=case, keys=keys, uvb=uvb, + frate_standoff=frate_standoff, + frate_width_multiplier=frate_width_multiplier, + min_frate_half_width=min_frate_half_width, + max_frate_half_width=max_frate_half_width, + max_frate_coeffs=max_frate_coeffs, + percentile_low=percentile_low, + percentile_high=percentile_high, + fr_freq_skip=fr_freq_skip, + verbose=verbose, nfr=nfr, + ) + # Lists of names of datacontainers that will hold each interleaved + # data set until they are recombined. + filtered_data_names = [ + f'clean_data_interleave_{inum}' for inum in range(ninterleave) + ] + filtered_flag_names = [ + fstr.replace('data', 'flags') for fstr in filtered_data_names + ] + filtered_resid_names = [ + fstr.replace('data', 'resid') for fstr in filtered_data_names + ] + filtered_model_names = [ + fstr.replace('data', 'model') for fstr in filtered_data_names + ] + filtered_resid_flag_names = [ + fstr.replace('data', 'resid_flags') for fstr in filtered_data_names + ] - else: - frfil.clean_data = DataContainer({}) - frfil.clean_flags = DataContainer({}) - frfil.clean_resid = DataContainer({}) - frfil.clean_resid_flags = DataContainer({}) - frfil.clean_model = DataContainer({}) - # put autocorr data into filtered data containers if skip_autos = True. - # so that it can be written out into the filtered files. - if skip_autos: - for bl in frfil.data.keys(): - if bl[0] == bl[1]: - frfil.clean_data[bl] = frfil.data[bl] - frfil.clean_flags[bl] = frfil.flags[bl] - frfil.clean_resid[bl] = frfil.data[bl] - frfil.clean_model[bl] = np.zeros_like(frfil.data[bl]) - frfil.clean_resid_flags[bl] = frfil.flags[bl] - - frfil.write_filtered_data(res_outfilename=res_outfilename, CLEAN_outfilename=CLEAN_outfilename, - filled_outfilename=filled_outfilename, partial_write=Nbls_per_load < len(baseline_list), - clobber=clobber, add_to_history=add_to_history, - extra_attrs={'Nfreqs': frfil.hd.Nfreqs, 'freq_array': frfil.hd.freq_array, 'channel_width': frfil.hd.channel_width, 'flex_spw_id_array': frfil.hd.flex_spw_id_array}) - frfil.hd.data_array = None # this forces a reload in the next loop + for inum in range(ninterleave): + + # Build weights using flags, nsamples, and exlcuded lsts + flags = getattr(frfil, f'flags_interleave_{inum}') + nsamples = getattr(frfil, f'nsamples_interleave_{inum}') + wgts = io.DataContainer({k: (~flags[k]).astype(float) for k in flags}) + + lsts = frfil.lst_sets[inum] + for k in wgts: + if wgt_by_nsample: + wgts[k] *= nsamples[k] + if lst_blacklists is not None: + for lb in lst_blacklists: + if lb[0] < lb[1]: + is_blacklisted = ( + lsts >= lb[0] * np.pi / 12 + ) & (lsts <= lb[1] * np.pi / 12) + else: + is_blacklisted = ( + lsts >= lb[0] * np.pi / 12 + ) | (lsts <= lb[1] * np.pi / 12) + wgts[k][is_blacklisted, :] = ( + wgts[k][is_blacklisted, :] * blacklist_wgt + ) + # run tophat filter + frfil.tophat_frfilter( + frate_centers=frate_centers, frate_half_widths=frate_half_widths, + keys=keys, verbose=verbose, wgts=wgts, + flags=getattr(frfil, f'flags_interleave_{inum}'), + data=getattr(frfil, f'data_interleave_{inum}'), + output_postfix=f'interleave_{inum}', + times=frfil.time_sets[inum] * SDAY_SEC * 1e-3, + **filter_kwargs + ) + + frfil._interleave_data_in_time(filtered_data_names, 'clean_data') + frfil._interleave_data_in_time(filtered_flag_names, 'clean_flags') + frfil._interleave_data_in_time(filtered_resid_names, 'clean_resid') + frfil._interleave_data_in_time(filtered_resid_flag_names, 'clean_resid_flags') + frfil._interleave_data_in_time(filtered_model_names, 'clean_model') + + else: + frfil.clean_data = DataContainer({}) + frfil.clean_flags = DataContainer({}) + frfil.clean_resid = DataContainer({}) + frfil.clean_resid_flags = DataContainer({}) + frfil.clean_model = DataContainer({}) + + # put autocorr data into filtered data containers if skip_autos = True. + # so that it can be written out into the filtered files. + if skip_autos: + for bl in frfil.data.keys(): + if bl[0] == bl[1]: + frfil.clean_data[bl] = frfil.data[bl] + frfil.clean_flags[bl] = frfil.flags[bl] + frfil.clean_resid[bl] = frfil.data[bl] + frfil.clean_model[bl] = np.zeros_like(frfil.data[bl]) + frfil.clean_resid_flags[bl] = frfil.flags[bl] + + frfil.write_filtered_data( + res_outfilename=res_outfilename, CLEAN_outfilename=CLEAN_outfilename, + filled_outfilename=filled_outfilename, + partial_write=Nbls_per_load < len(baseline_list), + clobber=clobber, add_to_history=add_to_history, + extra_attrs={ + 'Nfreqs': frfil.hd.Nfreqs, + 'freq_array': frfil.hd.freq_array, + 'channel_width': frfil.hd.channel_width, + 'flex_spw_id_array': frfil.hd.flex_spw_id_array + }, + ) + frfil.hd.data_array = None # this forces a reload in the next loop def time_average_argparser(): diff --git a/hera_cal/tests/test_frf.py b/hera_cal/tests/test_frf.py index 5f05c5f37..1eb86c48e 100644 --- a/hera_cal/tests/test_frf.py +++ b/hera_cal/tests/test_frf.py @@ -13,6 +13,7 @@ from pyuvdata import UVData from pyuvdata import utils as uvutils import unittest +import yaml from scipy import stats from scipy import constants from pyuvdata import UVFlag, UVBeam @@ -930,6 +931,89 @@ def test_load_dayenu_filter_and_write(self, tmpdir): os.remove(outfilename) shutil.rmtree(cdir) + @pytest.mark.parametrize("flip", [True, False]) + def test_load_tophat_frfilter_and_write_with_filter_yaml(self, tmpdir, flip): + tmp_path = tmpdir.strpath + uvh5 = os.path.join( + DATA_PATH, "test_input/zen.2458101.46106.xx.HH.OCR_53x_54x_only.uvh5" + ) + frate_centers = {(53, 54): 0.1} + frate_half_widths = {(53, 54): 0.05} + if flip: + filter_info = dict( + filter_centers={ + str(k[::-1]): -v for k, v in frate_centers.items() + }, + filter_half_widths={ + str(k[::-1]): v for k, v in frate_half_widths.items() + }, + ) + else: + filter_info = dict( + filter_centers={str(k): v for k, v in frate_centers.items()}, + filter_half_widths={str(k): v for k, v in frate_half_widths.items()}, + ) + frate_centers = {k+("ee",): v for k, v in frate_centers.items()} + frate_half_widths = {k+("ee",): v for k, v in frate_half_widths.items()} + + with open(tmpdir / "filter_info.yaml", "w") as f: + yaml.dump(filter_info, f) + + outfilename = os.path.join(tmp_path, 'temp.h5') + frf.load_tophat_frfilter_and_write( + uvh5, + res_outfilename=outfilename, + tol=1e-4, + clobber=True, + Nbls_per_load=1, + param_file=str(tmpdir / "filter_info.yaml"), + case="param_file", + skip_autos=True, + ) + hd = io.HERAData(outfilename) + d, f, n = hd.read(bls=[(53, 54, 'ee')]) + for bl in d: + assert not np.allclose(d[bl], 0) + + frfil = frf.FRFilter(uvh5, filetype='uvh5') + frfil.read(bls=[(53, 54, 'ee')]) + frfil.tophat_frfilter( + keys=[(53, 54, 'ee')], + tol=1e-4, + verbose=True, + frate_centers=frate_centers, + frate_half_widths=frate_half_widths, + ) + + # Check that the filtered data both ways matches. + np.testing.assert_almost_equal( + d[(53, 54, 'ee')], frfil.clean_resid[(53, 54, 'ee')], decimal=5 + ) + + # Check that the flags match. + np.testing.assert_array_equal(f[(53, 54, 'ee')], frfil.flags[(53, 54, 'ee')]) + + + def test_load_tophat_frfilter_and_write_with_bad_filter_yaml(self, tmpdir): + tmp_path = tmpdir.strpath + uvh5 = os.path.join( + DATA_PATH, "test_input/zen.2458101.46106.xx.HH.OCR_53x_54x_only.uvh5" + ) + param_file = os.path.join(DATA_PATH, "example_filter_params.yaml") + + outfilename = os.path.join(tmp_path, 'temp.h5') + with pytest.raises(ValueError, match="(53, 54)"): + frf.load_tophat_frfilter_and_write( + uvh5, + res_outfilename=outfilename, + tol=1e-4, + clobber=True, + Nbls_per_load=1, + param_file=param_file, + case="param_file", + ) + + def test_tophat_clean_argparser(self): sys.argv = [sys.argv[0], 'a', '--clobber', '--window', 'blackmanharris', '--max_frate_coeffs', '0.024', '-0.229'] parser = frf.tophat_frfilter_argparser() diff --git a/scripts/tophat_frfilter_run.py b/scripts/tophat_frfilter_run.py index 4f9fb98fc..6178524e1 100644 --- a/scripts/tophat_frfilter_run.py +++ b/scripts/tophat_frfilter_run.py @@ -43,20 +43,26 @@ clobber=ap.clobber, write_cache=ap.write_cache, CLEAN_outfilename=ap.CLEAN_outfilename, read_cache=ap.read_cache, mode=ap.mode, res_outfilename=ap.res_outfilename, factorize_flags=ap.factorize_flags, time_thresh=ap.time_thresh, - wgt_by_nsample=ap.wgt_by_nsample, lst_blacklists=ap.lst_blacklists, blacklist_wgt=ap.blacklist_wgt, + wgt_by_nsample=ap.wgt_by_nsample, lst_blacklists=ap.lst_blacklists, + blacklist_wgt=ap.blacklist_wgt, add_to_history=' '.join(sys.argv), verbose=ap.verbose, flag_yaml=ap.flag_yaml, Nbls_per_load=ap.Nbls_per_load, case=ap.case, external_flags=ap.external_flags, filter_spw_ranges=ap.filter_spw_ranges, overwrite_flags=ap.overwrite_flags, skip_autos=ap.skip_autos, skip_if_flag_within_edge_distance=ap.skip_if_flag_within_edge_distance, zeropad=ap.zeropad, tol=ap.tol, skip_wgt=ap.skip_wgt, max_frate_coeffs=ap.max_frate_coeffs, - frate_width_multiplier=ap.frate_width_multiplier, frate_standoff=ap.frate_standoff, fr_freq_skip=ap.fr_freq_skip, + frate_width_multiplier=ap.frate_width_multiplier, frate_standoff=ap.frate_standoff, + fr_freq_skip=ap.fr_freq_skip, min_frate_half_width=ap.min_frate_half_width, max_frate_half_width=ap.max_frate_half_width, - beamfitsfile=ap.beamfitsfile, percentile_low=ap.percentile_low, percentile_high=ap.percentile_high, - skip_contiguous_flags=not(ap.dont_skip_contiguous_flags), max_contiguous_flag=ap.max_contiguous_flag, + beamfitsfile=ap.beamfitsfile, percentile_low=ap.percentile_low, + percentile_high=ap.percentile_high, + skip_contiguous_flags=not(ap.dont_skip_contiguous_flags), + max_contiguous_flag=ap.max_contiguous_flag, skip_flagged_edges=not(ap.dont_skip_flagged_edges), - flag_model_rms_outliers=not(ap.dont_flag_model_rms_outliers), model_rms_threshold=ap.model_rms_threshold, + flag_model_rms_outliers=not(ap.dont_flag_model_rms_outliers), + model_rms_threshold=ap.model_rms_threshold, clean_flags_in_resid_flags=not(ap.clean_flags_not_in_resid_flags), pre_filter_modes_between_lobe_minimum_and_zero=ap.pre_filter_modes_between_lobe_minimum_and_zero, + param_file=ap.param_file, **filter_kwargs )