Skip to content

Commit

Permalink
refractor: add multiprocessing to STRtree method
Browse files Browse the repository at this point in the history
  • Loading branch information
RaczeQ committed May 16, 2024
1 parent 6123e27 commit 78caa4e
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 132 deletions.
28 changes: 1 addition & 27 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ dependencies = [
"requests",
"polars>=0.19.4",
"rich>=10.11.0",
"fast-crossing",
]
requires-python = ">=3.9"
readme = "README.md"
Expand Down
96 changes: 96 additions & 0 deletions quackosm/_intersection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import multiprocessing
from pathlib import Path
from queue import Queue
from time import sleep

import pyarrow as pa
import pyarrow.parquet as pq
from shapely import Point, STRtree
from shapely.geometry.base import BaseGeometry

from quackosm._rich_progress import TaskProgressBar # type: ignore[attr-defined]


def _intersection_worker(
queue: Queue[tuple[str, int]], save_path: Path, geometry_filter: BaseGeometry
) -> None:
current_pid = multiprocessing.current_process().pid

Check warning on line 17 in quackosm/_intersection.py

View check run for this annotation

Codecov / codecov/patch

quackosm/_intersection.py#L17

Added line #L17 was not covered by tests

filepath = save_path / f"{current_pid}.parquet"
writer = None
while not queue.empty():
try:
file_name, row_group_index = queue.get(block=True, timeout=1)

Check warning on line 23 in quackosm/_intersection.py

View check run for this annotation

Codecov / codecov/patch

quackosm/_intersection.py#L19-L23

Added lines #L19 - L23 were not covered by tests

pq_file = pq.ParquetFile(file_name)
row_group_table = pq_file.read_row_group(row_group_index, ["id", "lat", "lon"])
if len(row_group_table) == 0:
continue

Check warning on line 28 in quackosm/_intersection.py

View check run for this annotation

Codecov / codecov/patch

quackosm/_intersection.py#L25-L28

Added lines #L25 - L28 were not covered by tests

tree = STRtree(

Check warning on line 30 in quackosm/_intersection.py

View check run for this annotation

Codecov / codecov/patch

quackosm/_intersection.py#L30

Added line #L30 was not covered by tests
[
Point(lon.as_py(), lat.as_py())
for lon, lat in zip(row_group_table["lon"], row_group_table["lat"])
]
)

intersecting_ids_array = row_group_table["id"].take(

Check warning on line 37 in quackosm/_intersection.py

View check run for this annotation

Codecov / codecov/patch

quackosm/_intersection.py#L37

Added line #L37 was not covered by tests
tree.query(geometry_filter, predicate="intersects")
)

table = pa.table({"id": intersecting_ids_array})

Check warning on line 41 in quackosm/_intersection.py

View check run for this annotation

Codecov / codecov/patch

quackosm/_intersection.py#L41

Added line #L41 was not covered by tests

if not writer:
writer = pq.ParquetWriter(filepath, table.schema)

Check warning on line 44 in quackosm/_intersection.py

View check run for this annotation

Codecov / codecov/patch

quackosm/_intersection.py#L43-L44

Added lines #L43 - L44 were not covered by tests

writer.write_table(table)
except Exception as ex:
print(ex)

Check warning on line 48 in quackosm/_intersection.py

View check run for this annotation

Codecov / codecov/patch

quackosm/_intersection.py#L46-L48

Added lines #L46 - L48 were not covered by tests

if writer:
writer.close()

Check warning on line 51 in quackosm/_intersection.py

View check run for this annotation

Codecov / codecov/patch

quackosm/_intersection.py#L50-L51

Added lines #L50 - L51 were not covered by tests


def intersect_nodes_with_geometry(
tmp_dir_path: Path, geometry_filter: BaseGeometry, progress_bar: TaskProgressBar
) -> None:
"""
Intersects nodes points with geometry filter using spatial index with multiprocessing.
Args:
tmp_dir_path (Path): Path of the working directory.
geometry_filter (BaseGeometry): Geometry used for filtering.
progress_bar (TaskProgressBar): Progress bar to show task status.
"""
manager = multiprocessing.Manager()
queue: Queue[tuple[str, int]] = manager.Queue()

Check failure on line 66 in quackosm/_intersection.py

View workflow job for this annotation

GitHub Actions / Run pre-commit manual stage

Refurb FURB184

Assignment statement should be chained

dataset = pq.ParquetDataset(tmp_dir_path / "nodes_valid_with_tags")

for pq_file in dataset.files:
for row_group in range(pq.ParquetFile(pq_file).num_row_groups):
queue.put((pq_file, row_group))

total = queue.qsize()

nodes_intersecting_path = tmp_dir_path / "nodes_intersecting_ids"
nodes_intersecting_path.mkdir(parents=True, exist_ok=True)

processes = [
multiprocessing.Process(
target=_intersection_worker,
args=(queue, nodes_intersecting_path, geometry_filter),
)
for _ in range(multiprocessing.cpu_count())
]

# Run processes
for p in processes:
p.start()

progress_bar.create_manual_bar(total=total)
while any(process.is_alive() for process in processes):
progress_bar.update_manual_bar(current_progress=total - queue.qsize())
sleep(1)

progress_bar.update_manual_bar(current_progress=total)
63 changes: 35 additions & 28 deletions quackosm/_rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,38 +147,39 @@ def __init__(
self.progress_cls = progress_cls
self.live_obj = live_obj

def _create_progress(self):
columns = [
SpinnerColumn(),
TextColumn(self.step_number),
TextColumn(
"[progress.description]{task.description}"
" [progress.percentage]{task.percentage:>3.0f}%"
),
BarColumn(),
MofNCompleteColumn(),
TextColumn("•"),
TimeElapsedColumn(),
TextColumn("<"),
TimeRemainingColumn(),
TextColumn("•"),
SpeedColumn(),
]

if self.skip_step_number:
columns.pop(1)

Check warning on line 169 in quackosm/_rich_progress.py

View check run for this annotation

Codecov / codecov/patch

quackosm/_rich_progress.py#L169

Added line #L169 was not covered by tests

self.progress = self.progress_cls(
*columns,
live_obj=self.live_obj,
transient=self.transient_mode,
speed_estimate_period=1800,
)

def __enter__(self):
if self.silent_mode:
self.progress = None
else:

columns = [
SpinnerColumn(),
TextColumn(self.step_number),
TextColumn(
"[progress.description]{task.description}"
" [progress.percentage]{task.percentage:>3.0f}%"
),
BarColumn(),
MofNCompleteColumn(),
TextColumn("•"),
TimeElapsedColumn(),
TextColumn("<"),
TimeRemainingColumn(),
TextColumn("•"),
SpeedColumn(),
]

if self.skip_step_number:
columns.pop(1)

self.progress = self.progress_cls(
*columns,
live_obj=self.live_obj,
transient=self.transient_mode,
speed_estimate_period=1800,
)

self._create_progress()
self.progress.__enter__()

return self
Expand All @@ -189,6 +190,12 @@ def __exit__(self, exc_type, exc_value, exc_tb):

self.progress = None

def create_manual_bar(self, total: int):
self.progress.add_task(description=self.step_name, total=total)

def update_manual_bar(self, current_progress: int):
self.progress.update(task_id=self.progress.task_ids[0], completed=current_progress)

def track(self, iterable: Iterable):
if self.progress is not None:
for i in self.progress.track(list(iterable), description=self.step_name):
Expand Down
89 changes: 13 additions & 76 deletions quackosm/pbf_file_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,17 @@
import polars as pl
import psutil
import pyarrow as pa
import pyarrow.dataset as ds
import pyarrow.parquet as pq
import shapely.wkt as wktlib
from geoarrow.pyarrow import io
from pandas.util._decorators import deprecate, deprecate_kwarg
from pyarrow_ops import drop_duplicates
from shapely import STRtree
from shapely.geometry import LinearRing, Point, Polygon
from shapely.geometry import LinearRing, Polygon
from shapely.geometry.base import BaseGeometry, BaseMultipartGeometry

from quackosm._constants import FEATURES_INDEX, GEOMETRY_COLUMN, WGS84_CRS
from quackosm._exceptions import EmptyResultWarning, InvalidGeometryFilter
from quackosm._intersection import intersect_nodes_with_geometry
from quackosm._osm_tags_filters import (
GroupedOsmTagsFilter,
OsmTagsFilter,
Expand Down Expand Up @@ -577,6 +576,9 @@ def convert_geometry_to_parquet(
ignore_cache=ignore_cache,
filter_osm_ids=filter_osm_ids,
save_as_wkt=save_as_wkt,
pbf_extract_geometry=[
matching_extract.geometry for matching_extract in matching_extracts
],
)

@deprecate_kwarg(old_arg_name="file_paths", new_arg_name="pbf_path") # type: ignore
Expand Down Expand Up @@ -1141,75 +1143,15 @@ def _prefilter_elements_ids(
filter_osm_node_ids_filter = self._generate_elements_filter(filter_osm_ids, "node")
if is_intersecting:
with self.task_progress_tracker.get_bar("Filtering nodes - intersection") as bar:
# if isinstance(self.geometry_filter, BaseMultipartGeometry):
# geometry_filter_geoms = list(self.geometry_filter.geoms)
# else:
# geometry_filter_geoms = [self.geometry_filter]

# polygons_points = [
# list(geometry_filter_geom.exterior.coords)
# for geometry_filter_geom in geometry_filter_geoms
# ]

pq_dataset = ds.dataset(self.tmp_dir_path / "nodes_valid_with_tags")

writer = None
for batch in bar.track(pq_dataset.to_batches()):
if len(batch) == 0:
continue

ids = batch["id"]
# points = [
# (lon.as_py(), lat.as_py()) for lon, lat in zip(batch["lon"], batch["lat"])
# ]

tree = STRtree(
[
Point(lon.as_py(), lat.as_py())
for lon, lat in zip(batch["lon"], batch["lat"])
]
)

# tree = STRtree([GeocodeGeometryParser().convert("Monaco-Ville, Monaco")])

intersecting_ids_array = ids.take(
tree.query(self.geometry_filter, predicate="intersects")
)

# mask = contains_xy(self.geometry_filter, x=batch["lon"], y=batch["lat"])

# pool.imap(f_wrapped, zip(da, repeat(db))), total=len(da)
# with Pool() as pool:
# masks = pool.map(
# partial(_check_points_in_polygon, points=points), polygons_points
# )

# masks = [
# point_in_polygon(points=points, polygon=polygon_points) == 1
# for polygon_points in polygons_points
# ]

# total_mask = reduce(np.logical_or, masks)
# intersecting_ids_array = ids.filter(pa.array(mask))
# intersecting_ids_array = ids.filter(pa.array(mask))
intersecting_ids_batch = pa.RecordBatch.from_arrays(
[intersecting_ids_array], names=["id"]
)
if not writer:
nodes_intersecting_path = (
self.tmp_dir_path / "nodes_intersecting_ids" / "data.parquet"
)
nodes_intersecting_path.parent.mkdir(parents=True, exist_ok=True)
writer = pq.ParquetWriter(
nodes_intersecting_path,
intersecting_ids_batch.schema,
)

writer.write_batch(intersecting_ids_batch)
if writer:
writer.close()
intersect_nodes_with_geometry(
tmp_dir_path=self.tmp_dir_path,
geometry_filter=self.geometry_filter,
progress_bar=bar,
)

nodes_intersecting_ids = self.connection.read_parquet(str(nodes_intersecting_path))
nodes_intersecting_ids = self.connection.read_parquet(
str(self.tmp_dir_path / "nodes_intersecting_ids" / "*.parquet")
)

with self.task_progress_tracker.get_spinner("Filtering nodes - tags"):
self._sql_to_parquet_file(
Expand Down Expand Up @@ -2679,8 +2621,3 @@ def _group_ways_with_polars(current_ways_group_path: Path, current_destination_p
).write_parquet(
current_destination_path
)


# def _check_points_in_polygon(polygon, points)-> np.ndarray[bool]:
# mask = point_in_polygon(points, polygon) == 1
# return mask

0 comments on commit 78caa4e

Please sign in to comment.