diff --git a/notebooks/region_search/discrete_piles_e2e.ipynb b/notebooks/region_search/discrete_piles_e2e.ipynb new file mode 100644 index 000000000..fdc1a0acd --- /dev/null +++ b/notebooks/region_search/discrete_piles_e2e.ipynb @@ -0,0 +1,700 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "85168431", + "metadata": {}, + "source": [ + "# Environment\n", + "\n", + "This demo was presented on [baldur.astro.washington.edu/jupyter](baldur.astro.washington.edu/jupyter) with the shared jupyter kernel `kbmod/w_2023_38`. \n", + "\n", + "It assumes that the user has read access to the test data at `/epyc/projects/kbmod/data` on epyc and is assumed to be executed on baldur.\n", + "\n", + "This notebook is currently stored for shared access in `/epyc/projects/kbmod/jupyter/notebooks/e2e`\n", + "\n", + "# Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5b5bc851-cffd-4198-ab51-3ee20fadfa07", + "metadata": {}, + "outputs": [], + "source": [ + "import kbmod\n", + "\n", + "from kbmod.region_search import RegionSearch" + ] + }, + { + "cell_type": "markdown", + "id": "665d26ce-fed1-4d35-be2a-60bb345d8e25", + "metadata": {}, + "source": [ + "# Inspect the butler repo's contents\n", + "While you can inspect the butler repo in a fairly straightforwared manner, the `RegionSearch` module provides some static methods that can help you pick which collections and datatypes to query from the butler. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bf089f83-dd2e-44ad-8f58-345b371b91b8", + "metadata": {}, + "outputs": [], + "source": [ + "REPO_PATH = \"/epyc/projects/kbmod/data/imdiff_w09_gaiadr3\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b3c520cd-c9e4-4398-b2ce-f266c173019c", + "metadata": {}, + "outputs": [], + "source": [ + "RegionSearch.get_collection_names(repo_path=REPO_PATH)" + ] + }, + { + "cell_type": "markdown", + "id": "f0b815b4-30ec-4647-a0e3-4c1897e87783", + "metadata": {}, + "source": [ + "For this example, we want to pick one of the collections with fakes and we'll use 'DECam/withFakes/20210318'.\n", + "\n", + "We now want to inspect how many datarefs are associated with each datatype we can query from this collection." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f85b1c92-5a46-48dc-a138-3a77793c0ed4", + "metadata": {}, + "outputs": [], + "source": [ + "collections = [\"DECam/withFakes/20210318\"]\n", + "\n", + "RegionSearch.get_dataset_type_freq(repo_path=REPO_PATH, collections=collections)" + ] + }, + { + "cell_type": "markdown", + "id": "ae3dad04-a30a-4d7e-908d-82fdc16c1f9c", + "metadata": {}, + "source": [ + "# Fetch Data from the Butler for Region Search\n", + "\n", + "From the above, 'fakes_calexp' seems a reasonable choice for a datatype we can limit our queries to.\n", + "\n", + "In the following, we construct a `RegionSearch` object which will instantiate a butler for our repo and fetch the image data keyed by (Visit, Detector, Region) (aka VDR) along with some associated metadata and calculations in an astropy table." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "315fb948", + "metadata": {}, + "outputs": [], + "source": [ + "dataset_types = [\"fakes_calexp\"]\n", + "rs = RegionSearch(\n", + " REPO_PATH, collections, dataset_types, visit_info_str=\"calexp.visitInfo\", fetch_data_on_start=True\n", + ")\n", + "\n", + "rs.vdr_data" + ] + }, + { + "cell_type": "markdown", + "id": "052f3af7", + "metadata": {}, + "source": [ + "# Find Discrete Piles\n", + "\n", + "In the 10 images above we want to find" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "347036e6-32ac-492a-b448-9b8d3bff12fb", + "metadata": {}, + "outputs": [], + "source": [ + "overlapping_sets = rs.find_overlapping_coords(uncertainty_radius=30)\n", + "print(f\"Found {len(overlapping_sets)} discrete piles\")\n", + "for i in range(len(overlapping_sets)):\n", + " print(\n", + " f\"In overlapping set {i + 1}, we have the following indices for images in the VDR data table: {overlapping_sets[i]}\"\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "8eba4ae1-e8e5-42bb-a1e0-e5b7e1387f78", + "metadata": {}, + "source": [ + "## Create an ImageCollection\n", + "The first pile has the most images, so we'll use it to create a KBMOD ImageCollection from which we can run a search." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "894617d3-d033-4961-9caf-247156033a50", + "metadata": {}, + "outputs": [], + "source": [ + "uris = [rs.vdr_data[\"uri\"][index] for index in overlapping_sets[0]]\n", + "ic = kbmod.ImageCollection.fromTargets(uris)\n", + "ic" + ] + }, + { + "cell_type": "markdown", + "id": "de5baeb4-8ce5-4f43-a8bb-fec2102d3491", + "metadata": {}, + "source": [ + "# Create a KBMOD Workunit from the ImageCollection\n", + "\n", + "Use KBMOD to search for trajectories in one of the identified discrete piles" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aedb63bd-8542-4750-b47d-34c606f94565", + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "import os\n", + "import numpy as np\n", + "\n", + "results_suffix = \"DEMO\"\n", + "\n", + "res_filepath = \"./demo_results\"\n", + "if not Path(res_filepath).is_dir():\n", + " os.mkdir(res_filepath)\n", + "\n", + "# The demo data has an object moving at x_v=10 px/day\n", + "# and y_v = 0 px/day. So we search velocities [0, 20].\n", + "v_min = 3000\n", + "v_max = 4000\n", + "v_steps = 50\n", + "v_arr = [v_min, v_max, v_steps]\n", + "\n", + "# angle with respect to ecliptic, in radians\n", + "ang_below = 3 * np.pi / 2 # 0\n", + "ang_above = 2 * np.pi # 1\n", + "ang_steps = 50 # 21\n", + "ang_arr = [ang_below, ang_above, ang_steps]\n", + "\n", + "# There are 3 images in the demo data. Make sure we see\n", + "# the object in at least 2.\n", + "num_obs = 2\n", + "\n", + "input_parameters = {\n", + " # Required\n", + " \"res_filepath\": res_filepath,\n", + " \"output_suffix\": results_suffix,\n", + " \"v_arr\": v_arr,\n", + " \"ang_arr\": ang_arr,\n", + " # Important\n", + " \"num_obs\": 2,\n", + " \"do_mask\": False,\n", + " \"lh_level\": 10.0,\n", + " \"gpu_filter\": True,\n", + " # Fine tuning\n", + " \"sigmaG_lims\": [15, 60],\n", + " \"mom_lims\": [37.5, 37.5, 1.5, 1.0, 1.0],\n", + " \"peak_offset\": [3.0, 3.0],\n", + " \"chunk_size\": 1000000,\n", + " \"stamp_type\": \"cpp_median\",\n", + " \"eps\": 0.03,\n", + " \"clip_negative\": True,\n", + " \"mask_num_images\": 0,\n", + " \"cluster_type\": \"position\",\n", + " \"average_angle\": 0.0,\n", + "}\n", + "\n", + "config = kbmod.configuration.SearchConfiguration()\n", + "config.set_multiple(input_parameters)\n", + "\n", + "wunit = ic.toWorkUnit(config)" + ] + }, + { + "cell_type": "markdown", + "id": "10a6f708-296a-4dfd-9552-317f481543d0", + "metadata": {}, + "source": [ + "# Visualize Our ImageCollection\n", + "\n", + "The following defines some helper functions for visualizing the images in our `WorkUnit`. We can quickly inspect these to sanity check." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "70909d23", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "from astropy.visualization import astropy_mpl_style\n", + "from astropy.visualization import ZScaleInterval, simple_norm, imshow_norm, ZScaleInterval, SinhStretch\n", + "\n", + "\n", + "def get_image(workunit, n):\n", + " return workunit.im_stack.get_images()[n]\n", + "\n", + "\n", + "def get_science_image(workunit, n):\n", + " return get_image(workunit, n).get_science().image\n", + "\n", + "\n", + "def get_variance_image(workunit, n):\n", + " return get_image(workunit, n).get_variance().image\n", + "\n", + "\n", + "def get_mask_image(workunit, n):\n", + " return get_image(workunit, n).get_mask().image\n", + "\n", + "\n", + "def plot_img(img):\n", + " fig, ax = plt.subplots(figsize=(25, 25))\n", + " _ = imshow_norm(\n", + " img.T, ax, cmap=\"gray\", origin=\"lower\", interval=ZScaleInterval(contrast=0.5), stretch=SinhStretch()\n", + " )\n", + " plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "227cf312-b595-4f33-9e5f-67c5aab17c99", + "metadata": {}, + "source": [ + "## The Science Images" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b7c08fe4", + "metadata": {}, + "outputs": [], + "source": [ + "for i in range(len(ic)):\n", + " plot_img(get_science_image(wunit, i))" + ] + }, + { + "cell_type": "markdown", + "id": "37769833-fcc4-49b8-8773-1d075cedf000", + "metadata": {}, + "source": [ + "## The Variance Images" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9c08c5fe", + "metadata": {}, + "outputs": [], + "source": [ + "for i in range(len(ic)):\n", + " plot_img(get_variance_image(wunit, i))" + ] + }, + { + "cell_type": "markdown", + "id": "9e429a49-b79a-40f3-9f29-c3184341b1b6", + "metadata": {}, + "source": [ + "# Create a Reprojected Workunit\n", + "\n", + "First we'll need to create a new initial work unit so results can be saved in a different directory" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6271fd4b-3eb2-438a-93dd-2cbaf82dec34", + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "import os\n", + "\n", + "results_suffix = \"REPROJECT_DEMO\"\n", + "\n", + "res_filepath = \"./reproject_demo_results\"\n", + "if not Path(res_filepath).is_dir():\n", + " os.mkdir(res_filepath)\n", + "\n", + "# The demo data has an object moving at x_v=10 px/day\n", + "# and y_v = 0 px/day. So we search velocities [0, 20].\n", + "v_min = 3000\n", + "v_max = 4000\n", + "v_steps = 50\n", + "v_arr = [v_min, v_max, v_steps]\n", + "\n", + "# angle with respect to ecliptic, in radians\n", + "ang_below = 3 * np.pi / 2 # 0\n", + "ang_above = 2 * np.pi # 1\n", + "ang_steps = 50 # 21\n", + "ang_arr = [ang_below, ang_above, ang_steps]\n", + "\n", + "# There are 3 images in the demo data. Make sure we see\n", + "# the object in at least 2.\n", + "num_obs = 2\n", + "\n", + "input_parameters = {\n", + " # Required\n", + " \"res_filepath\": res_filepath,\n", + " \"output_suffix\": results_suffix,\n", + " \"v_arr\": v_arr,\n", + " \"ang_arr\": ang_arr,\n", + " # Important\n", + " \"num_obs\": 2,\n", + " \"do_mask\": False,\n", + " \"lh_level\": 10.0,\n", + " \"gpu_filter\": True,\n", + " # Fine tuning\n", + " \"sigmaG_lims\": [15, 60],\n", + " \"mom_lims\": [37.5, 37.5, 1.5, 1.0, 1.0],\n", + " \"peak_offset\": [3.0, 3.0],\n", + " \"chunk_size\": 1000000,\n", + " \"stamp_type\": \"cpp_median\",\n", + " \"eps\": 0.03,\n", + " \"clip_negative\": True,\n", + " \"mask_num_images\": 0,\n", + " \"cluster_type\": \"position\",\n", + " \"average_angle\": 0.0,\n", + "}\n", + "\n", + "config = kbmod.configuration.SearchConfiguration()\n", + "config.set_multiple(input_parameters)\n", + "\n", + "new_wunit = ic.toWorkUnit(config)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4599f776-291c-4238-acc8-f6f8acf95087", + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "from kbmod import reprojection\n", + "\n", + "common_wcs = new_wunit._per_image_wcs[0]\n", + "\n", + "uwunit = reprojection.reproject_work_unit(new_wunit, common_wcs)" + ] + }, + { + "cell_type": "markdown", + "id": "dbba31a3-b9d1-48aa-a0f9-29b6817f5b59", + "metadata": {}, + "source": [ + "# Let's visualize our reprojected images.\n", + "\n", + "## The reprojected science images" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cbab1cdc-62e2-47f6-a03c-68238da9a6cc", + "metadata": {}, + "outputs": [], + "source": [ + "for i in range(len(ic)):\n", + " plot_img(get_science_image(uwunit, i))" + ] + }, + { + "cell_type": "markdown", + "id": "0e0bb547-11eb-4e25-9345-54dfefbde80e", + "metadata": {}, + "source": [ + "## The reprojected variance images" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6001d8ad-0669-46d9-bbdd-fcdb090e021c", + "metadata": {}, + "outputs": [], + "source": [ + "for i in range(len(ic)):\n", + " plot_img(get_variance_image(uwunit, i))" + ] + }, + { + "cell_type": "markdown", + "id": "3791c433-62b9-42e0-88d6-0dc0a53ff06d", + "metadata": {}, + "source": [ + "# Run KBMOD Search without Reprojection" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18ceb66c-b9f2-4ec1-8c35-a2fc8cdd136a", + "metadata": {}, + "outputs": [], + "source": [ + "res = kbmod.run_search.SearchRunner().run_search_from_work_unit(wunit)" + ] + }, + { + "cell_type": "markdown", + "id": "25f16aac", + "metadata": {}, + "source": [ + "# Inspect the Results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c25713c9-27da-4c04-9903-089a4b58b38e", + "metadata": {}, + "outputs": [], + "source": [ + "trajectories = [t.trajectory for t in sorted(res.results, key=lambda x: x.trajectory.lh, reverse=True)]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cf232f31-656a-4dd0-94e2-c408893070ee", + "metadata": {}, + "outputs": [], + "source": [ + "trajectories" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d01b424b", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "# We can create stamps for each result\n", + "imgstack = wunit.im_stack\n", + "\n", + "# Create the stamps around remaining results\n", + "nres = len(trajectories)\n", + "fig, axes = plt.subplots(nres, 3, figsize=(10, nres * 3), sharey=True, sharex=True)\n", + "\n", + "stamp_size = 20\n", + "for row, traj in zip(axes, trajectories):\n", + " stamps = kbmod.search.StampCreator.get_stamps(imgstack, traj, stamp_size)\n", + " for ax, stamp in zip(row, stamps):\n", + " ax.imshow(stamp.image, interpolation=None, cmap=\"gist_heat\")\n", + "\n", + "plt.tight_layout()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aa81be5b", + "metadata": {}, + "outputs": [], + "source": [ + "# We can further filter these results - let's say we had a lower cutoff on likelihood of 10\n", + "# but now that we can see there are many results with a much larger likelihoods than that - we want to increase that limit\n", + "# This is not uncommon as usually the number of false positives returned by KBMOD is rather large\n", + "from kbmod.filters.stats_filters import LHFilter\n", + "\n", + "# Filter out all results that have a likelihood < 40.0.\n", + "lhfilter = LHFilter(40.0, None)\n", + "res.apply_filter(lhfilter)\n", + "print(f\"{res.num_results()} results remaining.\")\n", + "\n", + "for result in res.results:\n", + " print(result.trajectory)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7a8c15af", + "metadata": {}, + "outputs": [], + "source": [ + "# We can filter on stamps too, for example:\n", + "from kbmod.filters.stamp_filters import StampPeakFilter\n", + "\n", + "filter2 = StampPeakFilter(10, 2.1, 0.1)\n", + "res.apply_filter(filter2)\n", + "print(f\"{res.num_results()} results remaining.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d8076f74", + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(1, 3, figsize=(10, nres * 3), sharey=True, sharex=True)\n", + "\n", + "stamps = kbmod.search.StampCreator.get_stamps(imgstack, res.results[0].trajectory, 20)\n", + "for ax, stamp in zip(axes, stamps):\n", + " ax.imshow(stamp.image, interpolation=None, cmap=\"gist_heat\")\n", + "\n", + "plt.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "id": "bf918af6-987e-47d0-8ea2-ef79545f076e", + "metadata": {}, + "source": [ + "# Run KBMOD Search on the Reprojected Images" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e1904d07-6b40-45dd-900b-9e45fc4c6ad7", + "metadata": {}, + "outputs": [], + "source": [ + "reproject_res = kbmod.run_search.SearchRunner().run_search_from_work_unit(wunit)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "68d3c0b0-21e1-439b-bf45-fb95fd33343f", + "metadata": {}, + "outputs": [], + "source": [ + "reproj_traj = [t.trajectory for t in reproject_res.results]\n", + "reproj_traj" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "56fbf395-12ce-48dc-973d-1df97dc993e4", + "metadata": {}, + "outputs": [], + "source": [ + "# We can create stamps for each result\n", + "imgstack = uwunit.im_stack\n", + "\n", + "# Create the stamps around remaining results\n", + "nres = len(reproj_traj)\n", + "fig, axes = plt.subplots(nres, 3, figsize=(10, nres * 3), sharey=True, sharex=True)\n", + "\n", + "stamp_size = 20\n", + "for row, traj in zip(axes, reproj_traj):\n", + " stamps = kbmod.search.StampCreator.get_stamps(imgstack, traj, stamp_size)\n", + " for ax, stamp in zip(row, stamps):\n", + " ax.imshow(stamp.image, interpolation=None, cmap=\"gist_heat\")\n", + "\n", + "plt.tight_layout()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bf30090e-b533-4b6b-9c1b-e27552e3b3e5", + "metadata": {}, + "outputs": [], + "source": [ + "# We can further filter these results - let's say we had a lower cutoff on likelihood of 10\n", + "# but now that we can see there are many results with a much larger likelihoods than that - we want to increase that limit\n", + "# This is not uncommon as usually the number of false positives returned by KBMOD is rather large\n", + "from kbmod.filters.stats_filters import LHFilter\n", + "\n", + "# Filter out all results that have a likelihood < 40.0.\n", + "lhfilter = LHFilter(40.0, None)\n", + "reproject_res.apply_filter(lhfilter)\n", + "print(f\"{reproject_res.num_results()} results remaining.\")\n", + "\n", + "for result in reproject_res.results:\n", + " print(result.trajectory)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3b63a2b3-e39c-4dde-af41-b1aec66435c8", + "metadata": {}, + "outputs": [], + "source": [ + "# We can filter on stamps too, for example:\n", + "from kbmod.filters.stamp_filters import StampPeakFilter\n", + "\n", + "filter2 = StampPeakFilter(10, 2.1, 0.1)\n", + "reproject_res.apply_filter(filter2)\n", + "print(f\"{reproject_res.num_results()} results remaining.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ce8eab22-ac89-4a94-90e3-7628ea66af5f", + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(1, 3, figsize=(10, nres * 3), sharey=True, sharex=True)\n", + "\n", + "stamps = kbmod.search.StampCreator.get_stamps(imgstack, reproject_res.results[0].trajectory, 20)\n", + "for ax, stamp in zip(axes, stamps):\n", + " ax.imshow(stamp.image, interpolation=None, cmap=\"gist_heat\")\n", + "\n", + "plt.tight_layout()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "25a638a5-db0c-4380-b6bc-8d2c765aef08", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "kbmod/w_2023_38", + "language": "python", + "name": "kbmod_38" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/kbmod/region_search.py b/src/kbmod/region_search.py index 9c7781b8b..9b33ed21e 100644 --- a/src/kbmod/region_search.py +++ b/src/kbmod/region_search.py @@ -5,8 +5,13 @@ from concurrent.futures import ProcessPoolExecutor, as_completed +from astropy import units as u +from astropy.coordinates import SkyCoord + from astropy.table import Table +from urllib.parse import urlparse + def _chunked_data_ids(dataIds, chunk_size=200): """Helper function to yield successive chunk_size chunks from dataIds.""" @@ -14,6 +19,11 @@ def _chunked_data_ids(dataIds, chunk_size=200): yield dataIds[i : i + chunk_size] +def _trim_uri(uri): + """Trim the URI to a more standardized output.""" + return urlparse(uri).path + + class RegionSearch: """ A class for searching through a dataset for data suitable for KBMOD processing, @@ -35,7 +45,7 @@ def __init__( butler=None, visit_info_str="Exposure.visitInfo", max_workers=None, - fetch_data=False, + fetch_data_on_start=False, ): """ Parameters @@ -53,7 +63,7 @@ def __init__( max_workers : `int`, optional The maximum number of workers to use in parallel processing. Note that each parallel worker will instantiate its own Butler objects. If not provided, parallel processing is disabled. - fetch_data: `bool`, optional + fetch_data_on_start: `bool`, optional If True, fetch the VDR data when the object is created. Default is True. """ self.repo_path = repo_path @@ -69,7 +79,7 @@ def __init__( # Create an empty table to store the VDR (Visit, Detector, Region) data from the butler. self.vdr_data = Table() - if fetch_data: + if fetch_data_on_start: # Fetch the VDR data from the butler self.vdr_data = self.fetch_vdr_data() @@ -134,6 +144,8 @@ def is_parallel(self): def new_butler(self): """Instantiates a new Butler object from the repo_path.""" + if self.butler is not None: + return dafButler.Butler(self.repo_path, registry=self.butler.registry) return dafButler.Butler(self.repo_path) def set_collections(self, collections): @@ -202,7 +214,9 @@ def fetch_vdr_data(self, collections=None, dataset_types=None): vdr_dict["uri"] = self.get_uris(vdr_dict["data_id"]) # return as an Astropy Table - return Table(vdr_dict) + self.vdr_data = Table(vdr_dict) + + return self.vdr_data def get_instruments(self, data_ids=None, first_instrument_only=False): """ @@ -231,7 +245,9 @@ def get_instruments(self, data_ids=None, first_instrument_only=False): instruments.append(instrument) return instruments - def _get_uris_serial(self, data_ids, dataset_types=None, collections=None, butler=None): + def _get_uris_serial( + self, data_ids, dataset_types=None, collections=None, butler=None, trim_uri_func=_trim_uri + ): """Fetch URIs for a list of dataIds in serial fashion. Parameters @@ -244,6 +260,8 @@ def _get_uris_serial(self, data_ids, dataset_types=None, collections=None, butle The collections to use when fetching URIs. If None, use self.collections. butler : `lsst.daf.butler.Butler`, optional The Butler object to use for data access. If None, use self.butler. + trim_uri_func: `function`, optional + A function to trim the URIs. Default is _trim_uri. Returns ------- @@ -266,12 +284,12 @@ def _get_uris_serial(self, data_ids, dataset_types=None, collections=None, butle try: uri = self.butler.getURI(dataset_types[0], dataId=data_id, collections=collections) uri = uri.geturl() # Convert to URL string - uris.append(uri) + uris.append(trim_uri_func(uri)) except Exception as e: print(f"Failed to retrieve path for dataId {data_id}: {e}") return uris - def get_uris(self, data_ids, dataset_types=None, collections=None): + def get_uris(self, data_ids, dataset_types=None, collections=None, trim_uri_func=_trim_uri): """ Get the URIs for the given dataIds. @@ -283,6 +301,8 @@ def get_uris(self, data_ids, dataset_types=None, collections=None): The dataset types to use when fetching URIs. If None, use self.dataset_types. collections : `list[str]` The collections to use when fetching URIs. If None, use self.collections. + trim_uri_func: `function`, optional + A function to trim the URIs. Default is _trim_uri. Returns ------- @@ -314,6 +334,7 @@ def get_uris(self, data_ids, dataset_types=None, collections=None): dataset_types=dataset_types, collections=collections, butler=self.new_butler(), + trim_uri_func=trim_uri_func, ) for chunk in data_id_chunks ] @@ -342,3 +363,62 @@ def get_center_ra_dec(self, region): ra = bbox_center.getLon().asDegrees() dec = bbox_center.getLat().asDegrees() return ra, dec + + def find_overlapping_coords(self, data=None, uncertainty_radius=30): + """ + Find the overlapping sets of data based on the center coordinates of the data. + + Parameters + ---------- + data : `astropy.table.Table`, optional + The data to search for overlapping sets. If not provided, use the VDR data. + + uncertainty_radius : `float` + The radius in arcseconds to use when determining if two data points overlap. + + Returns + ------- + overlapping_sets : list[list[int]] + A list of overlapping sets of data. Each set is a list of the indices within + the VDR (Visit, Detector, Region) table. + """ + if not data: + if len(self.vdr_data) == 0: + self.vdr_data = self.fetch_vdr_data() + data = self.vdr_data + + # Assuming uncertainty_radius is provided as a float in arcseconds + uncertainty_radius_as = uncertainty_radius * u.arcsec + + # Convert the center coordinates to SkyCoord objects + all_ra_dec = SkyCoord( + ra=[x[0] for x in data["center_coord"]] * u.degree, + dec=[x[1] for x in data["center_coord"]] * u.degree, + ) + + # Indices of the data ids that we have already processed + processed_data_ids = set([]) + overlapping_sets = [] + for i in range(len(all_ra_dec) - 1): + coord = all_ra_dec[i] + if i not in processed_data_ids: + # We haven't chosen the current index for a previous pile, which means + # that it was not within the separation distance of any earlier coordinate + # with an index less than 'i'. So we only have to compute the separation + # distances for the coordinates that come after 'i'. + distances = coord.separation(all_ra_dec[i + 1 :]).to(u.arcsec).value + + # Consider choosing the the current index. + overlapping_data_ids = [i] + + for j in range(len(distances)): + if distances[j] <= uncertainty_radius_as.value: + # We add the indices of other coordinates within the radius, + # offset by our starting 'all_ra_dec' index of i + 1 + overlapping_data_ids.append(i + 1 + j) + if len(overlapping_data_ids) > 1: + # Add our choice of overlapping set to our results. + processed_data_ids.update(overlapping_data_ids) + overlapping_sets.append(overlapping_data_ids) + + return overlapping_sets diff --git a/tests/test_region_search.py b/tests/test_region_search.py index b89b48801..e03a14319 100644 --- a/tests/test_region_search.py +++ b/tests/test_region_search.py @@ -4,7 +4,19 @@ MOCK_REPO_PATH = "far/far/away" from unittest import mock -from utils import DatasetRef, DatasetId, dafButler, MockButler +from utils import ( + ConvexPolygon, + DatasetRef, + DatasetId, + dafButler, + DimensionRecord, + LonLat, + MockButler, + Registry, +) + +from astropy import units as u +from astropy.coordinates import SkyCoord with mock.patch.dict( "sys.modules", @@ -24,7 +36,8 @@ class TestRegionSearch(unittest.TestCase): """ def setUp(self): - self.butler = MockButler(MOCK_REPO_PATH) + self.registry = Registry() + self.butler = MockButler(MOCK_REPO_PATH, registry=self.registry) # For the default collections and dataset types, we'll just use the first two of each self.default_collections = self.butler.registry.queryCollections()[:2] @@ -41,7 +54,7 @@ def test_init(self): """ Test that the region search object can be initialized. """ - rs = region_search.RegionSearch(MOCK_REPO_PATH, [], [], butler=self.butler, fetch_data=False) + rs = region_search.RegionSearch(MOCK_REPO_PATH, [], [], butler=self.butler, fetch_data_on_start=False) self.assertTrue(rs is not None) self.assertEqual(0, len(rs.vdr_data)) # No data should be fetched @@ -54,7 +67,7 @@ def test_init_with_fetch(self): self.default_collections, self.default_datasetTypes, butler=self.butler, - fetch_data=True, + fetch_data_on_start=True, ) self.assertTrue(rs is not None) @@ -195,6 +208,179 @@ def test_get_center_ra_dec(self): center_ra_dec = self.rs.get_center_ra_dec(region) self.assertTrue(len(center_ra_dec) > 0) + def test_find_overlapping_coords(self): + """ + Tests that we can find discrete piles with overlapping coordinates + """ + # Create a set of regions that we can then greedily convert into discrete + # piles within a radius threshold + regions = [] + regions.append( + ConvexPolygon( + [ + (-0.8572310214106003, 0.5140136573995331, 0.03073981031324692), + (-0.8573126243779648, 0.514061814910292, 0.027486624625501416), + (-0.8603167349539873, 0.5090182552169222, 0.027486931695458513), + (-0.8602353512965948, 0.508969729485055, 0.03074011788419143), + ], + center=LonLat(2.604388763115912, 0.029117535741884262), + ) + ) + + regions.append( + ConvexPolygon( + [ + (-0.8709063227199754, 0.49048670961887364, 0.030740278684830185), + (-0.8709892077685559, 0.49053262854455665, 0.0274870930414856), + (-0.8738549039182039, 0.48540915703478454, 0.027487400111442725), + (-0.8737722280742154, 0.48536286405417617, 0.030740586255774707), + ], + center=LonLat(2.6316151722984484, 0.02911800433559046), + ) + ) + + regions.append( + ConvexPolygon( + [ + (-0.8632807208053553, 0.5017969238521098, 0.054279317408622074), + (-0.8634293237866446, 0.5018822737832265, 0.051029266970199966), + (-0.8663644987606522, 0.49679828690228733, 0.05102957395625042), + (-0.8662161100467123, 0.49671256579311107, 0.05427962489522404), + ], + center=LonLat(2.6180008039930964, 0.05267892959436134), + ) + ) + + regions.append( + ConvexPolygon( + [ + (-0.8572306476090913, 0.5140142807953888, 0.03073981031329064), + (-0.8573122505414348, 0.5140624383654911, 0.027486624625545138), + (-0.8603163647852367, 0.5090188808567737, 0.027486931695502228), + (-0.8602349811631328, 0.5089703550657229, 0.030740117884235148), + ], + center=LonLat(2.604388035895388, 0.029117535741927998), + ) + ) + + regions.append( + ConvexPolygon( + [ + (-0.8709066920283894, 0.490486083334642, 0.030739808639786412), + (-0.8709895757751559, 0.4905320014532335, 0.027486622951882263), + (-0.8738552681990356, 0.4854085278595403, 0.027486930021839356), + (-0.8737725936565892, 0.4853622356858718, 0.030740116210730917), + ], + center=LonLat(2.631615899518966, 0.029117534067630124), + ) + ) + + regions.append( + ConvexPolygon( + [ + (-0.8632815518374847, 0.5017949720584495, 0.05428414385365578), + (-0.8634301686036089, 0.5018803295307347, 0.051034094243599164), + (-0.8663653328545027, 0.4967963364589571, 0.0510344012296149), + (-0.8662169303549427, 0.49671060780815945, 0.05428445134022294), + ], + center=LonLat(2.618002912932494, 0.052683763175059205), + ) + ) + + regions.append( + ConvexPolygon( + [ + (-0.8572305070104566, 0.514014197049692, 0.030745131028441175), + (-0.8573121248210437, 0.5140623634817817, 0.027491945845133575), + (-0.8603162390634397, 0.5090188059722266, 0.027492252915090963), + (-0.86023484056309, 0.5089702713191868, 0.03074543859938591), + ], + center=LonLat(2.604388035895447, 0.02912285898079983), + ) + ) + + regions.append( + ConvexPolygon( + [ + (-0.8709079891032256, 0.4904834774301412, 0.030744639926530638), + (-0.8709908867073091, 0.4905294029836303, 0.027491454696659565), + (-0.8738565642262993, 0.4854059210532718, 0.027491761766616923), + (-0.8737738758254457, 0.48535962144409595, 0.030744947497475358), + ], + center=LonLat(2.631618808401125, 0.0291223676459133), + ) + ) + + regions.append( + ConvexPolygon( + [ + (-0.863282787712102, 0.5017930024533144, 0.05428269640419386), + (-0.8634314005967695, 0.5018783572241806, 0.05103264654570269), + (-0.8663665530170519, 0.4967943573222024, 0.051032953531728834), + (-0.8662181543980857, 0.4967086313723255, 0.05428300389077146), + ], + center=LonLat(2.618005240038212, 0.052682313585480325), + ) + ) + + regions.append( + ConvexPolygon( + [ + (-0.8572327018287437, 0.5140106819036223, 0.03074270327029348), + (-0.8573143130502634, 0.5140588439538368, 0.027489517856807404), + (-0.8603184063869375, 0.5090152739921819, 0.027489824926764658), + (-0.8602370144741296, 0.5089667437201095, 0.030743010841238115), + ], + center=LonLat(2.604392181052405, 0.02912043007100976), + ) + ) + + # Take the above regions have and construct them as DimensionRecords within + # our mock butler registry + new_records = [] + for i, region in enumerate(regions): + type = self.default_datasetTypes[i % 2] # Use modulo 2 to alternate through the two dataset types + new_records.append(DimensionRecord(f"dataId{i}", region, "fake_detector", type)) + self.registry.records = new_records + + # Fetch the VDR data for each of our 10 defined thresholds + data = self.rs.fetch_vdr_data() + self.assertEqual(len(data), 10) + + # Test that we can find 3 overlapping sets from the above test data + radius_threshold = 30 # radius in arcseconds + overlapping_sets = self.rs.find_overlapping_coords(data=data, uncertainty_radius=radius_threshold) + self.assertEqual(len(overlapping_sets), 3) + + # Test that none of the indices are repeated across the sets + prior_indices = set([]) + for s in overlapping_sets: + for idx in s: + self.assertNotIn(idx, prior_indices) + prior_indices.add(idx) + + # Test that for each set, the distances between all elements are within + # the uncertainty radius + for s in overlapping_sets: + # For this test data each set should have more than one element + self.assertGreater(len(s), 1) + + # Fetch the center coordinate for the index we chose from the VDR data + center_coords = [self.rs.vdr_data[idx]["center_coord"] for idx in s] + + # Convert the center coordinates for this pile to SkyCoord objects + ra_decs = SkyCoord( + ra=[c[0] * u.degree for c in center_coords], + dec=[c[1] * u.degree for c in center_coords], + ) + + # Compute the separation between all pairs of coordinates + for i in range(len(ra_decs)): + distances = ra_decs[i].separation(ra_decs) + for d in distances: + # Check that the separations is within the radius threshold + self.assertLessEqual(d.arcsec, radius_threshold) + if __name__ == "__main__": unittest.main() diff --git a/tests/utils/mock_butler.py b/tests/utils/mock_butler.py index 7f1aa57cc..df0b70e25 100644 --- a/tests/utils/mock_butler.py +++ b/tests/utils/mock_butler.py @@ -1,12 +1,6 @@ -import unittest from unittest import mock -# TODO remove unneeded imports -import os import uuid -import tempfile -import unittest -from unittest import mock from kbmod.standardizers import KBMODV1Config @@ -23,6 +17,9 @@ "DatasetRef", "DatasetId", "dafButler", + "DimensionRecord", + "ConvexPolygon", + "LonLat", ] @@ -80,66 +77,70 @@ def getLat(self): class Box: - def __init__(self, x, y, width, height): - self.x = x - self.y = y - self.width = width - self.height = height + def __init__(self, center): + self.center = center def getCenter(self): - return LonLat(self.x + self.width / 2, self.y + self.height / 2) + return self.center class ConvexPolygon: - def __init__(self, vertices): + def __init__(self, vertices, center=None): self.vertices = vertices + self.center = center def getBoundingBox(self): - x = min([v[0] for v in self.vertices]) - y = min([v[1] for v in self.vertices]) - width = max([v[0] for v in self.vertices]) - x - height = max([v[1] for v in self.vertices]) - y - return Box(x, y, width, height) + return Box(self.center) class DimensionRecord: - def __init__(self, dataId, region, detector): + def __init__( + self, dataId, region, detector, dataset_type="default_type", collection="default_collection" + ): self.dataId = dataId self.region = region self.detector = detector + self.dataset_type = DatasetType(dataset_type) + self.collection = collection class Registry: + + def __init__(self, records=None, **kwargs): + if records is None: + # Create some default records to return + region1 = ConvexPolygon([(0, 0), (0, 1), (1, 1), (1, 0)], LonLat(0.5, 1)) + region2 = ConvexPolygon([(1, 1), (1, 3), (3, 3), (3, 1)], LonLat(0, 0.5)) + records = [ + DimensionRecord(DatasetRef("dataId1"), region1, "fake_detector", "type1", "collection1"), + DimensionRecord(DatasetRef("dataId2"), region2, "fake_detector", "type2", "collection2"), + ] + self.records = records + def getDataset(self, ref): return ref - def queryDimensionRecords(self, type, **kwargs): - region1 = ConvexPolygon([(0, 0), (0, 1), (1, 1), (1, 0)]) - region2 = ConvexPolygon([(1, 1), (1, 3), (3, 3), (3, 1)]) - return [ - DimensionRecord("dataId1", region1, "detector_replace_me"), - DimensionRecord("dataId2", region2, "detector_replace_me"), - ] + def queryDimensionRecords(self, type, datasets=None, **kwargs): + """Query the registry for records of a particular type 'datasets'. Optionally""" + if datasets is None: + return self.records + if isinstance(datasets, DatasetType): + datasets = datasets.name + return [record for record in self.records if record.dataset_type.name == datasets] - # Fix queryCollections def queryCollections(self, **kwargs): - return ["replace_me", "replace_me2"] + """Query the registry for all collections.""" + return [record.collection for record in self.records] def queryDatasetTypes(self, **kwargs): - return [ - DatasetType("dataset_type_replace_me"), - DatasetType("dataset_type_replace_me2"), - DatasetType("dataset_type_replace_me3"), - ] + """Query the registry for all dataset types.""" + return [record.dataset_type for record in self.records] def queryDatasets(self, dataset_type, **kwargs): - return DatasetQueryResults( - [ - DatasetRef("dataset_ref_replace_me"), - DatasetRef("dataset_ref_replace_me2"), - DatasetRef("dataset_ref_replace_me3"), - ] - ) + """Query the registry for all datasets of a particular type.""" + if isinstance(dataset_type, DatasetType): + dataset_type = dataset_type.name + return DatasetQueryResults([r.dataId for r in self.records if r.dataset_type.name == dataset_type]) FitsFactory = DECamImdiffFactory() @@ -165,9 +166,9 @@ class MockButler: attributes can be used to customize the returned arrays. """ - def __init__(self, root, ref=None, mock_images_f=None): + def __init__(self, root, ref=None, mock_images_f=None, registry=None): self.datastore = Datastore(root) - self.registry = Registry() + self.registry = Registry() if registry is None else registry self.mockImages = mock_images_f def getURI(self, ref, dataId=None, collections=None): @@ -268,6 +269,7 @@ class dafButler: them to our mocks. """ + DatasetType = DatasetType DatasetRef = DatasetRef DatasetId = DatasetId Butler = MockButler