Skip to content

Commit

Permalink
Use context manager with multiprocessing pool
Browse files Browse the repository at this point in the history
  • Loading branch information
obulat committed Mar 21, 2023
1 parent 9df0688 commit ccdd6bb
Showing 1 changed file with 11 additions and 23 deletions.
34 changes: 11 additions & 23 deletions ingestion_server/ingestion_server/cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,31 +266,20 @@ def _clean_data_worker(rows, temp_table, sources_config, all_fields: list[str]):
return cleaned_values


def save_cleaned_data(results):
def save_cleaned_data(result: dict) -> dict[str, int]:
log.info("Saving cleaned data...")
start_time = time.time()

results_to_save: dict[str, list[tuple[str, str | Json]]] = {}
# Results is a list of dicts, where each dict is a mapping of field name to
# a list of tuples of (identifier, cleaned_value). There are as many dicts
# as there are workers. We need to merge the lists of tuples for each field
# name.
for result in results:
for field, values in result.items():
if field not in results_to_save:
results_to_save[field] = []
results_to_save[field].extend(values)
cleanup_counts = {}
for field, cleaned_items in results_to_save.items():
cleanup_counts[field] = len(cleaned_items) if cleaned_items else 0
cleanup_counts = {field: len(items) for field, items in result.items()}
for field, cleaned_items in result.items():
if cleaned_items:
with open(f"{field}.tsv", "a") as f:
csv_writer = csv.writer(f, delimiter="\t")
csv_writer.writerows(cleaned_items)

end_time = time.time()
total_time = end_time - start_time
log.info(f"Finished saving cleaned data in {total_time}")
log.info(f"Finished saving cleaned data in {total_time},\n{cleanup_counts}")
return cleanup_counts


Expand Down Expand Up @@ -362,14 +351,13 @@ def clean_image_data(table):
cleanable_fields_for_table,
)
)
pool = multiprocessing.Pool(processes=num_workers)
log.info(f"Starting {len(jobs)} cleaning jobs")
conn.commit()
results = pool.starmap(_clean_data_worker, jobs)
batch_cleaned_counts = save_cleaned_data(results)
for field in batch_cleaned_counts:
cleaned_counts_by_field[field] += batch_cleaned_counts[field]
pool.close()
with multiprocessing.Pool(processes=num_workers) as pool:
log.info(f"Starting {len(jobs)} cleaning jobs")

for result in pool.starmap(_clean_data_worker, jobs):
batch_cleaned_counts = save_cleaned_data(result)
for field in batch_cleaned_counts:
cleaned_counts_by_field[field] += batch_cleaned_counts[field]
num_cleaned += len(batch)
batch_end_time = time.time()
rate = len(batch) / (batch_end_time - batch_start_time)
Expand Down

0 comments on commit ccdd6bb

Please sign in to comment.