Skip to content

Commit

Permalink
chore: change intersection to dask
Browse files Browse the repository at this point in the history
  • Loading branch information
RaczeQ committed May 17, 2024
1 parent 4e7faf8 commit 9e758a5
Show file tree
Hide file tree
Showing 4 changed files with 248 additions and 10 deletions.
201 changes: 200 additions & 1 deletion pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ dependencies = [
"requests",
"polars>=0.19.4",
"rich>=10.11.0",
"dask-geopandas",
"dask<2024.3.0",
]
requires-python = ">=3.9"
readme = "README.md"
Expand Down
43 changes: 38 additions & 5 deletions quackosm/_intersection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,45 @@
from queue import Queue
from time import sleep

import pyarrow as pa
import pyarrow.parquet as pq
from shapely import Point, STRtree
from shapely.geometry.base import BaseGeometry
import dask

from quackosm._rich_progress import TaskProgressBar # type: ignore[attr-defined]
# TODO: update dask and dask-geopandas after https://github.com/geopandas/dask-geopandas/issues/284

Check notice on line 8 in quackosm/_intersection.py

View check run for this annotation

codefactor.io / CodeFactor

quackosm/_intersection.py#L8

unresolved comment '# TODO: update dask and dask-geopandas after https://github.com/geopandas/dask-geopandas/issues/284' (C100)
dask.config.set({"dataframe.query-planning-warning": False})

import dask_geopandas # noqa: E402
import pyarrow as pa # noqa: E402
import pyarrow.parquet as pq # noqa: E402
from shapely import Point, STRtree # noqa: E402
from shapely.geometry.base import BaseGeometry # noqa: E402

from quackosm._rich_progress import TaskProgressBar # type: ignore[attr-defined] # noqa: E402

rows_per_partition = 100_000


def intersect_nodes_with_geometry_dask(tmp_dir_path: Path, geometry_filter: BaseGeometry) -> None:
"""
Intersects nodes points with geometry filter using spatial index with dask geopandas.
Args:
tmp_dir_path (Path): Path of the working directory.
geometry_filter (BaseGeometry): Geometry used for filtering.
"""
nodes_intersecting_path = tmp_dir_path / "nodes_intersecting_ids"
nodes_intersecting_path.mkdir(parents=True, exist_ok=True)

pq_ds = pq.ParquetDataset(tmp_dir_path / "nodes_valid_with_tags")
total_rows = sum(frag.count_rows() for frag in pq_ds.fragments)

ddf = dask.dataframe.read_parquet(
tmp_dir_path / "nodes_valid_with_tags",
columns=["id", "lon", "lat"],
).repartition(npartitions=total_rows // rows_per_partition)
ddf = dask_geopandas.from_dask_dataframe(
ddf, geometry=dask_geopandas.points_from_xy(ddf, "lon", "lat")
)
ddf = ddf[ddf.within(geometry_filter)]
ddf[["id"]].to_parquet(nodes_intersecting_path)


def _intersection_worker(
Expand Down
12 changes: 8 additions & 4 deletions quackosm/pbf_file_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

from quackosm._constants import FEATURES_INDEX, GEOMETRY_COLUMN, WGS84_CRS
from quackosm._exceptions import EmptyResultWarning, InvalidGeometryFilter
from quackosm._intersection import intersect_nodes_with_geometry
from quackosm._intersection import intersect_nodes_with_geometry_dask
from quackosm._osm_tags_filters import (
GroupedOsmTagsFilter,
OsmTagsFilter,
Expand Down Expand Up @@ -1141,13 +1141,17 @@ def _prefilter_elements_ids(
# - select all from NI with tags filter
filter_osm_node_ids_filter = self._generate_elements_filter(filter_osm_ids, "node")
if is_intersecting:
with self.task_progress_tracker.get_bar("Filtering nodes - intersection") as bar:
intersect_nodes_with_geometry(
with self.task_progress_tracker.get_spinner("Filtering nodes - intersection"):
intersect_nodes_with_geometry_dask(
tmp_dir_path=self.tmp_dir_path,
geometry_filter=self.geometry_filter,
progress_bar=bar,
)

if self.debug_memory:
log_message(
f'Saved to directory: {self.tmp_dir_path / "nodes_intersecting_ids"}'
)

nodes_intersecting_ids = self.connection.read_parquet(
str(self.tmp_dir_path / "nodes_intersecting_ids" / "*.parquet")
)
Expand Down

0 comments on commit 9e758a5

Please sign in to comment.