Skip to content

Commit

Permalink
Merge pull request #364 from dirac-institute/eigen/raw_image
Browse files Browse the repository at this point in the history
Replace current array backend in RawImage with Eigen Matrix.
  • Loading branch information
DinoBektesevic authored Nov 2, 2023
2 parents 0bce446 + c443edb commit fc69675
Show file tree
Hide file tree
Showing 44 changed files with 2,852 additions and 1,755 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,5 @@ docs/source/examples/_notebooks
# emacs work in progress files (lock and autosave)
.#*
\#*#
# Exclude files starting with exclude (local test and development)
exclude*
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[submodule "lib/pybind11"]
path = lib/pybind11
url = https://github.com/pybind/pybind11.git
[submodule "include/eigen"]
path = include/eigen
url = https://gitlab.com/libeigen/eigen.git
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ find_library(CFITSIO_LIBRARY

add_subdirectory(lib/pybind11)

set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD 17)

include_directories(
include/
include/eigen
)


Expand Down
1 change: 1 addition & 0 deletions include/eigen
Submodule eigen added at 6d829e
20 changes: 17 additions & 3 deletions src/kbmod/analysis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,14 @@ def get_all_stamps(self, result_list, search, stamp_radius):
stamp_edge = stamp_radius * 2 + 1
for row in result_list.results:
stamps = kb.StampCreator.get_stamps(search.get_imagestack(), row.trajectory, stamp_radius)
row.all_stamps = np.array([np.array(stamp).reshape(stamp_edge, stamp_edge) for stamp in stamps])
# TODO: a way to avoid a copy here would be to do
# np.array([s.image for s in stamps], dtype=np.single, copy=False)
# but that could cause a problem with reference counting at the m
# moment. The real fix is to make the stamps return Image not
# RawImage, return the Image and avoid a reference to a private
# attribute. This risks collecting RawImage but leaving a dangling
# ref to its private field. That's a fix for another time.
row.all_stamps = np.array([stamp.image for stamp in stamps])

def apply_clipped_sigmaG(self, result_list):
"""This function applies a clipped median filter to the results of a KBMOD
Expand Down Expand Up @@ -324,9 +331,16 @@ def apply_stamp_filter(
params,
kb.HAS_GPU and len(trj_slice) > 100,
)
# TODO: a way to avoid a copy here would be to do
# np.array([s.image for s in stamps], dtype=np.single, copy=False)
# but that could cause a problem with reference counting at the m
# moment. The real fix is to make the stamps return Image not
# RawImage and avoid reference to an private attribute and risking
# collecting RawImage but leaving a dangling ref to the attribute.
# That's a fix for another time so I'm leaving it as a copy here
for ind, stamp in enumerate(stamps_slice):
if stamp.get_width() > 1:
result_list.results[ind + start_idx].stamp = np.array(stamp)
if stamp.width > 1:
result_list.results[ind + start_idx].stamp = np.array(stamp.image)
all_valid_inds.append(ind + start_idx)

# Move to the next chunk.
Expand Down
5 changes: 2 additions & 3 deletions src/kbmod/fake_data_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,14 @@ def add_fake_object(img, x, y, flux, psf=None):
sci = img

if psf is None:
sci.add_pixel_interp(x, y, flux)
sci.interpolated_add(x, y, flux)
else:
dim = psf.get_dim()
initial_x = x - psf.get_radius()
initial_y = y - psf.get_radius()

for i in range(dim):
for j in range(dim):
sci.add_pixel_interp(initial_x + i, initial_y + j, flux * psf.get_value(i, j))
sci.interpolated_add(float(initial_x + i), float(initial_y + j), flux * psf.get_value(i, j))


class FakeDataSet:
Expand Down
4 changes: 2 additions & 2 deletions src/kbmod/filters/stamp_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ def keep_row(self, row: ResultRow):
stamp = row.stamp.reshape([self.width, self.width])
peak_pos = RawImage(stamp).find_peak(True)
return (
abs(peak_pos.x - self.stamp_radius) < self.x_thresh
and abs(peak_pos.y - self.stamp_radius) < self.y_thresh
abs(peak_pos.i - self.stamp_radius) < self.x_thresh
and abs(peak_pos.j - self.stamp_radius) < self.y_thresh
)


Expand Down
17 changes: 12 additions & 5 deletions src/kbmod/search/bindings.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
#include <pybind11/pybind11.h>
#include <pybind11/operators.h>
#include <pybind11/numpy.h> // still required for PSF.h
#include <pybind11/eigen.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>

namespace py = pybind11;

#include "common.h"
#include "geom.h"

#include "psf.cpp"
#include "raw_image.cpp"
Expand All @@ -9,10 +16,6 @@
#include "stack_search.cpp"
#include "stamp_creator.cpp"
#include "filtering.cpp"
#include "common.h"

using pp = search::PixelPos;
using std::to_string;

PYBIND11_MODULE(search, m) {
m.attr("KB_NO_DATA") = pybind11::float_(search::NO_DATA);
Expand All @@ -22,6 +25,10 @@ PYBIND11_MODULE(search, m) {
.value("STAMP_MEAN", search::StampType::STAMP_MEAN)
.value("STAMP_MEDIAN", search::StampType::STAMP_MEDIAN)
.export_values();
indexing::index_bindings(m);
indexing::point_bindings(m);
indexing::rectangle_bindings(m);
indexing::geom_functions(m);
search::psf_bindings(m);
search::raw_image_bindings(m);
search::layered_image_bindings(m);
Expand Down
7 changes: 5 additions & 2 deletions src/kbmod/search/common.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
#ifndef COMMON_H_
#define COMMON_H_

#include <assert.h>
#include <string>

#include "pydocs/common_docs.h"

// assert(condition, message if !condition)
#define assertm(exp, msg) assert(((void)msg, exp))

namespace search {
#ifdef HAVE_CUDA
constexpr bool HAVE_GPU = true;
Expand Down Expand Up @@ -143,8 +148,6 @@ struct ImageMoments {
};

#ifdef Py_PYTHON_H
namespace py = pybind11;

static void trajectory_bindings(py::module &m) {
using tj = Trajectory;

Expand Down
Loading

0 comments on commit fc69675

Please sign in to comment.