Skip to content

Commit

Permalink
feat: add h3 intersection logic
Browse files Browse the repository at this point in the history
  • Loading branch information
RaczeQ committed May 13, 2024
1 parent cefe055 commit bc3ba4e
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 14 deletions.
20 changes: 10 additions & 10 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dependencies = [
"requests",
"polars>=0.19.4",
"rich>=10.11.0",
"h3ronpy>=0.19.0"
]
requires-python = ">=3.9"
readme = "README.md"
Expand Down
56 changes: 56 additions & 0 deletions quackosm/_h3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import platform
from pathlib import Path

import duckdb
import pyarrow as pa
import pyarrow.parquet as pq
from h3ronpy.arrow.vector import ContainmentMode, wkb_to_cells
from pooch import Decompress, retrieve
from shapely.geometry.base import BaseGeometry


def _transform_geometry_filter_to_h3(
geometry: BaseGeometry, working_directory: Path, h3_resolution: int = 11
) -> Path:
"""Fill geometry filter with H3 polygons and save them in a dedicated parquet file."""
result_file_path = working_directory / "geometry_filter_h3_indexes.parquet"
h3_indexes = wkb_to_cells(
[sub_geometry.wkb for sub_geometry in geometry.geoms],
resolution=h3_resolution,
containment_mode=ContainmentMode.Covers,
flatten=True,
).unique()
pq.write_table(pa.table(dict(h3=h3_indexes)), result_file_path)
return result_file_path


# Based on https://github.com/fusedio/udfs/blob/main/public/DuckDB_H3_Example/utils.py
# Will be changed after H3 extension becomes available in the official repository
def _load_h3_duckdb_extension(con: duckdb.DuckDBPyConnection) -> None:
"""Load H3 DuckDB extension for current system."""
system = platform.system()
arch = platform.machine()
arch = "amd64" if arch == "x86_64" else arch

if system == "Windows":
detected_os = "windows_amd64"
elif system == "Darwin":
detected_os = f"osx_{arch}"
else:
detected_os = f"linux_{arch}"
if detected_os == "linux_amd64":
detected_os = "linux_amd64_gcc4"

url = f"https://pub-cc26a6fd5d8240078bd0c2e0623393a5.r2.dev/v{duckdb.__version__}/{detected_os}/h3ext.duckdb_extension.gz"
# Note this is not the correct file name, it will be fixed later in this function
# This workaround of downloading in Python is needed because DuckDB cannot load extensions
# from https (ssl, secure) URLs.
ungzip_path = retrieve(
url=url,
processor=Decompress(name="h3ext.duckdb_extension"),
known_hash=None,
path="cache/h3_ext",
)

con.sql(f"INSTALL '{ungzip_path}';")
con.sql("LOAD h3ext;")
29 changes: 25 additions & 4 deletions quackosm/pbf_file_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

from quackosm._constants import FEATURES_INDEX, GEOMETRY_COLUMN, WGS84_CRS
from quackosm._exceptions import EmptyResultWarning, InvalidGeometryFilter
from quackosm._h3 import _load_h3_duckdb_extension, _transform_geometry_filter_to_h3
from quackosm._osm_tags_filters import (
GroupedOsmTagsFilter,
OsmTagsFilter,
Expand Down Expand Up @@ -110,6 +111,7 @@ def __init__(
] = None,
parquet_compression: str = "snappy",
osm_extract_source: Union[OsmExtractSource, str] = OsmExtractSource.geofabrik,
geometry_filter_intersection_h3_resolution: int = 10,
verbosity_mode: Literal["silent", "transient", "verbose"] = "transient",
allow_uncovered_geometry: bool = False,
debug_memory: bool = False,
Expand Down Expand Up @@ -143,6 +145,11 @@ def __init__(
osm_extract_source (Union[OsmExtractSource, str], optional): A source for automatic
downloading of OSM extracts. Can be Geofabrik, BBBike, OSMfr or any.
Defaults to `geofabrik`.
geometry_filter_intersection_h3_resolution (int, optional): A resolution used for faster
intersections. First H3 cells are generated to cover given geometry and then
intersection is done based on H3 indexes, not using ST_Intersects. Higher resolution
will result in finer approximation of the original geometry, but might require more
memory during operations. Defaults to 11.
verbosity_mode (Literal["silent", "transient", "verbose"], optional): Set progress
verbosity mode. Can be one of: silent, transient and verbose. Silent disables
output completely. Transient tracks progress, but removes output after finished.
Expand All @@ -158,6 +165,7 @@ def __init__(
InvalidGeometryFilter: When provided geometry filter has parts without area.
"""
self.geometry_filter = geometry_filter
self.geometry_filter_intersection_h3_resolution = geometry_filter_intersection_h3_resolution
self._check_if_valid_geometry_filter()

self.tags_filter = tags_filter
Expand Down Expand Up @@ -1101,16 +1109,27 @@ def _prefilter_elements_ids(
# - select all from NI with tags filter
filter_osm_node_ids_filter = self._generate_elements_filter(filter_osm_ids, "node")
if is_intersecting:
wkt = cast(BaseGeometry, self.geometry_filter).wkt
intersection_filter = f"ST_Intersects(ST_Point(lon, lat), ST_GeomFromText('{wkt}'))"
h3_cells_path = _transform_geometry_filter_to_h3(
geometry=self.geometry_filter,
working_directory=self.tmp_dir_path,
h3_resolution=self.geometry_filter_intersection_h3_resolution,
)
if self.debug_memory:
log_message(f"Saved to file: {h3_cells_path}")

with self.task_progress_tracker.get_spinner("Filtering nodes - intersection"):
nodes_intersecting_ids = self._sql_to_parquet_file(
sql_query=f"""
SELECT DISTINCT id FROM ({nodes_valid_with_tags.sql_query()}) n
WHERE {intersection_filter} = true
SEMI JOIN '{h3_cells_path}'
ON h3_latlng_to_cell(
lat, lon, {self.geometry_filter_intersection_h3_resolution}
) = h3
""",
file_path=self.tmp_dir_path / "nodes_intersecting_ids",
)

self._delete_directories(h3_cells_path)
with self.task_progress_tracker.get_spinner("Filtering nodes - tags"):
self._sql_to_parquet_file(
sql_query=f"""
Expand Down Expand Up @@ -2534,7 +2553,7 @@ def _set_up_duckdb_connection(
local_db_file = "db.duckdb" if is_main_connection else f"{secrets.token_hex(16)}.duckdb"
connection = duckdb.connect(
database=str(tmp_dir_path / local_db_file),
config=dict(preserve_insertion_order=False),
config=dict(preserve_insertion_order=False, allow_unsigned_extensions=True),
)
connection.sql("SET enable_progress_bar = false;")
connection.sql("SET enable_progress_bar_print = false;")
Expand All @@ -2543,6 +2562,8 @@ def _set_up_duckdb_connection(
for extension_name in ("parquet", "spatial"):
connection.load_extension(extension_name)

_load_h3_duckdb_extension(connection)

connection.sql(
"""
CREATE OR REPLACE MACRO linestring_to_linestring_geometry(ls) AS
Expand Down

0 comments on commit bc3ba4e

Please sign in to comment.