Skip to content

Commit

Permalink
refactor: change geometry filtering step
Browse files Browse the repository at this point in the history
  • Loading branch information
RaczeQ committed May 15, 2024
1 parent cefe055 commit 5193f90
Showing 1 changed file with 53 additions and 3 deletions.
56 changes: 53 additions & 3 deletions quackosm/pbf_file_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def convert_pbf_to_parquet(
ignore_cache: bool = False,
filter_osm_ids: Optional[list[str]] = None,
save_as_wkt: bool = False,
pbf_extract_geometry: Optional[Union[BaseGeometry, Iterable[BaseGeometry]]] = None,
) -> Path:
"""
Convert PBF file to GeoParquet file.
Expand All @@ -268,6 +269,9 @@ def convert_pbf_to_parquet(
save_as_wkt (bool): Whether to save the file with geometry in the WKT form instead
of WKB. If `True`, it will be saved as a `.parquet` file, because it won't be
in the GeoParquet standard. Defaults to `False`.
pbf_extract_geometry (Optional[Union[BaseGeometry, Iterable[BaseGeometry]]], optional):
List of geometries defining PBF extract. Used internally to speed up intersections
for complex filters. Defaults to `None`.
Returns:
Path: Path to the generated GeoParquet file.
Expand All @@ -277,6 +281,18 @@ def convert_pbf_to_parquet(
else:
pbf_path = list(pbf_path)

pbf_extract_geometry = None
if pbf_extract_geometry is not None:
if not isinstance(pbf_extract_geometry, BaseGeometry):
pbf_extract_geometry = [pbf_extract_geometry]
else:
if len(pbf_extract_geometry) != len(pbf_path):
raise AttributeError(
"Provided pbf_extract_geometry has a different length "
"than the list of pbf paths."
)
pbf_extract_geometry = list(pbf_extract_geometry)

if filter_osm_ids is None:
filter_osm_ids = []

Expand All @@ -293,6 +309,9 @@ def convert_pbf_to_parquet(
debug=self.debug_memory,
)
if total_files == 1:
single_pbf_extract_geometry = None
if pbf_extract_geometry is not None:
single_pbf_extract_geometry = pbf_extract_geometry[0]
parsed_geoparquet_file = self._convert_single_pbf_to_parquet(
pbf_path[0],
result_file_path=result_file_path,
Expand All @@ -301,6 +320,7 @@ def convert_pbf_to_parquet(
ignore_cache=ignore_cache,
filter_osm_ids=filter_osm_ids,
save_as_wkt=save_as_wkt,
pbf_extract_geometry=single_pbf_extract_geometry,
)
self.task_progress_tracker.stop()
return parsed_geoparquet_file
Expand Down Expand Up @@ -334,13 +354,19 @@ def convert_pbf_to_parquet(

for file_idx, single_pbf_path in enumerate(pbf_path):
self.task_progress_tracker.reset_steps(file_idx + 1)

single_pbf_extract_geometry = None
if pbf_extract_geometry is not None:
single_pbf_extract_geometry = pbf_extract_geometry[file_idx]

parsed_geoparquet_file = self._convert_single_pbf_to_parquet(
single_pbf_path,
keep_all_tags=keep_all_tags,
explode_tags=explode_tags,
ignore_cache=ignore_cache,
filter_osm_ids=filter_osm_ids,
save_as_wkt=save_as_wkt,
pbf_extract_geometry=single_pbf_extract_geometry,
)
parsed_geoparquet_files.append(parsed_geoparquet_file)

Expand Down Expand Up @@ -404,6 +430,7 @@ def _convert_single_pbf_to_parquet(
ignore_cache: bool = False,
filter_osm_ids: Optional[list[str]] = None,
save_as_wkt: bool = False,
pbf_extract_geometry: Optional[BaseGeometry] = None,
) -> Path:
if filter_osm_ids is None:
filter_osm_ids = []
Expand All @@ -422,6 +449,14 @@ def _convert_single_pbf_to_parquet(
try:
self.encountered_query_exception = False
self.connection = _set_up_duckdb_connection(tmp_dir_path=self.tmp_dir_path)

original_geometry_filter = self.geometry_filter

if pbf_extract_geometry is not None:
self.geometry_filter = cast(BaseGeometry, self.geometry_filter).intersection(
cast(BaseGeometry, pbf_extract_geometry)
)

result_file_path = result_file_path or self._generate_result_file_path(
pbf_path,
filter_osm_ids=filter_osm_ids,
Expand All @@ -439,6 +474,8 @@ def _convert_single_pbf_to_parquet(
save_as_wkt=save_as_wkt,
)

self.geometry_filter = original_geometry_filter

return parsed_geoparquet_file
finally:
if self.connection is not None:
Expand Down Expand Up @@ -1101,16 +1138,29 @@ def _prefilter_elements_ids(
# - select all from NI with tags filter
filter_osm_node_ids_filter = self._generate_elements_filter(filter_osm_ids, "node")
if is_intersecting:
wkt = cast(BaseGeometry, self.geometry_filter).wkt
intersection_filter = f"ST_Intersects(ST_Point(lon, lat), ST_GeomFromText('{wkt}'))"
with self.task_progress_tracker.get_spinner("Filtering nodes - intersection"):
if isinstance(self.geometry_filter, BaseMultipartGeometry):
geometry_array = ga.array(
[sub_geom.wkb for sub_geom in self.geometry_filter.geoms]
)
else:
geometry_array = ga.array([cast(BaseGeometry, self.geometry_filter).wkb])

tab = pa.table([geometry_array], names=["geometry"])

self.connection.from_arrow(tab).project(
"ST_GeomFromWKB(geometry) geometry"
).to_table("geometry_filter")

nodes_intersecting_ids = self._sql_to_parquet_file(
sql_query=f"""
SELECT DISTINCT id FROM ({nodes_valid_with_tags.sql_query()}) n
WHERE {intersection_filter} = true
SEMI JOIN geometry_filter gf
ON ST_Intersects(ST_Point(n.lon, n.lat), gf.geometry)
""",
file_path=self.tmp_dir_path / "nodes_intersecting_ids",
)

with self.task_progress_tracker.get_spinner("Filtering nodes - tags"):
self._sql_to_parquet_file(
sql_query=f"""
Expand Down

0 comments on commit 5193f90

Please sign in to comment.