Skip to content

Commit

Permalink
Merge pull request #351 from dirac-institute/zeroed_time
Browse files Browse the repository at this point in the history
Revamp the interface for timestamps
  • Loading branch information
jeremykubica authored Sep 27, 2023
2 parents af3c35a + 4bbe76e commit 78ff3cb
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 84 deletions.
6 changes: 0 additions & 6 deletions src/kbmod/analysis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,6 @@ def load_images(
print(f"Loaded {len(images)} images")
stack = kb.ImageStack(images)

# Create a list of visit times and visit times shifted to 0.0.
min_time = min(visit_times)
zero_shifted = [(t - min_time) for t in visit_times]
stack.set_times(zero_shifted)
print("Times set", flush=True)

return (stack, wcs_list, visit_times)


Expand Down
53 changes: 22 additions & 31 deletions src/kbmod/search/image_stack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,17 @@ namespace py = pybind11;
namespace search {
ImageStack::ImageStack(const std::vector<std::string>& filenames, const std::vector<PSF>& psfs) {
verbose = true;
reset_images();
images = std::vector<LayeredImage>();
load_images(filenames, psfs);
extract_image_times();
set_time_origin();

global_mask = RawImage(get_width(), get_height());
global_mask.set_all_pix(0.0);
}

ImageStack::ImageStack(const std::vector<LayeredImage>& imgs) {
verbose = true;
images = imgs;
extract_image_times();
set_time_origin();

global_mask = RawImage(get_width(), get_height());
global_mask.set_all_pix(0.0);
}
Expand All @@ -42,39 +40,32 @@ namespace search {
if (verbose) std::cout << "\n";
}

void ImageStack::extract_image_times() {
// Load image times
image_times = std::vector<float>();
for (auto& i : images) {
image_times.push_back(float(i.get_obstime()));
}
}

void ImageStack::set_time_origin() {
// Set beginning time to 0.0
double initial_time = image_times[0];
for (auto& t : image_times) t = t - initial_time;
}

LayeredImage& ImageStack::get_single_image(int index) {
if (index < 0 || index > images.size()) throw std::out_of_range("ImageStack index out of bounds.");
return images[index];
}

void ImageStack::set_single_image(int index, LayeredImage& img) {
float ImageStack::get_obstime(int index) const {
if (index < 0 || index > images.size()) throw std::out_of_range("ImageStack index out of bounds.");
images[index] = img;
return images[index].get_obstime();
}

void ImageStack::set_times(const std::vector<float>& times) {
if (times.size() != img_count())
throw std::runtime_error("List of times provided does not match the number of images!");
image_times = times;
set_time_origin();
float ImageStack::get_zeroed_time(int index) const {
if (index < 0 || index > images.size()) throw std::out_of_range("ImageStack index out of bounds.");
return images[index].get_obstime() - images[0].get_obstime();
}

void ImageStack::reset_images() { images = std::vector<LayeredImage>(); }

std::vector<float> ImageStack::build_zeroed_times() const {
std::vector<float> zeroed_times = std::vector<float>();
if (images.size() > 0) {
float t0 = images[0].get_obstime();
for (auto& i : images) {
zeroed_times.push_back(i.get_obstime() - t0);
}
}
return zeroed_times;
}

void ImageStack::convolve_psf() {
for (auto& i : images) i.convolve_psf();
}
Expand Down Expand Up @@ -141,9 +132,9 @@ namespace search {
.def("get_single_image", &is::get_single_image,
py::return_value_policy::reference_internal,
pydocs::DOC_ImageStack_get_single_image)
.def("set_single_image", &is::set_single_image, pydocs::DOC_ImageStack_set_single_image)
.def("get_times", &is::get_times, pydocs::DOC_ImageStack_get_times)
.def("set_times", &is::set_times, pydocs::DOC_ImageStack_set_times )
.def("get_obstime", &is::get_obstime, pydocs::DOC_ImageStack_get_obstime)
.def("get_zeroed_time", &is::get_zeroed_time, pydocs::DOC_ImageStack_get_zeroed_time)
.def("build_zeroed_times", &is::build_zeroed_times, pydocs::DOC_ImageStack_build_zeroed_times)
.def("img_count", &is::img_count, pydocs::DOC_ImageStack_img_count)
.def("apply_mask_flags", &is::apply_mask_flags, pydocs::DOC_ImageStack_apply_mask_flags)
.def("apply_mask_threshold", &is::apply_mask_threshold, pydocs::DOC_ImageStack_apply_mask_threshold)
Expand Down
13 changes: 4 additions & 9 deletions src/kbmod/search/image_stack.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,12 @@ namespace search {
unsigned get_height() const { return images.size() > 0 ? images[0].get_height() : 0; }
unsigned get_npixels() const { return images.size() > 0 ? images[0].get_npixels() : 0; }
std::vector<LayeredImage>& get_images() { return images; }
const std::vector<float>& get_times() const { return image_times; }
float* get_timesDataRef() { return image_times.data(); }
LayeredImage& get_single_image(int index);

// Simple setters.
void set_times(const std::vector<float>& times);
void reset_images();
void set_single_image(int index, LayeredImage& img);
// Functions for getting times.
float get_obstime(int index) const;
float get_zeroed_time(int index) const;
std::vector<float> build_zeroed_times() const; // Linear cost.

// Apply makes to all the images.
void apply_global_mask(int flags, int threshold);
Expand All @@ -49,12 +47,9 @@ namespace search {

private:
void load_images(const std::vector<std::string>& filenames, const std::vector<PSF>& psfs);
void extract_image_times();
void set_time_origin();
void create_global_mask(int flags, int threshold);
std::vector<LayeredImage> images;
RawImage global_mask;
std::vector<float> image_times;
bool verbose;
};

Expand Down
26 changes: 14 additions & 12 deletions src/kbmod/search/pydocs/image_stack_docs.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,32 @@ namespace pydocs {
static const auto DOC_ImageStack = R"doc(
todo
)doc";

static const auto DOC_ImageStack_get_images = R"doc(
todo
)doc";

static const auto DOC_ImageStack_get_single_image = R"doc(
todo
static const auto DOC_ImageStack_img_count = R"doc(
Returns the number of images in the stack.
)doc";

static const auto DOC_ImageStack_set_single_image = R"doc(
todo
static const auto DOC_ImageStack_get_single_image = R"doc(
Returns a single LayeredImage for a given index.
)doc";

static const auto DOC_ImageStack_get_times = R"doc(
todo
static const auto DOC_ImageStack_get_obstime = R"doc(
Returns a single image's observation time in MJD.
)doc";

static const auto DOC_ImageStack_set_times = R"doc(
todo
static const auto DOC_ImageStack_get_zeroed_time = R"doc(
Returns a single image's observation time relative to that
of the first image.
)doc";

static const auto DOC_ImageStack_img_count = R"doc(
todo
)doc";
static const auto DOC_ImageStack_build_zeroed_times = R"doc(
Construct an array of time differentials between each image
in the stack and the first image.
")doc";

static const auto DOC_ImageStack_apply_mask_flags = R"doc(
todo
Expand Down
10 changes: 6 additions & 4 deletions src/kbmod/search/stack_search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,10 @@ namespace search {
end_timer();

// Create a data stucture for the per-image data.
std::vector<float> image_times = stack.build_zeroed_times();
PerImageData img_data;
img_data.num_images = stack.img_count();
img_data.image_times = stack.get_timesDataRef();
img_data.image_times = image_times.data();
if (params.use_corr) img_data.bary_corrs = &bary_corrs[0];

// Compute the encoding parameters for psi and phi if needed.
Expand Down Expand Up @@ -428,9 +429,10 @@ namespace search {
const int height = stack.get_height();

// Create a data stucture for the per-image data.
std::vector<float> image_times = stack.build_zeroed_times();
PerImageData img_data;
img_data.num_images = num_images;
img_data.image_times = stack.get_timesDataRef();
img_data.image_times = image_times.data();

// Allocate space for the results.
const int num_trajectories = t_array.size();
Expand Down Expand Up @@ -479,7 +481,7 @@ namespace search {
}

PixelPos StackSearch::get_trajectory_position(const Trajectory& t, int i) const {
float time = stack.get_times()[i];
float time = stack.get_zeroed_time(i);
if (use_corr) {
return {t.x + time * t.vx + bary_corrs[i].dx + t.x * bary_corrs[i].dxdx + t.y * bary_corrs[i].dxdy,
t.y + time * t.vy + bary_corrs[i].dy + t.x * bary_corrs[i].dydx +
Expand Down Expand Up @@ -513,7 +515,7 @@ namespace search {
int img_size = imgs.size();
std::vector<float> lightcurve;
lightcurve.reserve(img_size);
const std::vector<float>& times = stack.get_times();
std::vector<float> times = stack.build_zeroed_times();
for (int i = 0; i < img_size; ++i) {
/* Do not use get_pixel_interp(), because results from create_curves must
* be able to recover the same likelihoods as the ones reported by the
Expand Down
6 changes: 3 additions & 3 deletions tests/test_analysis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,8 @@ def test_apply_stamp_filter(self):
int(self.img_count / 2),
)

mjds = np.array(stack.get_times())
kb_post_process = PostProcess(self.config, mjds)
zeroed_times = np.array(stack.build_zeroed_times())
kb_post_process = PostProcess(self.config, zeroed_times)

keep = kb_post_process.load_and_filter_results(
search,
Expand Down Expand Up @@ -302,7 +302,7 @@ def test_clustering(self):
cluster_params["y_size"] = self.dim_y
cluster_params["vel_lims"] = [self.min_vel, self.max_vel]
cluster_params["ang_lims"] = [self.min_angle, self.max_angle]
cluster_params["mjd"] = np.array(self.stack.get_times())
cluster_params["mjd"] = np.array(self.stack.build_zeroed_times())

trjs = [
self._make_trajectory(10, 11, 1, 2, 100.0),
Expand Down
19 changes: 8 additions & 11 deletions tests/test_image_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def setUp(self):
60, # dim_y = 60 pixels,
2.0, # noise_level
4.0, # variance
2.0 * i, # time
2.0 * i + 1.0, # time
self.p[i],
)

Expand All @@ -38,27 +38,24 @@ def test_create(self):
def test_access(self):
# Test we can access an individual image.
img = self.im_stack.get_single_image(1)
self.assertEqual(img.get_obstime(), 2.0)
self.assertEqual(img.get_obstime(), 3.0)
self.assertEqual(img.get_name(), "layered_test_1")

# Test an out of bounds access.
with self.assertRaises(IndexError):
img = self.im_stack.get_single_image(self.num_images + 1)

def test_times(self):
times = self.im_stack.get_times()
# Check that we can access specific times.
self.assertEqual(self.im_stack.get_obstime(1), 3.0)
self.assertEqual(self.im_stack.get_zeroed_time(1), 2.0)

# Check that we can build the full zeroed times list.
times = self.im_stack.build_zeroed_times()
self.assertEqual(len(times), self.num_images)
for i in range(self.num_images):
self.assertEqual(times[i], 2.0 * i)

new_times = [3.0 * i for i in range(self.num_images)]
self.im_stack.set_times(new_times)

times2 = self.im_stack.get_times()
self.assertEqual(len(times2), self.num_images)
for i in range(self.num_images):
self.assertEqual(times2[i], 3.0 * i)

def test_apply_mask(self):
# Nothing is initially masked.
for i in range(self.num_images):
Expand Down
16 changes: 8 additions & 8 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def test_sci_viz_stamps(self):
sci_stamps = self.search.get_stamps(self.trj, 2)
self.assertEqual(len(sci_stamps), self.imCount)

times = self.stack.get_times()
times = self.stack.build_zeroed_times()
for i in range(self.imCount):
self.assertEqual(sci_stamps[i].get_width(), 5)
self.assertEqual(sci_stamps[i].get_height(), 5)
Expand All @@ -278,7 +278,7 @@ def test_stacked_sci(self):
self.assertEqual(sci.get_height(), 5)

# Compute the true stacked pixel for the middle of the track.
times = self.stack.get_times()
times = self.stack.build_zeroed_times()
sum_middle = 0.0
for i in range(self.imCount):
t = times[i]
Expand Down Expand Up @@ -309,7 +309,7 @@ def test_median_stamps_trj(self):
self.assertEqual(medianStamps1.get_height(), 5)

# Compute the true median pixel for the middle of the track.
times = self.stack.get_times()
times = self.stack.build_zeroed_times()
pix_values0 = []
pix_values1 = []
for i in range(self.imCount):
Expand Down Expand Up @@ -368,7 +368,7 @@ def test_mean_stamps_trj(self):
self.assertEqual(meanStamp1.get_height(), 5)

# Compute the true median pixel for the middle of the track.
times = self.stack.get_times()
times = self.stack.build_zeroed_times()
pix_sum0 = 0.0
pix_sum1 = 0.0
pix_count0 = 0.0
Expand Down Expand Up @@ -605,7 +605,7 @@ def test_coadd_cpu(self):
self.assertEqual(medianStamps[0].get_height(), 2 * params.radius + 1)

# Compute the true summed and mean pixels for all of the pixels in the stamp.
times = self.stack.get_times()
times = self.stack.build_zeroed_times()
for stamp_x in range(2 * params.radius + 1):
for stamp_y in range(2 * params.radius + 1):
x_offset = stamp_x - params.radius
Expand Down Expand Up @@ -656,7 +656,7 @@ def test_coadd_gpu(self):
self.assertEqual(medianStamps[0].get_height(), 2 * params.radius + 1)

# Compute the true summed and mean pixels for all of the pixels in the stamp.
times = self.stack.get_times()
times = self.stack.build_zeroed_times()
for stamp_x in range(2 * params.radius + 1):
for stamp_y in range(2 * params.radius + 1):
x_offset = stamp_x - params.radius
Expand Down Expand Up @@ -702,7 +702,7 @@ def test_coadd_cpu_use_inds(self):
meanStamps = self.search.get_coadded_stamps([self.trj, self.trj], inds, params, False)

# Compute the true summed and mean pixels for all of the pixels in the stamp.
times = self.stack.get_times()
times = self.stack.build_zeroed_times()
for stamp_x in range(2 * params.radius + 1):
for stamp_y in range(2 * params.radius + 1):
x_offset = stamp_x - params.radius
Expand Down Expand Up @@ -751,7 +751,7 @@ def test_coadd_gpu_use_inds(self):
meanStamps = self.search.get_coadded_stamps([self.trj, self.trj], inds, params, True)

# Compute the true summed and mean pixels for all of the pixels in the stamp.
times = self.stack.get_times()
times = self.stack.build_zeroed_times()
for stamp_x in range(2 * params.radius + 1):
for stamp_y in range(2 * params.radius + 1):
x_offset = stamp_x - params.radius
Expand Down

0 comments on commit 78ff3cb

Please sign in to comment.