diff --git a/data_rentgen/consumer/handlers.py b/data_rentgen/consumer/handlers.py index 82bf9add..5c542c7d 100644 --- a/data_rentgen/consumer/handlers.py +++ b/data_rentgen/consumer/handlers.py @@ -35,9 +35,8 @@ async def runs_handler( async def save_to_db(data: BatchExtractionResult, unit_of_work: UnitOfWork, logger: Logger) -> None: # noqa: WPS217 - # To avoid issues when parallel consumer instances create the same object, and then fail at the end, - # commit changes as soon as possible. Yes, this is quite slow, but it's fine for a prototype. - # TODO: rewrite this to create objects in batch. + # To avoid deadlocks when parallel consumer instances insert/update the same row, + # commit changes for each row instead of committing the whole batch. Yes, this cloud be slow. logger.debug("Creating locations") for location_dto in data.locations(): @@ -75,24 +74,24 @@ async def save_to_db(data: BatchExtractionResult, unit_of_work: UnitOfWork, logg schema = await unit_of_work.schema.get_or_create(schema_dto) schema_dto.id = schema.id + # Some events related to specific run are send to the same Kafka partition, + # but at the same time we have parent_run which may be already inserted/updated by other worker + # (Kafka key maybe different for run and it's parent). + # In this case we cannot insert all the rows in one transaction, as it may lead to deadlocks. logger.debug("Creating runs") for run_dto in data.runs(): async with unit_of_work: await unit_of_work.run.create_or_update(run_dto) - logger.debug("Creating operations") - for operation_dto in data.operations(): - async with unit_of_work: - await unit_of_work.operation.create_or_update(operation_dto) + # All events related to same operation are always send to the same Kafka partition, + # so other workers never insert/update the same operation in parallel. + # These rows can be inserted/updated in bulk, in one transaction. + async with unit_of_work: + logger.debug("Creating operations") + await unit_of_work.operation.create_or_update_bulk(data.operations()) - logger.debug("Creating inputs") - for input_dto in data.inputs(): - async with unit_of_work: - input = await unit_of_work.input.create_or_update(input_dto) - input_dto.id = input.id + logger.debug("Creating inputs") + await unit_of_work.input.create_or_update_bulk(data.inputs()) - logger.debug("Creating outputs") - for output_dto in data.outputs(): - async with unit_of_work: - output = await unit_of_work.output.create_or_update(output_dto) - output_dto.id = output.id + logger.debug("Creating outputs") + await unit_of_work.output.create_or_update_bulk(data.outputs()) diff --git a/data_rentgen/db/repositories/dataset.py b/data_rentgen/db/repositories/dataset.py index 939893da..f469ecb6 100644 --- a/data_rentgen/db/repositories/dataset.py +++ b/data_rentgen/db/repositories/dataset.py @@ -113,6 +113,7 @@ async def _create(self, dataset: DatasetDTO) -> Dataset: return result async def _update(self, existing: Dataset, new: DatasetDTO) -> Dataset: + # almost of fields are immutable, so we can avoid UPDATE statements if row is unchanged if new.format: existing.format = new.format await self._session.flush([existing]) diff --git a/data_rentgen/db/repositories/dataset_symlink.py b/data_rentgen/db/repositories/dataset_symlink.py index 2513e87c..6ee3d6e9 100644 --- a/data_rentgen/db/repositories/dataset_symlink.py +++ b/data_rentgen/db/repositories/dataset_symlink.py @@ -54,6 +54,7 @@ async def _create(self, dataset_symlink: DatasetSymlinkDTO) -> DatasetSymlink: return result async def _update(self, existing: DatasetSymlink, new: DatasetSymlinkDTO) -> DatasetSymlink: + # almost of fields are immutable, so we can avoid UPDATE statements if row is unchanged existing.type = DatasetSymlinkType(new.type) await self._session.flush([existing]) return existing diff --git a/data_rentgen/db/repositories/input.py b/data_rentgen/db/repositories/input.py index b8d76b2c..a3c54f97 100644 --- a/data_rentgen/db/repositories/input.py +++ b/data_rentgen/db/repositories/input.py @@ -6,6 +6,7 @@ from uuid import UUID from sqlalchemy import Select, any_, func, literal_column, select +from sqlalchemy.dialects.postgresql import insert from data_rentgen.db.models import Input from data_rentgen.db.repositories.base import Repository @@ -17,7 +18,7 @@ class InputRepository(Repository[Input]): - async def create_or_update(self, input: InputDTO) -> Input: + def get_id(self, input: InputDTO) -> UUID: # `created_at' field of input should be the same as operation's, # to avoid scanning all partitions and speed up queries created_at = extract_timestamp_from_uuid(input.operation.id) @@ -29,27 +30,41 @@ async def create_or_update(self, input: InputDTO) -> Input: str(input.dataset.id), str(input.schema.id) if input.schema else "", ] - input_id = generate_incremental_uuid(created_at, ".".join(id_components).encode("utf-8")) - - result = await self._get(created_at, input_id) - if not result: - # try one more time, but with lock acquired. - # if another worker already created the same row, just use it. if not - create with holding the lock. - await self._lock(input_id) - result = await self._get(created_at, input_id) - - if not result: - return await self._create( - created_at=created_at, - input_id=input_id, - input=input, - operation_id=input.operation.id, - run_id=input.operation.run.id, - job_id=input.operation.run.job.id, # type: ignore[arg-type] - dataset_id=input.dataset.id, # type: ignore[arg-type] - schema_id=input.schema.id if input.schema else None, - ) - return await self._update(result, input) + return generate_incremental_uuid(created_at, ".".join(id_components).encode("utf-8")) + + async def create_or_update_bulk(self, inputs: list[InputDTO]) -> list[Input]: + if not inputs: + return [] + + insert_statement = insert(Input) + statement = insert_statement.on_conflict_do_update( + index_elements=[Input.created_at, Input.id], + set_={ + "num_bytes": func.coalesce(insert_statement.excluded.num_bytes, Input.num_bytes), + "num_rows": func.coalesce(insert_statement.excluded.num_rows, Input.num_rows), + "num_files": func.coalesce(insert_statement.excluded.num_files, Input.num_files), + }, + ).returning(Input) + + result = await self._session.execute( + statement, + [ + { + "id": self.get_id(input), + "created_at": extract_timestamp_from_uuid(input.operation.id), + "operation_id": input.operation.id, + "run_id": input.operation.run.id, + "job_id": input.operation.run.job.id, # type: ignore[arg-type] + "dataset_id": input.dataset.id, # type: ignore[arg-type] + "schema_id": input.schema.id if input.schema else None, + "num_bytes": input.num_bytes, + "num_rows": input.num_rows, + "num_files": input.num_files, + } + for input in inputs + ], + ) + return list(result.scalars().all()) async def list_by_operation_ids( self, @@ -185,44 +200,3 @@ def _get_select( Input.job_id, Input.dataset_id, ) - - async def _get(self, created_at: datetime, input_id: UUID) -> Input | None: - query = select(Input).where(Input.created_at == created_at, Input.id == input_id) - return await self._session.scalar(query) - - async def _create( - self, - created_at: datetime, - input_id: UUID, - input: InputDTO, - operation_id: UUID, - run_id: UUID, - job_id: int, - dataset_id: int, - schema_id: int | None = None, - ) -> Input: - result = Input( - created_at=created_at, - id=input_id, - operation_id=operation_id, - run_id=run_id, - job_id=job_id, - dataset_id=dataset_id, - schema_id=schema_id, - num_bytes=input.num_bytes, - num_rows=input.num_rows, - num_files=input.num_files, - ) - self._session.add(result) - await self._session.flush([result]) - return result - - async def _update(self, existing: Input, new: InputDTO) -> Input: - if new.num_bytes is not None: - existing.num_bytes = new.num_bytes - if new.num_rows is not None: - existing.num_rows = new.num_rows - if new.num_files is not None: - existing.num_files = new.num_files - await self._session.flush([existing]) - return existing diff --git a/data_rentgen/db/repositories/job.py b/data_rentgen/db/repositories/job.py index 12337311..253b7921 100644 --- a/data_rentgen/db/repositories/job.py +++ b/data_rentgen/db/repositories/job.py @@ -116,6 +116,7 @@ async def _create(self, job: JobDTO) -> Job: return result async def _update(self, existing: Job, new: JobDTO) -> Job: + # almost of fields are immutable, so we can avoid UPDATE statements if row is unchanged if new.type: existing.type = JobType(new.type) await self._session.flush([existing]) diff --git a/data_rentgen/db/repositories/location.py b/data_rentgen/db/repositories/location.py index 99c54693..2c29927e 100644 --- a/data_rentgen/db/repositories/location.py +++ b/data_rentgen/db/repositories/location.py @@ -49,6 +49,7 @@ async def _create(self, location: LocationDTO) -> Location: async def _update_addresses(self, existing: Location, new: LocationDTO) -> Location: existing_urls = {address.url for address in existing.addresses} new_urls = new.addresses - existing_urls + # in most cases, Location is unchanged, so we can avoid UPDATE statements if not new_urls: return existing diff --git a/data_rentgen/db/repositories/operation.py b/data_rentgen/db/repositories/operation.py index 8b314ecb..7eed8414 100644 --- a/data_rentgen/db/repositories/operation.py +++ b/data_rentgen/db/repositories/operation.py @@ -4,7 +4,8 @@ from datetime import datetime, timezone from typing import Sequence -from sqlalchemy import any_, select +from sqlalchemy import any_, func, select +from sqlalchemy.dialects.postgresql import insert from data_rentgen.db.models import Operation, OperationType, Status from data_rentgen.db.repositories.base import Repository @@ -14,19 +15,45 @@ class OperationRepository(Repository[Operation]): - async def create_or_update(self, operation: OperationDTO) -> Operation: - # avoid calculating created_at twice - created_at = extract_timestamp_from_uuid(operation.id) - result = await self._get(created_at, operation.id) - if not result: - # try one more time, but with lock acquired. - # if another worker already created the same row, just use it. if not - create with holding the lock. - await self._lock(operation.id) - result = await self._get(created_at, operation.id) + async def create_or_update_bulk(self, operations: list[OperationDTO]) -> list[Operation]: + if not operations: + return [] - if not result: - return await self._create(created_at, operation) - return await self._update(result, operation) + insert_statement = insert(Operation) + statement = insert_statement.on_conflict_do_update( + index_elements=[Operation.created_at, Operation.id], + set_={ + "name": func.coalesce(insert_statement.excluded.name, Operation.name), + "type": func.coalesce(insert_statement.excluded.type, Operation.type), + "status": func.coalesce(insert_statement.excluded.status, Operation.status), + "started_at": func.coalesce(insert_statement.excluded.started_at, Operation.started_at), + "ended_at": func.coalesce(insert_statement.excluded.ended_at, Operation.ended_at), + "description": func.coalesce(insert_statement.excluded.description, Operation.description), + "group": func.coalesce(insert_statement.excluded.group, Operation.group), + "position": func.coalesce(insert_statement.excluded.position, Operation.position), + }, + ).returning(Operation) + + result = await self._session.execute( + statement, + [ + { + "id": operation.id, + "created_at": extract_timestamp_from_uuid(operation.id), + "run_id": operation.run.id, + "name": operation.name, + "type": OperationType(operation.type) if operation.type else None, + "status": Status(operation.status) if operation.status else None, + "started_at": operation.started_at, + "ended_at": operation.ended_at, + "description": operation.description, + "group": operation.group, + "position": operation.position, + } + for operation in operations + ], + ) + return list(result.scalars().all()) async def paginate( self, @@ -104,42 +131,3 @@ async def list_by_ids(self, operation_ids: Sequence[UUID]) -> list[Operation]: ) result = await self._session.scalars(query) return list(result.all()) - - async def _get(self, created_at: datetime, operation_id: UUID) -> Operation | None: - query = select(Operation).where(Operation.created_at == created_at, Operation.id == operation_id) - return await self._session.scalar(query) - - async def _create(self, created_at: datetime, operation: OperationDTO) -> Operation: - result = Operation( - created_at=created_at, - id=operation.id, - run_id=operation.run.id, - name=operation.name, - type=OperationType(operation.type), - status=Status(operation.status) if operation.status else Status.UNKNOWN, - started_at=operation.started_at, - ended_at=operation.ended_at, - description=operation.description, - group=operation.group, - position=operation.position, - ) - self._session.add(result) - await self._session.flush([result]) - return result - - async def _update(self, existing: Operation, new: OperationDTO) -> Operation: - optional_fields = { - "type": OperationType(new.type) if new.type else None, - "status": Status(new.status) if new.status else None, - "started_at": new.started_at, - "ended_at": new.ended_at, - "description": new.description, - "group": new.group, - "position": new.position, - } - for column, value in optional_fields.items(): - if value is not None: - setattr(existing, column, value) - - await self._session.flush([existing]) - return existing diff --git a/data_rentgen/db/repositories/output.py b/data_rentgen/db/repositories/output.py index 16e57217..82b2baae 100644 --- a/data_rentgen/db/repositories/output.py +++ b/data_rentgen/db/repositories/output.py @@ -6,6 +6,7 @@ from uuid import UUID from sqlalchemy import Select, any_, func, literal_column, select +from sqlalchemy.dialects.postgresql import insert from data_rentgen.db.models import Output, OutputType from data_rentgen.db.repositories.base import Repository @@ -17,7 +18,7 @@ class OutputRepository(Repository[Output]): - async def create_or_update(self, output: OutputDTO) -> Output: + def get_id(self, output: OutputDTO) -> UUID: # `created_at' field of output should be the same as operation's, # to avoid scanning all partitions and speed up queries created_at = extract_timestamp_from_uuid(output.operation.id) @@ -29,27 +30,42 @@ async def create_or_update(self, output: OutputDTO) -> Output: str(output.dataset.id), str(output.schema.id) if output.schema else "", ] - output_id = generate_incremental_uuid(created_at, ".".join(id_components).encode("utf-8")) - - result = await self._get(created_at, output_id) - if not result: - # try one more time, but with lock acquired. - # if another worker already created the same row, just use it. if not - create with holding the lock. - await self._lock(output_id) - result = await self._get(created_at, output_id) - - if not result: - return await self._create( - created_at=created_at, - output_id=output_id, - output=output, - operation_id=output.operation.id, - run_id=output.operation.run.id, - job_id=output.operation.run.job.id, # type: ignore[arg-type] - dataset_id=output.dataset.id, # type: ignore[arg-type] - schema_id=output.schema.id if output.schema else None, - ) - return await self._update(result, output) + return generate_incremental_uuid(created_at, ".".join(id_components).encode("utf-8")) + + async def create_or_update_bulk(self, outputs: list[OutputDTO]) -> list[Output]: + if not outputs: + return [] + + insert_statement = insert(Output) + statement = insert_statement.on_conflict_do_update( + index_elements=[Output.created_at, Output.id], + set_={ + "num_bytes": func.coalesce(insert_statement.excluded.num_bytes, Output.num_bytes), + "num_rows": func.coalesce(insert_statement.excluded.num_rows, Output.num_rows), + "num_files": func.coalesce(insert_statement.excluded.num_files, Output.num_files), + }, + ).returning(Output) + + result = await self._session.execute( + statement, + [ + { + "id": self.get_id(output), + "created_at": extract_timestamp_from_uuid(output.operation.id), + "type": OutputType(output.type), + "operation_id": output.operation.id, + "run_id": output.operation.run.id, + "job_id": output.operation.run.job.id, # type: ignore[arg-type] + "dataset_id": output.dataset.id, # type: ignore[arg-type] + "schema_id": output.schema.id if output.schema else None, + "num_bytes": output.num_bytes, + "num_rows": output.num_rows, + "num_files": output.num_files, + } + for output in outputs + ], + ) + return list(result.scalars().all()) async def list_by_operation_ids( self, @@ -190,45 +206,3 @@ def _get_select( Output.dataset_id, Output.type, ) - - async def _get(self, created_at: datetime, output_id: UUID) -> Output | None: - query = select(Output).where(Output.created_at == created_at, Output.id == output_id) - return await self._session.scalar(query) - - async def _create( - self, - created_at: datetime, - output_id: UUID, - output: OutputDTO, - operation_id: UUID, - run_id: UUID, - job_id: int, - dataset_id: int, - schema_id: int | None = None, - ) -> Output: - result = Output( - created_at=created_at, - id=output_id, - operation_id=operation_id, - run_id=run_id, - job_id=job_id, - dataset_id=dataset_id, - type=OutputType(output.type), - schema_id=schema_id, - num_bytes=output.num_bytes, - num_rows=output.num_rows, - num_files=output.num_files, - ) - self._session.add(result) - await self._session.flush([result]) - return result - - async def _update(self, existing: Output, new: OutputDTO) -> Output: - if new.num_bytes is not None: - existing.num_bytes = new.num_bytes - if new.num_rows is not None: - existing.num_rows = new.num_rows - if new.num_files is not None: - existing.num_files = new.num_files - await self._session.flush([existing]) - return existing diff --git a/data_rentgen/db/repositories/run.py b/data_rentgen/db/repositories/run.py index 708cce5b..77b63c90 100644 --- a/data_rentgen/db/repositories/run.py +++ b/data_rentgen/db/repositories/run.py @@ -37,7 +37,6 @@ async def create_or_update(self, run: RunDTO) -> Run: if not result: return await self._create(created_at, run) - return await self._update(result, run) async def paginate( @@ -183,6 +182,7 @@ async def _update( existing: Run, new: RunDTO, ) -> Run: + # for parent_run most of fields are None, so we can avoid UPDATE statements if row is unchanged optional_fields = { "status": Status(new.status) if new.status else None, "parent_run_id": new.parent_run.id if new.parent_run else None,