Skip to content

Commit

Permalink
[DOP-20958] Insert operations, inputs & outputs in one statement
Browse files Browse the repository at this point in the history
  • Loading branch information
dolfinus committed Nov 2, 2024
1 parent 65108ea commit e5d87f0
Show file tree
Hide file tree
Showing 9 changed files with 136 additions and 197 deletions.
33 changes: 16 additions & 17 deletions data_rentgen/consumer/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@ async def runs_handler(


async def save_to_db(data: BatchExtractionResult, unit_of_work: UnitOfWork, logger: Logger) -> None: # noqa: WPS217
# To avoid issues when parallel consumer instances create the same object, and then fail at the end,
# commit changes as soon as possible. Yes, this is quite slow, but it's fine for a prototype.
# TODO: rewrite this to create objects in batch.
# To avoid deadlocks when parallel consumer instances insert/update the same row,
# commit changes for each row instead of committing the whole batch. Yes, this cloud be slow.

logger.debug("Creating locations")
for location_dto in data.locations():
Expand Down Expand Up @@ -75,24 +74,24 @@ async def save_to_db(data: BatchExtractionResult, unit_of_work: UnitOfWork, logg
schema = await unit_of_work.schema.get_or_create(schema_dto)
schema_dto.id = schema.id

# Some events related to specific run are send to the same Kafka partition,
# but at the same time we have parent_run which may be already inserted/updated by other worker
# (Kafka key maybe different for run and it's parent).
# In this case we cannot insert all the rows in one transaction, as it may lead to deadlocks.
logger.debug("Creating runs")
for run_dto in data.runs():
async with unit_of_work:
await unit_of_work.run.create_or_update(run_dto)

logger.debug("Creating operations")
for operation_dto in data.operations():
async with unit_of_work:
await unit_of_work.operation.create_or_update(operation_dto)
# All events related to same operation are always send to the same Kafka partition,
# so other workers never insert/update the same operation in parallel.
# These rows can be inserted/updated in bulk, in one transaction.
async with unit_of_work:
logger.debug("Creating operations")
await unit_of_work.operation.create_or_update_bulk(data.operations())

logger.debug("Creating inputs")
for input_dto in data.inputs():
async with unit_of_work:
input = await unit_of_work.input.create_or_update(input_dto)
input_dto.id = input.id
logger.debug("Creating inputs")
await unit_of_work.input.create_or_update_bulk(data.inputs())

logger.debug("Creating outputs")
for output_dto in data.outputs():
async with unit_of_work:
output = await unit_of_work.output.create_or_update(output_dto)
output_dto.id = output.id
logger.debug("Creating outputs")
await unit_of_work.output.create_or_update_bulk(data.outputs())
1 change: 1 addition & 0 deletions data_rentgen/db/repositories/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ async def _create(self, dataset: DatasetDTO) -> Dataset:
return result

async def _update(self, existing: Dataset, new: DatasetDTO) -> Dataset:
# almost of fields are immutable, so we can avoid UPDATE statements if row is unchanged
if new.format:
existing.format = new.format
await self._session.flush([existing])
Expand Down
1 change: 1 addition & 0 deletions data_rentgen/db/repositories/dataset_symlink.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ async def _create(self, dataset_symlink: DatasetSymlinkDTO) -> DatasetSymlink:
return result

async def _update(self, existing: DatasetSymlink, new: DatasetSymlinkDTO) -> DatasetSymlink:
# almost of fields are immutable, so we can avoid UPDATE statements if row is unchanged
existing.type = DatasetSymlinkType(new.type)
await self._session.flush([existing])
return existing
100 changes: 37 additions & 63 deletions data_rentgen/db/repositories/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from uuid import UUID

from sqlalchemy import Select, any_, func, literal_column, select
from sqlalchemy.dialects.postgresql import insert

from data_rentgen.db.models import Input
from data_rentgen.db.repositories.base import Repository
Expand All @@ -17,7 +18,7 @@


class InputRepository(Repository[Input]):
async def create_or_update(self, input: InputDTO) -> Input:
def get_id(self, input: InputDTO) -> UUID:
# `created_at' field of input should be the same as operation's,
# to avoid scanning all partitions and speed up queries
created_at = extract_timestamp_from_uuid(input.operation.id)
Expand All @@ -29,27 +30,41 @@ async def create_or_update(self, input: InputDTO) -> Input:
str(input.dataset.id),
str(input.schema.id) if input.schema else "",
]
input_id = generate_incremental_uuid(created_at, ".".join(id_components).encode("utf-8"))

result = await self._get(created_at, input_id)
if not result:
# try one more time, but with lock acquired.
# if another worker already created the same row, just use it. if not - create with holding the lock.
await self._lock(input_id)
result = await self._get(created_at, input_id)

if not result:
return await self._create(
created_at=created_at,
input_id=input_id,
input=input,
operation_id=input.operation.id,
run_id=input.operation.run.id,
job_id=input.operation.run.job.id, # type: ignore[arg-type]
dataset_id=input.dataset.id, # type: ignore[arg-type]
schema_id=input.schema.id if input.schema else None,
)
return await self._update(result, input)
return generate_incremental_uuid(created_at, ".".join(id_components).encode("utf-8"))

async def create_or_update_bulk(self, inputs: list[InputDTO]) -> list[Input]:
if not inputs:
return []

insert_statement = insert(Input)
statement = insert_statement.on_conflict_do_update(
index_elements=[Input.created_at, Input.id],
set_={
"num_bytes": func.coalesce(insert_statement.excluded.num_bytes, Input.num_bytes),
"num_rows": func.coalesce(insert_statement.excluded.num_rows, Input.num_rows),
"num_files": func.coalesce(insert_statement.excluded.num_files, Input.num_files),
},
).returning(Input)

result = await self._session.execute(
statement,
[
{
"id": self.get_id(input),
"created_at": extract_timestamp_from_uuid(input.operation.id),
"operation_id": input.operation.id,
"run_id": input.operation.run.id,
"job_id": input.operation.run.job.id, # type: ignore[arg-type]
"dataset_id": input.dataset.id, # type: ignore[arg-type]
"schema_id": input.schema.id if input.schema else None,
"num_bytes": input.num_bytes,
"num_rows": input.num_rows,
"num_files": input.num_files,
}
for input in inputs
],
)
return list(result.scalars().all())

async def list_by_operation_ids(
self,
Expand Down Expand Up @@ -185,44 +200,3 @@ def _get_select(
Input.job_id,
Input.dataset_id,
)

async def _get(self, created_at: datetime, input_id: UUID) -> Input | None:
query = select(Input).where(Input.created_at == created_at, Input.id == input_id)
return await self._session.scalar(query)

async def _create(
self,
created_at: datetime,
input_id: UUID,
input: InputDTO,
operation_id: UUID,
run_id: UUID,
job_id: int,
dataset_id: int,
schema_id: int | None = None,
) -> Input:
result = Input(
created_at=created_at,
id=input_id,
operation_id=operation_id,
run_id=run_id,
job_id=job_id,
dataset_id=dataset_id,
schema_id=schema_id,
num_bytes=input.num_bytes,
num_rows=input.num_rows,
num_files=input.num_files,
)
self._session.add(result)
await self._session.flush([result])
return result

async def _update(self, existing: Input, new: InputDTO) -> Input:
if new.num_bytes is not None:
existing.num_bytes = new.num_bytes
if new.num_rows is not None:
existing.num_rows = new.num_rows
if new.num_files is not None:
existing.num_files = new.num_files
await self._session.flush([existing])
return existing
1 change: 1 addition & 0 deletions data_rentgen/db/repositories/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ async def _create(self, job: JobDTO) -> Job:
return result

async def _update(self, existing: Job, new: JobDTO) -> Job:
# almost of fields are immutable, so we can avoid UPDATE statements if row is unchanged
if new.type:
existing.type = JobType(new.type)
await self._session.flush([existing])
Expand Down
1 change: 1 addition & 0 deletions data_rentgen/db/repositories/location.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ async def _create(self, location: LocationDTO) -> Location:
async def _update_addresses(self, existing: Location, new: LocationDTO) -> Location:
existing_urls = {address.url for address in existing.addresses}
new_urls = new.addresses - existing_urls
# in most cases, Location is unchanged, so we can avoid UPDATE statements
if not new_urls:
return existing

Expand Down
92 changes: 40 additions & 52 deletions data_rentgen/db/repositories/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from datetime import datetime, timezone
from typing import Sequence

from sqlalchemy import any_, select
from sqlalchemy import any_, func, select
from sqlalchemy.dialects.postgresql import insert

from data_rentgen.db.models import Operation, OperationType, Status
from data_rentgen.db.repositories.base import Repository
Expand All @@ -14,19 +15,45 @@


class OperationRepository(Repository[Operation]):
async def create_or_update(self, operation: OperationDTO) -> Operation:
# avoid calculating created_at twice
created_at = extract_timestamp_from_uuid(operation.id)
result = await self._get(created_at, operation.id)
if not result:
# try one more time, but with lock acquired.
# if another worker already created the same row, just use it. if not - create with holding the lock.
await self._lock(operation.id)
result = await self._get(created_at, operation.id)
async def create_or_update_bulk(self, operations: list[OperationDTO]) -> list[Operation]:
if not operations:
return []

if not result:
return await self._create(created_at, operation)
return await self._update(result, operation)
insert_statement = insert(Operation)
statement = insert_statement.on_conflict_do_update(
index_elements=[Operation.created_at, Operation.id],
set_={
"name": func.coalesce(insert_statement.excluded.name, Operation.name),
"type": func.coalesce(insert_statement.excluded.type, Operation.type),
"status": func.coalesce(insert_statement.excluded.status, Operation.status),
"started_at": func.coalesce(insert_statement.excluded.started_at, Operation.started_at),
"ended_at": func.coalesce(insert_statement.excluded.ended_at, Operation.ended_at),
"description": func.coalesce(insert_statement.excluded.description, Operation.description),
"group": func.coalesce(insert_statement.excluded.group, Operation.group),
"position": func.coalesce(insert_statement.excluded.position, Operation.position),
},
).returning(Operation)

result = await self._session.execute(
statement,
[
{
"id": operation.id,
"created_at": extract_timestamp_from_uuid(operation.id),
"run_id": operation.run.id,
"name": operation.name,
"type": OperationType(operation.type) if operation.type else None,
"status": Status(operation.status) if operation.status else None,
"started_at": operation.started_at,
"ended_at": operation.ended_at,
"description": operation.description,
"group": operation.group,
"position": operation.position,
}
for operation in operations
],
)
return list(result.scalars().all())

async def paginate(
self,
Expand Down Expand Up @@ -104,42 +131,3 @@ async def list_by_ids(self, operation_ids: Sequence[UUID]) -> list[Operation]:
)
result = await self._session.scalars(query)
return list(result.all())

async def _get(self, created_at: datetime, operation_id: UUID) -> Operation | None:
query = select(Operation).where(Operation.created_at == created_at, Operation.id == operation_id)
return await self._session.scalar(query)

async def _create(self, created_at: datetime, operation: OperationDTO) -> Operation:
result = Operation(
created_at=created_at,
id=operation.id,
run_id=operation.run.id,
name=operation.name,
type=OperationType(operation.type),
status=Status(operation.status) if operation.status else Status.UNKNOWN,
started_at=operation.started_at,
ended_at=operation.ended_at,
description=operation.description,
group=operation.group,
position=operation.position,
)
self._session.add(result)
await self._session.flush([result])
return result

async def _update(self, existing: Operation, new: OperationDTO) -> Operation:
optional_fields = {
"type": OperationType(new.type) if new.type else None,
"status": Status(new.status) if new.status else None,
"started_at": new.started_at,
"ended_at": new.ended_at,
"description": new.description,
"group": new.group,
"position": new.position,
}
for column, value in optional_fields.items():
if value is not None:
setattr(existing, column, value)

await self._session.flush([existing])
return existing
Loading

0 comments on commit e5d87f0

Please sign in to comment.