Skip to content

Commit

Permalink
Merge branch 'main' into generator_config_cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremykubica committed Sep 17, 2024
2 parents 592e139 + e01f775 commit e2d523e
Show file tree
Hide file tree
Showing 12 changed files with 19 additions and 202 deletions.
3 changes: 3 additions & 0 deletions src/kbmod/filters/sigma_g_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from scipy.special import erfinv

from kbmod.results import Results
from kbmod.search import DebugTimer

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -177,7 +178,9 @@ def apply_clipped_sigma_g(clipper, result_data):
logger.info("SigmaG Clipping : skipping, nothing to filter.")
return

filter_timer = DebugTimer("sigma-g filtering", logger)
lh = result_data.compute_likelihood_curves(filter_obs=True, mask_value=np.nan)
obs_valid = clipper.compute_clipped_sigma_g_matrix(lh)
result_data.update_obs_valid(obs_valid)
filter_timer.stop()
return
4 changes: 4 additions & 0 deletions src/kbmod/filters/stamp_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,16 @@ def append_coadds(result_data, im_stack, coadd_types, radius, chunk_size=100_000
"""
if radius <= 0:
raise ValueError(f"Invalid stamp radius {radius}")
stamp_timer = DebugTimer("computing extra coadds", logger)

params = StampParameters()
params.radius = radius
params.do_filtering = False

# Loop through all the coadd types in the list, generating a corresponding stamp.
for coadd_type in coadd_types:
logger.info(f"Adding coadd={coadd_type} for all results.")

if coadd_type == "median":
params.stamp_type = StampType.STAMP_MEDIAN
elif coadd_type == "mean":
Expand All @@ -212,6 +215,7 @@ def append_coadds(result_data, im_stack, coadd_types, radius, chunk_size=100_000
chunk_size=chunk_size,
colname=f"coadd_{coadd_type}",
)
stamp_timer.stop()


def append_all_stamps(result_data, im_stack, stamp_radius):
Expand Down
8 changes: 2 additions & 6 deletions src/kbmod/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,8 @@ def from_trajectories(cls, trajectories, track_filtered=False):
input_d = {}
for col in cls.required_cols:
input_d[col[0]] = []
valid_mask = []

# Add the valid trajectories to the table.
# Add the trajectories to the table.
for trj in trajectories:
input_d["x"].append(trj.x)
input_d["y"].append(trj.y)
Expand All @@ -150,17 +149,14 @@ def from_trajectories(cls, trajectories, track_filtered=False):
input_d["likelihood"].append(trj.lh)
input_d["flux"].append(trj.flux)
input_d["obs_count"].append(trj.obs_count)
valid_mask.append(trj.valid)

# Check for any missing columns and fill in the default value.
for col in cls.required_cols:
if col[0] not in input_d:
input_d[col[0]] = [col[2]] * num_valid
invalid_d[col[0]] = [col[2]] * num_invalid
input_d[col[0]] = [col[2]] * len(trajectories)

# Create the table and add the unfiltered (and filtered) results.
results = Results(input_d, track_filtered=track_filtered)
results.filter_rows(np.array(valid_mask, dtype=bool), "invalid_trajectory")
return results

@classmethod
Expand Down
16 changes: 5 additions & 11 deletions src/kbmod/search/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ struct Trajectory {
int y = 0;
// Number of images summed
int obs_count;
// Whether the trajectory is valid. Used for on-GPU filtering.
bool valid = true;

// Get pixel positions from a zero-shifted time. Centered indicates whether
// the prediction starts from the center of the pixel (which it does in the search)
Expand All @@ -79,14 +77,13 @@ struct Trajectory {
const std::string to_string() const {
return "lh: " + std::to_string(lh) + " flux: " + std::to_string(flux) + " x: " + std::to_string(x) +
" y: " + std::to_string(y) + " vx: " + std::to_string(vx) + " vy: " + std::to_string(vy) +
" obs_count: " + std::to_string(obs_count) + " valid: " + std::to_string(valid);
" obs_count: " + std::to_string(obs_count);
}

// This is a hack to provide a constructor with non-default arguments in Python. If we include
// the constructor as a method in the Trajectory struct CUDA will complain when creating new objects
// because it cannot call out to a host function.
static Trajectory make_trajectory(int x, int y, float vx, float vy, float flux, float lh, int obs_count,
bool valid) {
static Trajectory make_trajectory(int x, int y, float vx, float vy, float flux, float lh, int obs_count) {
Trajectory trj;
trj.x = x;
trj.y = y;
Expand All @@ -95,7 +92,6 @@ struct Trajectory {
trj.flux = flux;
trj.lh = lh;
trj.obs_count = obs_count;
trj.valid = valid;
return trj;
}
};
Expand Down Expand Up @@ -199,16 +195,14 @@ static void trajectory_bindings(py::module &m) {

py::class_<tj>(m, "Trajectory", pydocs::DOC_Trajectory)
.def(py::init(&tj::make_trajectory), py::arg("x") = 0, py::arg("y") = 0, py::arg("vx") = 0.0f,
py::arg("vy") = 0.0f, py::arg("flux") = 0.0f, py::arg("lh") = 0.0f, py::arg("obs_count") = 0,
py::arg("valid") = true)
py::arg("vy") = 0.0f, py::arg("flux") = 0.0f, py::arg("lh") = 0.0f, py::arg("obs_count") = 0)
.def_readwrite("vx", &tj::vx)
.def_readwrite("vy", &tj::vy)
.def_readwrite("lh", &tj::lh)
.def_readwrite("flux", &tj::flux)
.def_readwrite("x", &tj::x)
.def_readwrite("y", &tj::y)
.def_readwrite("obs_count", &tj::obs_count)
.def_readwrite("valid", &tj::valid)
.def("get_x_pos", &tj::get_x_pos, py::arg("time"), py::arg("centered") = true,
pydocs::DOC_Trajectory_get_x_pos)
.def("get_y_pos", &tj::get_y_pos, py::arg("time"), py::arg("centered") = true,
Expand All @@ -220,13 +214,13 @@ static void trajectory_bindings(py::module &m) {
.def("__str__", &tj::to_string)
.def(py::pickle(
[](const tj &p) { // __getstate__
return py::make_tuple(p.vx, p.vy, p.lh, p.flux, p.x, p.y, p.obs_count, p.valid);
return py::make_tuple(p.vx, p.vy, p.lh, p.flux, p.x, p.y, p.obs_count);
},
[](py::tuple t) { // __setstate__
if (t.size() != 8) throw std::runtime_error("Invalid state!");
tj trj = {t[0].cast<float>(), t[1].cast<float>(), t[2].cast<float>(),
t[3].cast<float>(), t[4].cast<int>(), t[5].cast<int>(),
t[6].cast<int>(), t[7].cast<bool>()};
t[6].cast<int>()};
return trj;
}));
}
Expand Down
3 changes: 0 additions & 3 deletions src/kbmod/search/pydocs/common_docs.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@ static const auto DOC_Trajectory = R"doc(
flux : `float`
Flux (accumulated?)
obs_count : `int`
Number of observations trajectory was seen in.
valid : `bool`
Whether the trajectory is valid. Used for filtering.
)doc";

static const auto DOC_Trajectory_get_x_pos = R"doc(
Expand Down
44 changes: 0 additions & 44 deletions src/kbmod/search/pydocs/trajectory_list_docs.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,50 +143,6 @@ static const auto DOC_TrajectoryList_sort_by_likelihood = R"doc(
Raises a ``RuntimeError`` the data is on GPU.
)doc";

static const auto DOC_TrajectoryList_sort_by_obs_count = R"doc(
Sort the data in order of decreasing obs_count. The data must reside on the CPU.
Raises
------
Raises a ``RuntimeError`` the data is on GPU.
)doc";

static const auto DOC_TrajectoryList_filter_by_likelihood = R"doc(
Sort the data in order of decreasing likelihood and drop everything less than
a given threshold. The data must reside on the CPU.
Parameters
----------
min_likelihood : `float`
The threshold on minimum likelihood.
Raises
------
Raises a ``RuntimeError`` the data is on GPU.
)doc";

static const auto DOC_TrajectoryList_filter_by_obs_count = R"doc(
Sort the data in order of decreasing obs_count and drop everything less than
a given threshold. The data must reside on the CPU.
Parameters
----------
min_obs_count : `int`
The threshold on minimum number of observations.
Raises
------
Raises a ``RuntimeError`` the data is on GPU.
)doc";

static const auto DOC_TrajectoryList_filter_by_valid = R"doc(
Filter out all trajectories with the ``valid`` attribute set to ``False``.
Ordering is not preserved. The data must reside on the CPU.
Raises
------
Raises a ``RuntimeError`` the data is on GPU.
)doc";

} // namespace pydocs

Expand Down
47 changes: 0 additions & 47 deletions src/kbmod/search/trajectory_list.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,47 +73,6 @@ void TrajectoryList::sort_by_likelihood() {
[](const Trajectory& a, const Trajectory& b) { return b.lh < a.lh; });
}

void TrajectoryList::sort_by_obs_count() {
if (data_on_gpu) throw std::runtime_error("Data on GPU");
__gnu_parallel::sort(cpu_list.begin(), cpu_list.end(),
[](const Trajectory& a, const Trajectory& b) { return b.obs_count < a.obs_count; });
}

void TrajectoryList::filter_by_likelihood(float min_likelihood) {
sort_by_likelihood();

// Find the first index that does not meet the threshold.
uint64_t index = 0;
while ((index < max_size) && (cpu_list[index].lh >= min_likelihood)) {
++index;
}

// Drop the values below the threshold.
resize(index);
}

void TrajectoryList::filter_by_obs_count(int min_obs_count) {
sort_by_obs_count();

// Find the first index that does not meet the threshold.
uint64_t index = 0;
while ((index < max_size) && (cpu_list[index].obs_count >= min_obs_count)) {
++index;
}

// Drop the values below the threshold.
resize(index);
}

void TrajectoryList::filter_by_valid() {
if (data_on_gpu) throw std::runtime_error("Data on GPU");

auto new_end =
std::partition(cpu_list.begin(), cpu_list.end(), [](const Trajectory& x) { return x.valid; });
uint64_t new_size = std::distance(cpu_list.begin(), new_end);
resize(new_size);
}

void TrajectoryList::move_to_gpu() {
if (data_on_gpu) return; // Nothing to do.

Expand Down Expand Up @@ -160,12 +119,6 @@ static void trajectory_list_binding(py::module& m) {
.def("get_batch", &trjl::get_batch, pydocs::DOC_TrajectoryList_get_batch)
.def("sort_by_likelihood", &trjl::sort_by_likelihood,
pydocs::DOC_TrajectoryList_sort_by_likelihood)
.def("sort_by_obs_count", &trjl::sort_by_obs_count, pydocs::DOC_TrajectoryList_sort_by_obs_count)
.def("filter_by_likelihood", &trjl::filter_by_likelihood,
pydocs::DOC_TrajectoryList_filter_by_likelihood)
.def("filter_by_obs_count", &trjl::filter_by_obs_count,
pydocs::DOC_TrajectoryList_filter_by_obs_count)
.def("filter_by_valid", &trjl::filter_by_valid, pydocs::DOC_TrajectoryList_filter_by_valid)
.def("move_to_cpu", &trjl::move_to_cpu, pydocs::DOC_TrajectoryList_move_to_cpu)
.def("move_to_gpu", &trjl::move_to_gpu, pydocs::DOC_TrajectoryList_move_to_gpu);
}
Expand Down
6 changes: 1 addition & 5 deletions src/kbmod/search/trajectory_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,8 @@ class TrajectoryList {
// Get a batch of results.
std::vector<Trajectory> get_batch(uint64_t start, uint64_t count);

// Processing functions for sorting or filtering.
// Processing functions for sorting.
void sort_by_likelihood();
void sort_by_obs_count();
void filter_by_likelihood(float min_likelihood);
void filter_by_obs_count(int min_obs_count);
void filter_by_valid();

// Data allocation functions.
inline bool on_gpu() const { return data_on_gpu; }
Expand Down
9 changes: 0 additions & 9 deletions src/kbmod/trajectory_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,6 @@ def trajectory_from_np_object(result):
trj.flux = float(result["flux"][0])
trj.lh = float(result["lh"][0])
trj.obs_count = int(result["num_obs"][0])
if "valid" in result.dtype.names:
trj.valid = bool(result["valid"][0])
else:
trj.valid = True
return trj


Expand All @@ -135,10 +131,6 @@ def trajectory_from_dict(trj_dict):
trj.flux = float(trj_dict["flux"])
trj.lh = float(trj_dict["lh"])
trj.obs_count = int(trj_dict["obs_count"])
if "valid" in trj_dict:
trj.valid = bool(trj_dict["valid"])
else:
trj.valid = True
return trj


Expand Down Expand Up @@ -181,6 +173,5 @@ def trajectory_to_yaml(trj):
"flux": trj.flux,
"lh": trj.lh,
"obs_count": trj.obs_count,
"valid": trj.valid,
}
return dump(yaml_dict)
4 changes: 0 additions & 4 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def test_trajectory_create(self):
self.assertEqual(trj1.flux, 0.0)
self.assertEqual(trj1.lh, 0.0)
self.assertEqual(trj1.obs_count, 0)
self.assertEqual(trj1.valid, True)

# All specified
trj2 = Trajectory(x=1, y=2, vx=3.0, vy=4.0, flux=5.0, lh=6.0, obs_count=7)
Expand All @@ -39,7 +38,6 @@ def test_trajectory_create(self):
self.assertEqual(trj2.flux, 5.0)
self.assertEqual(trj2.lh, 6.0)
self.assertEqual(trj2.obs_count, 7)
self.assertEqual(trj2.valid, True)

# Some specified, some defaults
trj3 = Trajectory(y=2, vx=3.0, vy=-4.0, obs_count=7)
Expand All @@ -50,7 +48,6 @@ def test_trajectory_create(self):
self.assertEqual(trj3.flux, 0.0)
self.assertEqual(trj3.lh, 0.0)
self.assertEqual(trj3.obs_count, 7)
self.assertEqual(trj3.valid, True)

# Four specified by order
trj4 = Trajectory(4, 3, 2.0, 1.0)
Expand All @@ -61,7 +58,6 @@ def test_trajectory_create(self):
self.assertEqual(trj4.flux, 0.0)
self.assertEqual(trj4.lh, 0.0)
self.assertEqual(trj4.obs_count, 0)
self.assertEqual(trj4.valid, True)

def test_trajectory_predict(self):
trj = Trajectory(x=5, y=10, vx=2.0, vy=-1.0)
Expand Down
23 changes: 4 additions & 19 deletions tests/test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,6 @@ def test_from_trajectories(self):
self.assertEqual(len(table.colnames), 7)
self._assert_results_match_dict(table, self.input_dict)

# Test that we ignore invalid results, but track them in the filtered table.
self.trj_list[2].valid = False
self.trj_list[7].valid = False
table2 = Results.from_trajectories(self.trj_list, track_filtered=True)
self.assertEqual(len(table2), self.num_entries - 2)
for i in range(self.num_entries - 2):
self.assertFalse(table2["x"][i] == 2 or table2["x"][i] == 7)

filtered = table2.get_filtered()
self.assertEqual(len(filtered), 2)
self.assertEqual(filtered["x"][0], 2)
self.assertEqual(filtered["x"][1], 7)

def test_from_dict(self):
self.input_dict["something_added"] = [i for i in range(self.num_entries)]

Expand Down Expand Up @@ -507,12 +494,10 @@ def test_write_filter_stats(self):
data = FileUtils.load_csv_to_list(file_path)
self.assertEqual(data[0][0], "unfiltered")
self.assertEqual(data[0][1], "5")
self.assertEqual(data[1][0], "invalid_trajectory")
self.assertEqual(data[1][1], "0")
self.assertEqual(data[2][0], "filter1")
self.assertEqual(data[2][1], "2")
self.assertEqual(data[3][0], "filter2")
self.assertEqual(data[3][1], "3")
self.assertEqual(data[1][0], "filter1")
self.assertEqual(data[1][1], "2")
self.assertEqual(data[2][0], "filter2")
self.assertEqual(data[2][1], "3")

def test_mask_based_on_invalid_obs(self):
num_times = 5
Expand Down
Loading

0 comments on commit e2d523e

Please sign in to comment.