diff --git a/src/kbmod/filters/stamp_filters.py b/src/kbmod/filters/stamp_filters.py index f56666cff..b295f9de5 100644 --- a/src/kbmod/filters/stamp_filters.py +++ b/src/kbmod/filters/stamp_filters.py @@ -312,6 +312,7 @@ def extract_search_parameters_from_config(config): return params +# TODO remove once we full replace ResultList with Results def get_coadds_and_filter(result_list, im_stack, stamp_params, chunk_size=1000000, debug=False): """Create the co-added postage stamps and filter them based on their statistical properties. Results with stamps that are similar to a Gaussian are kept. @@ -388,6 +389,90 @@ def get_coadds_and_filter(result_list, im_stack, stamp_params, chunk_size=100000 logger.debug("{:.2f}s elapsed".format(time.time() - start_time)) +def get_coadds_and_filter_results(result_data, im_stack, stamp_params, chunk_size=1000000, debug=False): + """Create the co-added postage stamps and filter them based on their statistical + properties. Results with stamps that are similar to a Gaussian are kept. + + Parameters + ---------- + result_data : `Results` + The current set of results. Modified directly. + im_stack : `ImageStack` + The images from which to build the co-added stamps. + stamp_params : `StampParameters` or `SearchConfiguration` + The filtering parameters for the stamps. + chunk_size : `int` + How many stamps to load and filter at a time. Used to control memory. + debug : `bool` + Output verbose debugging messages. + """ + num_results = len(result_data) + + if type(stamp_params) is SearchConfiguration: + stamp_params = extract_search_parameters_from_config(stamp_params) + + if num_results <= 0: + logger.debug("Stamp Filtering : skipping, othing to filter.") + else: + logger.debug(f"Stamp filtering {num_results} results.") + logger.debug(f"Using filtering params: {stamp_params}") + logger.debug(f"Using chunksize = {chunk_size}") + + trj_list = result_data.make_trajectory_list() + keep_row = [False] * num_results + stamps_to_keep = [] + + # Run the stamp creation and filtering in batches of chunk_size. + start_time = time.time() + start_idx = 0 + while start_idx < num_results: + end_idx = min([start_idx + chunk_size, num_results]) + slice_size = end_idx - start_idx + + # Create a subslice of the results and the Boolean indices. + # Note that the sum stamp type does not filter out lc_index. + trj_slice = trj_list[start_idx:end_idx] + if stamp_params.stamp_type != StampType.STAMP_SUM and "index_valid" in result_data.colnames: + bool_slice = result_data["index_valid"][start_idx:end_idx] + else: + # Use all the indices for each trajectory. + bool_slice = [[True] * im_stack.img_count() for _ in range(slice_size)] + + # Create and filter the results, using the GPU if there is one and enough + # trajectories to make it worthwhile. + stamps_slice = StampCreator.get_coadded_stamps( + im_stack, + trj_slice, + bool_slice, + stamp_params, + HAS_GPU and len(trj_slice) > 100, + ) + # TODO: a way to avoid a copy here would be to do + # np.array([s.image for s in stamps], dtype=np.single, copy=False) + # but that could cause a problem with reference counting at the m + # moment. The real fix is to make the stamps return Image not + # RawImage and avoid reference to an private attribute and risking + # collecting RawImage but leaving a dangling ref to the attribute. + # That's a fix for another time so I'm leaving it as a copy here + for ind, stamp in enumerate(stamps_slice): + if stamp.width > 1: + stamps_to_keep.append(np.array(stamp.image)) + keep_row[start_idx + ind] = True + + # Move to the next chunk. + start_idx += chunk_size + + # Do the actual filtering of results + result_data.filter_mask(keep_row, label="stamp_filter") + + # Append the coadded stamps to the results. We do this after the filtering + # so we are not adding a jagged array. + result_data.table["stamp"] = np.array(stamps_to_keep) + + logger.debug(f"Keeping {len(result_data)} results") + logger.debug("{:.2f}s elapsed".format(time.time() - start_time)) + + def append_all_stamps(result_data, im_stack, stamp_radius): """Get the stamps for the final results from a kbmod search. These are appended onto the corresponding entries in a ResultList. diff --git a/tests/test_stamp_filters.py b/tests/test_stamp_filters.py index 9e39927f6..64b35d4fe 100644 --- a/tests/test_stamp_filters.py +++ b/tests/test_stamp_filters.py @@ -5,6 +5,7 @@ from kbmod.fake_data.fake_data_creator import add_fake_object, create_fake_times, FakeDataSet from kbmod.filters.stamp_filters import * from kbmod.result_list import * +from kbmod.results import Results from kbmod.search import * @@ -235,6 +236,58 @@ def test_get_coadds_and_filter(self): self.assertIsNotNone(keep.results[0].stamp) self.assertIsNotNone(keep.results[1].stamp) + @unittest.skipIf(not HAS_GPU, "Skipping test (no GPU detected)") + def test_get_coadds_and_filter_results(self): + image_count = 10 + fake_times = create_fake_times(image_count, 57130.2, 1, 0.01, 1) + ds = FakeDataSet( + 25, # width + 35, # height + fake_times, # time stamps + 1.0, # noise level + 0.5, # psf value + True, # Use a fixed seed for testing + ) + + # Insert a single fake object with known parameters. + trj = make_trajectory(8, 7, 2.0, 1.0, flux=250.0) + ds.insert_object(trj) + + # Second Trajectory that isn't any good. + trj2 = make_trajectory(1, 1, 0.0, 0.0) + + # Third Trajectory that is close to good, but offset. + trj3 = make_trajectory(trj.x + 2, trj.y + 2, trj.vx, trj.vy) + + # Create a fourth Trajectory that is just close enough + trj4 = make_trajectory(trj.x + 1, trj.y + 1, trj.vx, trj.vy) + + # Create the Results. + keep = Results([trj, trj2, trj3, trj4]) + self.assertFalse("stamp" in keep.colnames) + + # Create the stamp parameters we need. + config_dict = { + "center_thresh": 0.03, + "do_stamp_filter": True, + "mom_lims": [35.5, 35.5, 1.0, 1.0, 1.0], + "peak_offset": [1.5, 1.5], + "stamp_type": "cpp_mean", + "stamp_radius": 5, + } + config = SearchConfiguration.from_dict(config_dict) + + # Do the filtering. + get_coadds_and_filter_results(keep, ds.stack, config, chunk_size=2, debug=False) + + # The check that the correct indices and number of stamps are saved. + self.assertTrue("stamp" in keep.colnames) + self.assertEqual(len(keep), 2) + self.assertEqual(keep["x"][0], trj.x) + self.assertEqual(keep["x"][1], trj.x + 1) + self.assertEqual(keep["stamp"][0].shape, (11, 11)) + self.assertEqual(keep["stamp"][1].shape, (11, 11)) + def test_append_all_stamps(self): image_count = 10 fake_times = create_fake_times(image_count, 57130.2, 1, 0.01, 1) @@ -284,7 +337,7 @@ def test_append_all_stamps_results(self): append_all_stamps(keep, ds.stack, 5) self.assertTrue("all_stamps" in keep.colnames) - for i in range(len(trj_list)): + for i in range(len(keep)): stamps_array = keep["all_stamps"][i] self.assertEqual(stamps_array.shape[0], image_count) self.assertEqual(stamps_array.shape[1], 11)