From ccdd6bbae55e0ece25a1f8f34dd54b7f99718674 Mon Sep 17 00:00:00 2001 From: Olga Bulat Date: Tue, 21 Mar 2023 15:40:29 +0300 Subject: [PATCH] Use context manager with multiprocessing pool --- ingestion_server/ingestion_server/cleanup.py | 34 +++++++------------- 1 file changed, 11 insertions(+), 23 deletions(-) diff --git a/ingestion_server/ingestion_server/cleanup.py b/ingestion_server/ingestion_server/cleanup.py index 63990de7dfe..731a1009b9f 100644 --- a/ingestion_server/ingestion_server/cleanup.py +++ b/ingestion_server/ingestion_server/cleanup.py @@ -266,23 +266,12 @@ def _clean_data_worker(rows, temp_table, sources_config, all_fields: list[str]): return cleaned_values -def save_cleaned_data(results): +def save_cleaned_data(result: dict) -> dict[str, int]: log.info("Saving cleaned data...") start_time = time.time() - results_to_save: dict[str, list[tuple[str, str | Json]]] = {} - # Results is a list of dicts, where each dict is a mapping of field name to - # a list of tuples of (identifier, cleaned_value). There are as many dicts - # as there are workers. We need to merge the lists of tuples for each field - # name. - for result in results: - for field, values in result.items(): - if field not in results_to_save: - results_to_save[field] = [] - results_to_save[field].extend(values) - cleanup_counts = {} - for field, cleaned_items in results_to_save.items(): - cleanup_counts[field] = len(cleaned_items) if cleaned_items else 0 + cleanup_counts = {field: len(items) for field, items in result.items()} + for field, cleaned_items in result.items(): if cleaned_items: with open(f"{field}.tsv", "a") as f: csv_writer = csv.writer(f, delimiter="\t") @@ -290,7 +279,7 @@ def save_cleaned_data(results): end_time = time.time() total_time = end_time - start_time - log.info(f"Finished saving cleaned data in {total_time}") + log.info(f"Finished saving cleaned data in {total_time},\n{cleanup_counts}") return cleanup_counts @@ -362,14 +351,13 @@ def clean_image_data(table): cleanable_fields_for_table, ) ) - pool = multiprocessing.Pool(processes=num_workers) - log.info(f"Starting {len(jobs)} cleaning jobs") - conn.commit() - results = pool.starmap(_clean_data_worker, jobs) - batch_cleaned_counts = save_cleaned_data(results) - for field in batch_cleaned_counts: - cleaned_counts_by_field[field] += batch_cleaned_counts[field] - pool.close() + with multiprocessing.Pool(processes=num_workers) as pool: + log.info(f"Starting {len(jobs)} cleaning jobs") + + for result in pool.starmap(_clean_data_worker, jobs): + batch_cleaned_counts = save_cleaned_data(result) + for field in batch_cleaned_counts: + cleaned_counts_by_field[field] += batch_cleaned_counts[field] num_cleaned += len(batch) batch_end_time = time.time() rate = len(batch) / (batch_end_time - batch_start_time)