Skip to content

Commit

Permalink
Improve deployment concurrency GCL management (#15426)
Browse files Browse the repository at this point in the history
  • Loading branch information
collincchoy authored Sep 19, 2024
1 parent bb96af8 commit 4ee7d12
Show file tree
Hide file tree
Showing 19 changed files with 489 additions and 58 deletions.
19 changes: 16 additions & 3 deletions docs/3.0/api-ref/rest-api/server/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -15449,7 +15449,8 @@
"type": "null"
}
],
"description": "Filter criteria for `Deployment.concurrency`"
"description": "DEPRECATED: Prefer `Deployment.concurrency_limit_id` over `Deployment.concurrency_limit`. If provided, will be ignored for backwards-compatibility. Will be removed after December 2024.",
"deprecated": true
}
},
"additionalProperties": false,
Expand Down Expand Up @@ -15499,7 +15500,7 @@
"additionalProperties": false,
"type": "object",
"title": "DeploymentFilterConcurrencyLimit",
"description": "Filter by `Deployment.concurrency_limit`."
"description": "DEPRECATED: Prefer `Deployment.concurrency_limit_id` over `Deployment.concurrency_limit`."
},
"DeploymentFilterId": {
"properties": {
Expand Down Expand Up @@ -15933,7 +15934,19 @@
}
],
"title": "Concurrency Limit",
"description": "The maximum number of flow runs that can be active at once."
"description": "DEPRECATED: Prefer `global_concurrency_limit`. Will always be None for backwards compatibility. Will be removed after December 2024.",
"deprecated": true
},
"global_concurrency_limit": {
"anyOf": [
{
"$ref": "#/components/schemas/GlobalConcurrencyLimitResponse"
},
{
"type": "null"
}
],
"description": "The global concurrency limit object for enforcing the maximum number of flow runs that can be active at once."
},
"concurrency_options": {
"anyOf": [
Expand Down
6 changes: 4 additions & 2 deletions src/prefect/client/schemas/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ class DeploymentFilterTags(PrefectBaseModel, OperatorMixin):


class DeploymentFilterConcurrencyLimit(PrefectBaseModel):
"""Filter by `Deployment.concurrency_limit`."""
"""DEPRECATED: Prefer `Deployment.concurrency_limit_id` over `Deployment.concurrency_limit`."""

ge_: Optional[int] = Field(
default=None,
Expand Down Expand Up @@ -538,7 +538,9 @@ class DeploymentFilter(PrefectBaseModel, OperatorMixin):
default=None, description="Filter criteria for `Deployment.work_queue_name`"
)
concurrency_limit: Optional[DeploymentFilterConcurrencyLimit] = Field(
default=None, description="Filter criteria for `Deployment.concurrency_limit`"
default=None,
description="DEPRECATED: Prefer `Deployment.concurrency_limit_id` over `Deployment.concurrency_limit`. If provided, will be ignored for backwards-compatibility. Will be removed after December 2024.",
deprecated=True,
)


Expand Down
8 changes: 7 additions & 1 deletion src/prefect/client/schemas/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,13 @@ class DeploymentResponse(ObjectBaseModel):
default=..., description="The flow id associated with the deployment."
)
concurrency_limit: Optional[int] = Field(
default=None, description="The concurrency limit for the deployment."
default=None,
description="DEPRECATED: Prefer `global_concurrency_limit`. Will always be None for backwards compatibility. Will be removed after December 2024.",
deprecated=True,
)
global_concurrency_limit: Optional["GlobalConcurrencyLimitResponse"] = Field(
default=None,
description="The global concurrency limit object for enforcing the maximum number of flow runs that can be active at once.",
)
concurrency_options: Optional[objects.ConcurrencyOptions] = Field(
default=None,
Expand Down
10 changes: 2 additions & 8 deletions src/prefect/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,15 +1049,9 @@ async def _submit_run_and_capture_errors(

if flow_run.deployment_id:
deployment = await self._client.read_deployment(flow_run.deployment_id)
if deployment and deployment.concurrency_limit:
limit_name = f"deployment:{deployment.id}"
if deployment and deployment.global_concurrency_limit:
limit_name = deployment.global_concurrency_limit.name
concurrency_ctx = concurrency

# ensure that the global concurrency limit is available
# and up-to-date before attempting to acquire a slot
await self._client.upsert_global_concurrency_limit_by_name(
limit_name, deployment.concurrency_limit
)
else:
limit_name = ""
concurrency_ctx = asyncnullcontext
Expand Down
5 changes: 4 additions & 1 deletion src/prefect/server/database/migrations/MIGRATION-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@ Each time a database migration is written, an entry is included here with:

This gives us a history of changes and will create merge conflicts if two migrations are made at once, flagging situations where a branch needs to be updated before merging.

# Adds `concurrency_options` to `Deployments`
# Migrate `Deployment.concurrency_limit` to a foreign key `Deployment.concurrency_limit_id`
SQLite: `4ad4658cbefe`
Postgres: `eaec5004771f`

# Adds `concurrency_options` to `Deployments`
SQLite: `7d6350aea855`
Postgres: `555ed31b284d`

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Add deployment to global concurrency limit FK
Revision ID: eaec5004771f
Revises: 555ed31b284d
Create Date: 2024-09-16 15:20:51.582204
"""
import sqlalchemy as sa
from alembic import op

import prefect

# revision identifiers, used by Alembic.
revision = "eaec5004771f"
down_revision = "555ed31b284d"
branch_labels = None
depends_on = None


def upgrade():
op.add_column(
"deployment",
sa.Column(
"concurrency_limit_id",
prefect.server.utilities.database.UUID(),
nullable=True,
),
)

op.create_foreign_key(
op.f("fk_deployment__concurrency_limit_id__concurrency_limit_v2"),
"deployment",
"concurrency_limit_v2",
["concurrency_limit_id"],
["id"],
ondelete="SET NULL",
)

# migrate existing data
sql = sa.text(
"""
WITH deployment_limit_mapping AS (
SELECT d.id AS deployment_id, l.id AS limit_id
FROM deployment d
JOIN concurrency_limit_v2 l ON l.name = 'deployment:' || d.id::text
)
UPDATE deployment
SET concurrency_limit_id = dlm.limit_id
FROM deployment_limit_mapping dlm
WHERE deployment.id = dlm.deployment_id;
"""
)
op.execute(sql)


def downgrade():
op.drop_constraint(
op.f("fk_deployment__concurrency_limit_id__concurrency_limit_v2"),
"deployment",
type_="foreignkey",
)
op.drop_column("deployment", "concurrency_limit_id")
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Add deployment to global concurrency limit FK
Revision ID: 4ad4658cbefe
Revises: 7d6350aea855
Create Date: 2024-09-16 16:27:19.451150
"""
import sqlalchemy as sa
from alembic import op

import prefect

# revision identifiers, used by Alembic.
revision = "4ad4658cbefe"
down_revision = "7d6350aea855"
branch_labels = None
depends_on = None


def upgrade():
with op.batch_alter_table("deployment", schema=None) as batch_op:
batch_op.add_column(
sa.Column(
"concurrency_limit_id",
prefect.server.utilities.database.UUID(),
nullable=True,
)
)
batch_op.create_foreign_key(
batch_op.f("fk_deployment__concurrency_limit_id__concurrency_limit_v2"),
"concurrency_limit_v2",
["concurrency_limit_id"],
["id"],
ondelete="SET NULL",
)

# migrate existing data
sql = sa.text(
"""
WITH deployment_limit_mapping AS (
SELECT d.id AS deployment_id, l.id AS limit_id
FROM deployment d
JOIN concurrency_limit_v2 l ON l.name = 'deployment:' || d.id
)
UPDATE deployment
SET concurrency_limit_id = dlm.limit_id
FROM deployment_limit_mapping dlm
WHERE deployment.id = dlm.deployment_id;
"""
)
op.execute(sql)


def downgrade():
with op.batch_alter_table("deployment", schema=None) as batch_op:
batch_op.drop_constraint(
batch_op.f("fk_deployment__concurrency_limit_id__concurrency_limit_v2"),
type_="foreignkey",
)
batch_op.drop_column("concurrency_limit_id")
19 changes: 14 additions & 5 deletions src/prefect/server/database/orm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,11 +877,19 @@ def job_variables(self):
order_by=sa.desc(sa.text("updated")),
)

concurrency_limit: Mapped[Union[int, None]] = mapped_column(
sa.Integer,
server_default=None,
# deprecated in favor of `concurrency_limit_id` FK
_concurrency_limit: Mapped[Union[int, None]] = mapped_column(
sa.Integer, default=None, nullable=True, name="concurrency_limit"
)
concurrency_limit_id: Mapped[Union[uuid.UUID, None]] = mapped_column(
UUID,
sa.ForeignKey("concurrency_limit_v2.id", ondelete="SET NULL"),
nullable=True,
default=None,
)
global_concurrency_limit: Mapped[
Union["ConcurrencyLimitV2", None]
] = sa.orm.relationship(
lazy="selectin",
)
concurrency_options: Mapped[
Union[schemas.core.ConcurrencyOptions, None]
Expand All @@ -891,6 +899,7 @@ def job_variables(self):
nullable=True,
default=None,
)

tags: Mapped[List[str]] = mapped_column(
JSON, server_default="[]", default=list, nullable=False
)
Expand Down Expand Up @@ -984,7 +993,7 @@ class ConcurrencyLimitV2(Base):
active = sa.Column(sa.Boolean, nullable=False, default=True)
name = sa.Column(sa.String, nullable=False)
limit = sa.Column(sa.Integer, nullable=False)
active_slots = sa.Column(sa.Integer, nullable=False)
active_slots = sa.Column(sa.Integer, nullable=False, default=0)
denied_slots = sa.Column(sa.Integer, nullable=False, default=0)

slot_decay_per_second = sa.Column(sa.Float, default=0.0, nullable=False)
Expand Down
55 changes: 53 additions & 2 deletions src/prefect/server/models/deployments.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ async def create_deployment(

schedules = deployment.schedules
insert_values = deployment.model_dump_for_orm(
exclude_unset=True, exclude={"schedules"}
exclude_unset=True, exclude={"schedules", "concurrency_limit"}
)

# The job_variables field in client and server schemas is named
Expand Down Expand Up @@ -155,6 +155,10 @@ async def create_deployment(
],
)

await _create_or_update_deployment_concurrency_limit(
session, deployment_id, deployment.concurrency_limit
)

query = (
sa.select(orm_models.Deployment)
.where(
Expand Down Expand Up @@ -194,7 +198,7 @@ async def update_deployment(
# the user, ignoring any defaults on the model
update_data = deployment.model_dump_for_orm(
exclude_unset=True,
exclude={"work_pool_name"},
exclude={"work_pool_name", "concurrency_limit"},
)

# The job_variables field in client and server schemas is named
Expand Down Expand Up @@ -263,9 +267,41 @@ async def update_deployment(
],
)

await _create_or_update_deployment_concurrency_limit(
session, deployment_id, deployment.concurrency_limit
)

return result.rowcount > 0


async def _create_or_update_deployment_concurrency_limit(
session: AsyncSession, deployment_id: UUID, limit: Optional[int]
):
deployment = await session.get(orm_models.Deployment, deployment_id)
assert deployment is not None

if (
deployment.global_concurrency_limit
and deployment.global_concurrency_limit.limit == limit
) or (deployment.global_concurrency_limit is None and limit is None):
return

deployment._concurrency_limit = limit
if limit is None:
await _delete_related_concurrency_limit(
session=session, deployment_id=deployment_id
)
await session.refresh(deployment)
elif deployment.global_concurrency_limit:
deployment.global_concurrency_limit.limit = limit
else:
limit_name = f"deployment:{deployment_id}"
new_limit = orm_models.ConcurrencyLimitV2(name=limit_name, limit=limit)
deployment.global_concurrency_limit = new_limit

session.add(deployment)


async def read_deployment(
session: AsyncSession, deployment_id: UUID
) -> Optional[orm_models.Deployment]:
Expand Down Expand Up @@ -482,12 +518,27 @@ async def delete_deployment(session: AsyncSession, deployment_id: UUID) -> bool:
session=session, deployment_id=deployment_id, auto_scheduled_only=False
)

await _delete_related_concurrency_limit(
session=session, deployment_id=deployment_id
)

result = await session.execute(
delete(orm_models.Deployment).where(orm_models.Deployment.id == deployment_id)
)
return result.rowcount > 0


async def _delete_related_concurrency_limit(session: AsyncSession, deployment_id: UUID):
return await session.execute(
delete(orm_models.ConcurrencyLimitV2).where(
orm_models.ConcurrencyLimitV2.id
== sa.select(orm_models.Deployment.concurrency_limit_id)
.where(orm_models.Deployment.id == deployment_id)
.scalar_subquery()
)
)


async def schedule_runs(
session: AsyncSession,
deployment_id: UUID,
Expand Down
Loading

0 comments on commit 4ee7d12

Please sign in to comment.