From 78caa4efd546e788c79ed59f14cbc4969ab3128e Mon Sep 17 00:00:00 2001 From: Kamil Raczycki Date: Thu, 16 May 2024 08:47:41 +0200 Subject: [PATCH] refractor: add multiprocessing to STRtree method --- pdm.lock | 28 +---------- pyproject.toml | 1 - quackosm/_intersection.py | 96 +++++++++++++++++++++++++++++++++++++ quackosm/_rich_progress.py | 63 +++++++++++++----------- quackosm/pbf_file_reader.py | 89 +++++----------------------------- 5 files changed, 145 insertions(+), 132 deletions(-) create mode 100644 quackosm/_intersection.py diff --git a/pdm.lock b/pdm.lock index 0a79bc3..31bbe14 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev", "docs", "license", "lint", "test", "cli", "cli-dev"] strategy = ["cross_platform"] lock_version = "4.4.1" -content_hash = "sha256:9d9d20905512cf957a5128e77b5860495fcb163a9b5f0c33fe9691088324e62b" +content_hash = "sha256:18fee9a786c0aec34cb18f9b7237e359068dbb412a2efb5c8a42d2b428875cad" [[package]] name = "adjusttext" @@ -731,32 +731,6 @@ files = [ {file = "executing-2.0.1.tar.gz", hash = "sha256:35afe2ce3affba8ee97f2d69927fa823b08b472b7b994e36a52a964b93d16147"}, ] -[[package]] -name = "fast-crossing" -version = "0.0.8" -summary = "fast crossing" -dependencies = [ - "numpy", -] -files = [ - {file = "fast_crossing-0.0.8-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:831c7939ff9698f3cbea5d2f926d6e71b550204110bf5123e0789f2a388de992"}, - {file = "fast_crossing-0.0.8-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5fb7f16676aa46dc2f83a46460f67d5e112b2ed92cc6a9b37d4bc782e24caed4"}, - {file = "fast_crossing-0.0.8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:de3680339f3929819fe7a074c9c83e4e4516ae4ef502b504ce811375bc8c1a6d"}, - {file = "fast_crossing-0.0.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3f758621425fe2fdac51cf11878d77480afc6c95e67950b1645a1dbd13437ad"}, - {file = "fast_crossing-0.0.8-cp310-cp310-win_amd64.whl", hash = "sha256:14816b6bfbe5496291b7623056d612c48f92d3aeab9efdcbd646a774885ddb9c"}, - {file = "fast_crossing-0.0.8-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:086ffe4c34bc5132b8d28387194261197f981cec6aa5060c6c3800e3fa4317f3"}, - {file = "fast_crossing-0.0.8-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ff1e0bb4f8f8af288399165e8f6a5bc4e8e2bdcac41c7344f4f93049ee09e3ca"}, - {file = "fast_crossing-0.0.8-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a2897b42103f0c00d38328a59c1a35f00be47534ff093db5ae58566d1fd66434"}, - {file = "fast_crossing-0.0.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f74f99083941290557d9334ef0abdb2cea65efd5f48b7928f1a600ba873c3a1"}, - {file = "fast_crossing-0.0.8-cp311-cp311-win_amd64.whl", hash = "sha256:4d9a48759c6784f78f04c5139434628c25e6571bee791744bbc24bab2e29639d"}, - {file = "fast_crossing-0.0.8-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a6031f0b0c9cbe24b373a4f9e1538da715e28df6aca65af196762224ec99a09e"}, - {file = "fast_crossing-0.0.8-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:19a2b135b4c65c5e02034b4635949d140eb7713941c8401498b177217f4ee76a"}, - {file = "fast_crossing-0.0.8-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fbeece624113f428da6d15eeb234cb6985e7d44e0206ececda23e473d4269607"}, - {file = "fast_crossing-0.0.8-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f9d03bba21b6411a538267b86b95491e177cc7f17c865a1af1e8eb425554ac0"}, - {file = "fast_crossing-0.0.8-cp39-cp39-win_amd64.whl", hash = "sha256:f219abe7cd33d031e319434984e1f1dd15815d57e95b5665b7fec1e710507967"}, - {file = "fast_crossing-0.0.8.tar.gz", hash = "sha256:37c3d6c0e75c51e3e9ad4fb0b6a26dbe0f4ce5879f1913994bd638090860b846"}, -] - [[package]] name = "fastjsonschema" version = "2.19.1" diff --git a/pyproject.toml b/pyproject.toml index 85d7b51..a73c7b0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,6 @@ dependencies = [ "requests", "polars>=0.19.4", "rich>=10.11.0", - "fast-crossing", ] requires-python = ">=3.9" readme = "README.md" diff --git a/quackosm/_intersection.py b/quackosm/_intersection.py new file mode 100644 index 0000000..8b2f0ff --- /dev/null +++ b/quackosm/_intersection.py @@ -0,0 +1,96 @@ +import multiprocessing +from pathlib import Path +from queue import Queue +from time import sleep + +import pyarrow as pa +import pyarrow.parquet as pq +from shapely import Point, STRtree +from shapely.geometry.base import BaseGeometry + +from quackosm._rich_progress import TaskProgressBar # type: ignore[attr-defined] + + +def _intersection_worker( + queue: Queue[tuple[str, int]], save_path: Path, geometry_filter: BaseGeometry +) -> None: + current_pid = multiprocessing.current_process().pid + + filepath = save_path / f"{current_pid}.parquet" + writer = None + while not queue.empty(): + try: + file_name, row_group_index = queue.get(block=True, timeout=1) + + pq_file = pq.ParquetFile(file_name) + row_group_table = pq_file.read_row_group(row_group_index, ["id", "lat", "lon"]) + if len(row_group_table) == 0: + continue + + tree = STRtree( + [ + Point(lon.as_py(), lat.as_py()) + for lon, lat in zip(row_group_table["lon"], row_group_table["lat"]) + ] + ) + + intersecting_ids_array = row_group_table["id"].take( + tree.query(geometry_filter, predicate="intersects") + ) + + table = pa.table({"id": intersecting_ids_array}) + + if not writer: + writer = pq.ParquetWriter(filepath, table.schema) + + writer.write_table(table) + except Exception as ex: + print(ex) + + if writer: + writer.close() + + +def intersect_nodes_with_geometry( + tmp_dir_path: Path, geometry_filter: BaseGeometry, progress_bar: TaskProgressBar +) -> None: + """ + Intersects nodes points with geometry filter using spatial index with multiprocessing. + + Args: + tmp_dir_path (Path): Path of the working directory. + geometry_filter (BaseGeometry): Geometry used for filtering. + progress_bar (TaskProgressBar): Progress bar to show task status. + """ + manager = multiprocessing.Manager() + queue: Queue[tuple[str, int]] = manager.Queue() + + dataset = pq.ParquetDataset(tmp_dir_path / "nodes_valid_with_tags") + + for pq_file in dataset.files: + for row_group in range(pq.ParquetFile(pq_file).num_row_groups): + queue.put((pq_file, row_group)) + + total = queue.qsize() + + nodes_intersecting_path = tmp_dir_path / "nodes_intersecting_ids" + nodes_intersecting_path.mkdir(parents=True, exist_ok=True) + + processes = [ + multiprocessing.Process( + target=_intersection_worker, + args=(queue, nodes_intersecting_path, geometry_filter), + ) + for _ in range(multiprocessing.cpu_count()) + ] + + # Run processes + for p in processes: + p.start() + + progress_bar.create_manual_bar(total=total) + while any(process.is_alive() for process in processes): + progress_bar.update_manual_bar(current_progress=total - queue.qsize()) + sleep(1) + + progress_bar.update_manual_bar(current_progress=total) diff --git a/quackosm/_rich_progress.py b/quackosm/_rich_progress.py index ca15036..d210bea 100644 --- a/quackosm/_rich_progress.py +++ b/quackosm/_rich_progress.py @@ -147,38 +147,39 @@ def __init__( self.progress_cls = progress_cls self.live_obj = live_obj + def _create_progress(self): + columns = [ + SpinnerColumn(), + TextColumn(self.step_number), + TextColumn( + "[progress.description]{task.description}" + " [progress.percentage]{task.percentage:>3.0f}%" + ), + BarColumn(), + MofNCompleteColumn(), + TextColumn("•"), + TimeElapsedColumn(), + TextColumn("<"), + TimeRemainingColumn(), + TextColumn("•"), + SpeedColumn(), + ] + + if self.skip_step_number: + columns.pop(1) + + self.progress = self.progress_cls( + *columns, + live_obj=self.live_obj, + transient=self.transient_mode, + speed_estimate_period=1800, + ) + def __enter__(self): if self.silent_mode: self.progress = None else: - - columns = [ - SpinnerColumn(), - TextColumn(self.step_number), - TextColumn( - "[progress.description]{task.description}" - " [progress.percentage]{task.percentage:>3.0f}%" - ), - BarColumn(), - MofNCompleteColumn(), - TextColumn("•"), - TimeElapsedColumn(), - TextColumn("<"), - TimeRemainingColumn(), - TextColumn("•"), - SpeedColumn(), - ] - - if self.skip_step_number: - columns.pop(1) - - self.progress = self.progress_cls( - *columns, - live_obj=self.live_obj, - transient=self.transient_mode, - speed_estimate_period=1800, - ) - + self._create_progress() self.progress.__enter__() return self @@ -189,6 +190,12 @@ def __exit__(self, exc_type, exc_value, exc_tb): self.progress = None + def create_manual_bar(self, total: int): + self.progress.add_task(description=self.step_name, total=total) + + def update_manual_bar(self, current_progress: int): + self.progress.update(task_id=self.progress.task_ids[0], completed=current_progress) + def track(self, iterable: Iterable): if self.progress is not None: for i in self.progress.track(list(iterable), description=self.step_name): diff --git a/quackosm/pbf_file_reader.py b/quackosm/pbf_file_reader.py index cc8e02b..52097f5 100644 --- a/quackosm/pbf_file_reader.py +++ b/quackosm/pbf_file_reader.py @@ -25,18 +25,17 @@ import polars as pl import psutil import pyarrow as pa -import pyarrow.dataset as ds import pyarrow.parquet as pq import shapely.wkt as wktlib from geoarrow.pyarrow import io from pandas.util._decorators import deprecate, deprecate_kwarg from pyarrow_ops import drop_duplicates -from shapely import STRtree -from shapely.geometry import LinearRing, Point, Polygon +from shapely.geometry import LinearRing, Polygon from shapely.geometry.base import BaseGeometry, BaseMultipartGeometry from quackosm._constants import FEATURES_INDEX, GEOMETRY_COLUMN, WGS84_CRS from quackosm._exceptions import EmptyResultWarning, InvalidGeometryFilter +from quackosm._intersection import intersect_nodes_with_geometry from quackosm._osm_tags_filters import ( GroupedOsmTagsFilter, OsmTagsFilter, @@ -577,6 +576,9 @@ def convert_geometry_to_parquet( ignore_cache=ignore_cache, filter_osm_ids=filter_osm_ids, save_as_wkt=save_as_wkt, + pbf_extract_geometry=[ + matching_extract.geometry for matching_extract in matching_extracts + ], ) @deprecate_kwarg(old_arg_name="file_paths", new_arg_name="pbf_path") # type: ignore @@ -1141,75 +1143,15 @@ def _prefilter_elements_ids( filter_osm_node_ids_filter = self._generate_elements_filter(filter_osm_ids, "node") if is_intersecting: with self.task_progress_tracker.get_bar("Filtering nodes - intersection") as bar: - # if isinstance(self.geometry_filter, BaseMultipartGeometry): - # geometry_filter_geoms = list(self.geometry_filter.geoms) - # else: - # geometry_filter_geoms = [self.geometry_filter] - - # polygons_points = [ - # list(geometry_filter_geom.exterior.coords) - # for geometry_filter_geom in geometry_filter_geoms - # ] - - pq_dataset = ds.dataset(self.tmp_dir_path / "nodes_valid_with_tags") - - writer = None - for batch in bar.track(pq_dataset.to_batches()): - if len(batch) == 0: - continue - - ids = batch["id"] - # points = [ - # (lon.as_py(), lat.as_py()) for lon, lat in zip(batch["lon"], batch["lat"]) - # ] - - tree = STRtree( - [ - Point(lon.as_py(), lat.as_py()) - for lon, lat in zip(batch["lon"], batch["lat"]) - ] - ) - - # tree = STRtree([GeocodeGeometryParser().convert("Monaco-Ville, Monaco")]) - - intersecting_ids_array = ids.take( - tree.query(self.geometry_filter, predicate="intersects") - ) - - # mask = contains_xy(self.geometry_filter, x=batch["lon"], y=batch["lat"]) - - # pool.imap(f_wrapped, zip(da, repeat(db))), total=len(da) - # with Pool() as pool: - # masks = pool.map( - # partial(_check_points_in_polygon, points=points), polygons_points - # ) - - # masks = [ - # point_in_polygon(points=points, polygon=polygon_points) == 1 - # for polygon_points in polygons_points - # ] - - # total_mask = reduce(np.logical_or, masks) - # intersecting_ids_array = ids.filter(pa.array(mask)) - # intersecting_ids_array = ids.filter(pa.array(mask)) - intersecting_ids_batch = pa.RecordBatch.from_arrays( - [intersecting_ids_array], names=["id"] - ) - if not writer: - nodes_intersecting_path = ( - self.tmp_dir_path / "nodes_intersecting_ids" / "data.parquet" - ) - nodes_intersecting_path.parent.mkdir(parents=True, exist_ok=True) - writer = pq.ParquetWriter( - nodes_intersecting_path, - intersecting_ids_batch.schema, - ) - - writer.write_batch(intersecting_ids_batch) - if writer: - writer.close() + intersect_nodes_with_geometry( + tmp_dir_path=self.tmp_dir_path, + geometry_filter=self.geometry_filter, + progress_bar=bar, + ) - nodes_intersecting_ids = self.connection.read_parquet(str(nodes_intersecting_path)) + nodes_intersecting_ids = self.connection.read_parquet( + str(self.tmp_dir_path / "nodes_intersecting_ids" / "*.parquet") + ) with self.task_progress_tracker.get_spinner("Filtering nodes - tags"): self._sql_to_parquet_file( @@ -2679,8 +2621,3 @@ def _group_ways_with_polars(current_ways_group_path: Path, current_destination_p ).write_parquet( current_destination_path ) - - -# def _check_points_in_polygon(polygon, points)-> np.ndarray[bool]: -# mask = point_in_polygon(points, polygon) == 1 -# return mask