Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove the (unused) valid bit from trajectories. #707

Merged
merged 3 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/kbmod/filters/sigma_g_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from scipy.special import erfinv

from kbmod.results import Results
from kbmod.search import DebugTimer

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -177,7 +178,9 @@ def apply_clipped_sigma_g(clipper, result_data):
logger.info("SigmaG Clipping : skipping, nothing to filter.")
return

filter_timer = DebugTimer("sigma-g filtering", logger)
lh = result_data.compute_likelihood_curves(filter_obs=True, mask_value=np.nan)
obs_valid = clipper.compute_clipped_sigma_g_matrix(lh)
result_data.update_obs_valid(obs_valid)
filter_timer.stop()
return
4 changes: 4 additions & 0 deletions src/kbmod/filters/stamp_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,16 @@ def append_coadds(result_data, im_stack, coadd_types, radius, chunk_size=100_000
"""
if radius <= 0:
raise ValueError(f"Invalid stamp radius {radius}")
stamp_timer = DebugTimer("computing extra coadds", logger)

params = StampParameters()
params.radius = radius
params.do_filtering = False

# Loop through all the coadd types in the list, generating a corresponding stamp.
for coadd_type in coadd_types:
logger.info(f"Adding coadd={coadd_type} for all results.")

if coadd_type == "median":
params.stamp_type = StampType.STAMP_MEDIAN
elif coadd_type == "mean":
Expand All @@ -212,6 +215,7 @@ def append_coadds(result_data, im_stack, coadd_types, radius, chunk_size=100_000
chunk_size=chunk_size,
colname=f"coadd_{coadd_type}",
)
stamp_timer.stop()


def append_all_stamps(result_data, im_stack, stamp_radius):
Expand Down
8 changes: 2 additions & 6 deletions src/kbmod/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,8 @@ def from_trajectories(cls, trajectories, track_filtered=False):
input_d = {}
for col in cls.required_cols:
input_d[col[0]] = []
valid_mask = []

# Add the valid trajectories to the table.
# Add the trajectories to the table.
for trj in trajectories:
input_d["x"].append(trj.x)
input_d["y"].append(trj.y)
Expand All @@ -150,17 +149,14 @@ def from_trajectories(cls, trajectories, track_filtered=False):
input_d["likelihood"].append(trj.lh)
input_d["flux"].append(trj.flux)
input_d["obs_count"].append(trj.obs_count)
valid_mask.append(trj.valid)

# Check for any missing columns and fill in the default value.
for col in cls.required_cols:
if col[0] not in input_d:
input_d[col[0]] = [col[2]] * num_valid
invalid_d[col[0]] = [col[2]] * num_invalid
input_d[col[0]] = [col[2]] * len(trajectories)

# Create the table and add the unfiltered (and filtered) results.
results = Results(input_d, track_filtered=track_filtered)
results.filter_rows(np.array(valid_mask, dtype=bool), "invalid_trajectory")
return results

@classmethod
Expand Down
16 changes: 5 additions & 11 deletions src/kbmod/search/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ struct Trajectory {
int y = 0;
// Number of images summed
int obs_count;
// Whether the trajectory is valid. Used for on-GPU filtering.
bool valid = true;

// Get pixel positions from a zero-shifted time. Centered indicates whether
// the prediction starts from the center of the pixel (which it does in the search)
Expand All @@ -79,14 +77,13 @@ struct Trajectory {
const std::string to_string() const {
return "lh: " + std::to_string(lh) + " flux: " + std::to_string(flux) + " x: " + std::to_string(x) +
" y: " + std::to_string(y) + " vx: " + std::to_string(vx) + " vy: " + std::to_string(vy) +
" obs_count: " + std::to_string(obs_count) + " valid: " + std::to_string(valid);
" obs_count: " + std::to_string(obs_count);
}

// This is a hack to provide a constructor with non-default arguments in Python. If we include
// the constructor as a method in the Trajectory struct CUDA will complain when creating new objects
// because it cannot call out to a host function.
static Trajectory make_trajectory(int x, int y, float vx, float vy, float flux, float lh, int obs_count,
bool valid) {
static Trajectory make_trajectory(int x, int y, float vx, float vy, float flux, float lh, int obs_count) {
Trajectory trj;
trj.x = x;
trj.y = y;
Expand All @@ -95,7 +92,6 @@ struct Trajectory {
trj.flux = flux;
trj.lh = lh;
trj.obs_count = obs_count;
trj.valid = valid;
return trj;
}
};
Expand Down Expand Up @@ -199,16 +195,14 @@ static void trajectory_bindings(py::module &m) {

py::class_<tj>(m, "Trajectory", pydocs::DOC_Trajectory)
.def(py::init(&tj::make_trajectory), py::arg("x") = 0, py::arg("y") = 0, py::arg("vx") = 0.0f,
py::arg("vy") = 0.0f, py::arg("flux") = 0.0f, py::arg("lh") = 0.0f, py::arg("obs_count") = 0,
py::arg("valid") = true)
py::arg("vy") = 0.0f, py::arg("flux") = 0.0f, py::arg("lh") = 0.0f, py::arg("obs_count") = 0)
.def_readwrite("vx", &tj::vx)
.def_readwrite("vy", &tj::vy)
.def_readwrite("lh", &tj::lh)
.def_readwrite("flux", &tj::flux)
.def_readwrite("x", &tj::x)
.def_readwrite("y", &tj::y)
.def_readwrite("obs_count", &tj::obs_count)
.def_readwrite("valid", &tj::valid)
.def("get_x_pos", &tj::get_x_pos, py::arg("time"), py::arg("centered") = true,
pydocs::DOC_Trajectory_get_x_pos)
.def("get_y_pos", &tj::get_y_pos, py::arg("time"), py::arg("centered") = true,
Expand All @@ -220,13 +214,13 @@ static void trajectory_bindings(py::module &m) {
.def("__str__", &tj::to_string)
.def(py::pickle(
[](const tj &p) { // __getstate__
return py::make_tuple(p.vx, p.vy, p.lh, p.flux, p.x, p.y, p.obs_count, p.valid);
return py::make_tuple(p.vx, p.vy, p.lh, p.flux, p.x, p.y, p.obs_count);
},
[](py::tuple t) { // __setstate__
if (t.size() != 8) throw std::runtime_error("Invalid state!");
tj trj = {t[0].cast<float>(), t[1].cast<float>(), t[2].cast<float>(),
t[3].cast<float>(), t[4].cast<int>(), t[5].cast<int>(),
t[6].cast<int>(), t[7].cast<bool>()};
t[6].cast<int>()};
return trj;
}));
}
Expand Down
3 changes: 0 additions & 3 deletions src/kbmod/search/pydocs/common_docs.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@ static const auto DOC_Trajectory = R"doc(
flux : `float`
Flux (accumulated?)
obs_count : `int`
Number of observations trajectory was seen in.
valid : `bool`
Whether the trajectory is valid. Used for filtering.
)doc";

static const auto DOC_Trajectory_get_x_pos = R"doc(
Expand Down
44 changes: 0 additions & 44 deletions src/kbmod/search/pydocs/trajectory_list_docs.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,50 +143,6 @@ static const auto DOC_TrajectoryList_sort_by_likelihood = R"doc(
Raises a ``RuntimeError`` the data is on GPU.
)doc";

static const auto DOC_TrajectoryList_sort_by_obs_count = R"doc(
Sort the data in order of decreasing obs_count. The data must reside on the CPU.

Raises
------
Raises a ``RuntimeError`` the data is on GPU.
)doc";

static const auto DOC_TrajectoryList_filter_by_likelihood = R"doc(
Sort the data in order of decreasing likelihood and drop everything less than
a given threshold. The data must reside on the CPU.

Parameters
----------
min_likelihood : `float`
The threshold on minimum likelihood.

Raises
------
Raises a ``RuntimeError`` the data is on GPU.
)doc";

static const auto DOC_TrajectoryList_filter_by_obs_count = R"doc(
Sort the data in order of decreasing obs_count and drop everything less than
a given threshold. The data must reside on the CPU.

Parameters
----------
min_obs_count : `int`
The threshold on minimum number of observations.

Raises
------
Raises a ``RuntimeError`` the data is on GPU.
)doc";

static const auto DOC_TrajectoryList_filter_by_valid = R"doc(
Filter out all trajectories with the ``valid`` attribute set to ``False``.
Ordering is not preserved. The data must reside on the CPU.

Raises
------
Raises a ``RuntimeError`` the data is on GPU.
)doc";

} // namespace pydocs

Expand Down
47 changes: 0 additions & 47 deletions src/kbmod/search/trajectory_list.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,47 +73,6 @@ void TrajectoryList::sort_by_likelihood() {
[](const Trajectory& a, const Trajectory& b) { return b.lh < a.lh; });
}

void TrajectoryList::sort_by_obs_count() {
if (data_on_gpu) throw std::runtime_error("Data on GPU");
__gnu_parallel::sort(cpu_list.begin(), cpu_list.end(),
[](const Trajectory& a, const Trajectory& b) { return b.obs_count < a.obs_count; });
}

void TrajectoryList::filter_by_likelihood(float min_likelihood) {
sort_by_likelihood();

// Find the first index that does not meet the threshold.
uint64_t index = 0;
while ((index < max_size) && (cpu_list[index].lh >= min_likelihood)) {
++index;
}

// Drop the values below the threshold.
resize(index);
}

void TrajectoryList::filter_by_obs_count(int min_obs_count) {
sort_by_obs_count();

// Find the first index that does not meet the threshold.
uint64_t index = 0;
while ((index < max_size) && (cpu_list[index].obs_count >= min_obs_count)) {
++index;
}

// Drop the values below the threshold.
resize(index);
}

void TrajectoryList::filter_by_valid() {
if (data_on_gpu) throw std::runtime_error("Data on GPU");

auto new_end =
std::partition(cpu_list.begin(), cpu_list.end(), [](const Trajectory& x) { return x.valid; });
uint64_t new_size = std::distance(cpu_list.begin(), new_end);
resize(new_size);
}

void TrajectoryList::move_to_gpu() {
if (data_on_gpu) return; // Nothing to do.

Expand Down Expand Up @@ -160,12 +119,6 @@ static void trajectory_list_binding(py::module& m) {
.def("get_batch", &trjl::get_batch, pydocs::DOC_TrajectoryList_get_batch)
.def("sort_by_likelihood", &trjl::sort_by_likelihood,
pydocs::DOC_TrajectoryList_sort_by_likelihood)
.def("sort_by_obs_count", &trjl::sort_by_obs_count, pydocs::DOC_TrajectoryList_sort_by_obs_count)
.def("filter_by_likelihood", &trjl::filter_by_likelihood,
pydocs::DOC_TrajectoryList_filter_by_likelihood)
.def("filter_by_obs_count", &trjl::filter_by_obs_count,
pydocs::DOC_TrajectoryList_filter_by_obs_count)
.def("filter_by_valid", &trjl::filter_by_valid, pydocs::DOC_TrajectoryList_filter_by_valid)
.def("move_to_cpu", &trjl::move_to_cpu, pydocs::DOC_TrajectoryList_move_to_cpu)
.def("move_to_gpu", &trjl::move_to_gpu, pydocs::DOC_TrajectoryList_move_to_gpu);
}
Expand Down
6 changes: 1 addition & 5 deletions src/kbmod/search/trajectory_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,8 @@ class TrajectoryList {
// Get a batch of results.
std::vector<Trajectory> get_batch(uint64_t start, uint64_t count);

// Processing functions for sorting or filtering.
// Processing functions for sorting.
void sort_by_likelihood();
void sort_by_obs_count();
void filter_by_likelihood(float min_likelihood);
void filter_by_obs_count(int min_obs_count);
void filter_by_valid();

// Data allocation functions.
inline bool on_gpu() const { return data_on_gpu; }
Expand Down
9 changes: 0 additions & 9 deletions src/kbmod/trajectory_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,6 @@ def trajectory_from_np_object(result):
trj.flux = float(result["flux"][0])
trj.lh = float(result["lh"][0])
trj.obs_count = int(result["num_obs"][0])
if "valid" in result.dtype.names:
trj.valid = bool(result["valid"][0])
else:
trj.valid = True
return trj


Expand All @@ -135,10 +131,6 @@ def trajectory_from_dict(trj_dict):
trj.flux = float(trj_dict["flux"])
trj.lh = float(trj_dict["lh"])
trj.obs_count = int(trj_dict["obs_count"])
if "valid" in trj_dict:
trj.valid = bool(trj_dict["valid"])
else:
trj.valid = True
return trj


Expand Down Expand Up @@ -181,6 +173,5 @@ def trajectory_to_yaml(trj):
"flux": trj.flux,
"lh": trj.lh,
"obs_count": trj.obs_count,
"valid": trj.valid,
}
return dump(yaml_dict)
4 changes: 0 additions & 4 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def test_trajectory_create(self):
self.assertEqual(trj1.flux, 0.0)
self.assertEqual(trj1.lh, 0.0)
self.assertEqual(trj1.obs_count, 0)
self.assertEqual(trj1.valid, True)

# All specified
trj2 = Trajectory(x=1, y=2, vx=3.0, vy=4.0, flux=5.0, lh=6.0, obs_count=7)
Expand All @@ -39,7 +38,6 @@ def test_trajectory_create(self):
self.assertEqual(trj2.flux, 5.0)
self.assertEqual(trj2.lh, 6.0)
self.assertEqual(trj2.obs_count, 7)
self.assertEqual(trj2.valid, True)

# Some specified, some defaults
trj3 = Trajectory(y=2, vx=3.0, vy=-4.0, obs_count=7)
Expand All @@ -50,7 +48,6 @@ def test_trajectory_create(self):
self.assertEqual(trj3.flux, 0.0)
self.assertEqual(trj3.lh, 0.0)
self.assertEqual(trj3.obs_count, 7)
self.assertEqual(trj3.valid, True)

# Four specified by order
trj4 = Trajectory(4, 3, 2.0, 1.0)
Expand All @@ -61,7 +58,6 @@ def test_trajectory_create(self):
self.assertEqual(trj4.flux, 0.0)
self.assertEqual(trj4.lh, 0.0)
self.assertEqual(trj4.obs_count, 0)
self.assertEqual(trj4.valid, True)

def test_trajectory_predict(self):
trj = Trajectory(x=5, y=10, vx=2.0, vy=-1.0)
Expand Down
23 changes: 4 additions & 19 deletions tests/test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,6 @@ def test_from_trajectories(self):
self.assertEqual(len(table.colnames), 7)
self._assert_results_match_dict(table, self.input_dict)

# Test that we ignore invalid results, but track them in the filtered table.
self.trj_list[2].valid = False
self.trj_list[7].valid = False
table2 = Results.from_trajectories(self.trj_list, track_filtered=True)
self.assertEqual(len(table2), self.num_entries - 2)
for i in range(self.num_entries - 2):
self.assertFalse(table2["x"][i] == 2 or table2["x"][i] == 7)

filtered = table2.get_filtered()
self.assertEqual(len(filtered), 2)
self.assertEqual(filtered["x"][0], 2)
self.assertEqual(filtered["x"][1], 7)

def test_from_dict(self):
self.input_dict["something_added"] = [i for i in range(self.num_entries)]

Expand Down Expand Up @@ -507,12 +494,10 @@ def test_write_filter_stats(self):
data = FileUtils.load_csv_to_list(file_path)
self.assertEqual(data[0][0], "unfiltered")
self.assertEqual(data[0][1], "5")
self.assertEqual(data[1][0], "invalid_trajectory")
self.assertEqual(data[1][1], "0")
self.assertEqual(data[2][0], "filter1")
self.assertEqual(data[2][1], "2")
self.assertEqual(data[3][0], "filter2")
self.assertEqual(data[3][1], "3")
self.assertEqual(data[1][0], "filter1")
self.assertEqual(data[1][1], "2")
self.assertEqual(data[2][0], "filter2")
self.assertEqual(data[2][1], "3")

def test_mask_based_on_invalid_obs(self):
num_times = 5
Expand Down
Loading
Loading