Skip to content

Commit

Permalink
feat(backend): Add Project data model and creation API endpoint
Browse files Browse the repository at this point in the history
Integration tests are missing pending a suitable mocking strategy.
  • Loading branch information
AdrianoKF committed Oct 31, 2024
1 parent 7dcdb6a commit 49725e2
Show file tree
Hide file tree
Showing 9 changed files with 346 additions and 19 deletions.
3 changes: 2 additions & 1 deletion backend/src/jobq_server/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from jobq_server.config import settings
from jobq_server.db import check_migrations, get_engine, upgrade_migrations
from jobq_server.routers import jobs
from jobq_server.routers import jobs, projects


@asynccontextmanager
Expand Down Expand Up @@ -36,6 +36,7 @@ async def lifespan(app: FastAPI):
)

app.include_router(jobs.router, prefix="/jobs")
app.include_router(projects.router, prefix="/projects")


@app.get("/health", include_in_schema=False)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Initial schema
Revision ID: 2837c7c54f35
Revises:
Create Date: 2024-10-31 11:27:32.242586
"""
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes


# revision identifiers, used by Alembic.
revision = '2837c7c54f35'
down_revision = None
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('project',
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('description', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('cluster_queue', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('local_queue', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('namespace', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('id', sa.Uuid(), nullable=False),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_project_description'), 'project', ['description'], unique=False)
op.create_index(op.f('ix_project_name'), 'project', ['name'], unique=True)
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_project_name'), table_name='project')
op.drop_index(op.f('ix_project_description'), table_name='project')
op.drop_table('project')
# ### end Alembic commands ###
25 changes: 23 additions & 2 deletions backend/src/jobq_server/db.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from threading import Lock
from uuid import UUID, uuid4

from alembic import command
from alembic.config import Config
from alembic.runtime.migration import MigrationContext
from alembic.script import ScriptDirectory
from sqlalchemy import Engine
from sqlmodel import SQLModel as SQLModel
from sqlmodel import create_engine
from sqlmodel import Field, SQLModel, create_engine

from jobq_server.config import settings

Expand Down Expand Up @@ -50,3 +50,24 @@ def upgrade_migrations():

alembic_cfg = Config("alembic.ini")
command.upgrade(alembic_cfg, "head")


# --- PROJECT
class ProjectBase(SQLModel):
name: str = Field(index=True, unique=True)
description: str | None = Field(None, index=True)
cluster_queue: str | None = Field(None)
local_queue: str | None = Field(None)
namespace: str | None = Field(None)


class Project(ProjectBase, table=True):
id: UUID = Field(default_factory=uuid4, primary_key=True)


class ProjectCreate(ProjectBase):
pass


class ProjectPublic(ProjectBase):
id: UUID
10 changes: 9 additions & 1 deletion backend/src/jobq_server/dependencies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,21 @@
from jobq_server.db import get_engine
from jobq_server.models import JobId
from jobq_server.services.k8s import KubernetesService
from jobq_server.services.kueue import KueueService
from jobq_server.utils.kueue import KueueWorkload


def k8s_service() -> KubernetesService:
return KubernetesService()


KubernetesDep = Annotated[KubernetesService, Depends(k8s_service)]


def kueue_service(k8s: KubernetesDep) -> KueueService:
return KueueService(k8s)


def managed_workload(
k8s: Annotated[KubernetesService, Depends(k8s_service)],
uid: JobId,
Expand All @@ -31,5 +39,5 @@ def get_session() -> Generator[Session, None, None]:


ManagedWorkload = Annotated[KueueWorkload, Depends(managed_workload)]
Kubernetes = Annotated[KubernetesService, Depends(k8s_service)]
KueueDep = Annotated[KueueService, Depends(kueue_service)]
DBSessionDep = Annotated[Session, Depends(get_session)]
10 changes: 5 additions & 5 deletions backend/src/jobq_server/routers/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from fastapi.responses import StreamingResponse
from jobq import Job

from jobq_server.dependencies import Kubernetes, ManagedWorkload
from jobq_server.dependencies import KubernetesDep, ManagedWorkload
from jobq_server.exceptions import PodNotReadyError
from jobq_server.models import (
CreateJobModel,
Expand All @@ -28,7 +28,7 @@
@router.post("")
async def submit_job(
opts: CreateJobModel,
k8s: Kubernetes,
k8s: KubernetesDep,
) -> WorkloadIdentifier:
# FIXME: Having to define a function just to set the job name is ugly
def job_fn(): ...
Expand Down Expand Up @@ -72,7 +72,7 @@ async def status(
@router.get("/{uid}/logs")
async def logs(
workload: ManagedWorkload,
k8s: Kubernetes,
k8s: KubernetesDep,
params: Annotated[LogOptions, Depends(make_dependable(LogOptions))],
):
try:
Expand Down Expand Up @@ -164,7 +164,7 @@ async def stream_response(
async def stop_workload(
uid: JobId,
workload: ManagedWorkload,
k8s: Kubernetes,
k8s: KubernetesDep,
):
try:
workload.stop(k8s)
Expand All @@ -182,7 +182,7 @@ async def stop_workload(

@router.get("", response_model_exclude_unset=True)
async def list_jobs(
k8s: Kubernetes,
k8s: KubernetesDep,
include_metadata: Annotated[bool, Query()] = False,
) -> list[ListWorkloadModel]:
workloads = k8s.list_workloads()
Expand Down
87 changes: 87 additions & 0 deletions backend/src/jobq_server/routers/projects.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import logging

from fastapi import APIRouter
from kubernetes import client
from sqlmodel import select

from jobq_server.db import Project, ProjectCreate, ProjectPublic
from jobq_server.dependencies import DBSessionDep, KubernetesDep, KueueDep
from jobq_server.utils.kueue import ClusterQueue, ClusterQueueSpec, LocalQueue

router = APIRouter()


@router.get("/")
async def list_projects(db: DBSessionDep):
return db.exec(select(Project)).all()


@router.post("/", status_code=201)
async def create_project(
project: ProjectCreate,
db: DBSessionDep,
k8s: KubernetesDep,
kueue: KueueDep,
) -> ProjectPublic:
# Create namespace if it doesn't exist
ns, created = k8s.ensure_namespace(project.namespace)
if created:
logging.info(f"Created Kubernetes namespace {ns.metadata.name}")

# Create cluster queue if it doesn't exist
cluster_queue = kueue.get_cluster_queue(project.cluster_queue)
if cluster_queue is None:
default_spec = {
"namespaceSelector": {},
"preemption": {
"reclaimWithinCohort": "Any",
"borrowWithinCohort": {
"policy": "LowerPriority",
"maxPriorityThreshold": 100,
},
"withinClusterQueue": "LowerPriority",
},
"resourceGroups": [
{
"coveredResources": ["cpu", "memory"],
"flavors": [
{
"name": "default-flavor",
"resources": [
{"name": "cpu", "nominalQuota": 4},
{"name": "memory", "nominalQuota": 6},
],
}
],
}
],
}
cluster_queue = ClusterQueue(
metadata=client.V1ObjectMeta(name=project.cluster_queue),
spec=ClusterQueueSpec.model_validate(default_spec),
)
kueue.create_cluster_queue(cluster_queue)
logging.info(f"Created cluster queue {project.cluster_queue!r}")

# Create local queue if it doesn't exist
local_queue = kueue.get_local_queue(project.local_queue, project.namespace)
if local_queue is None:
local_queue = LocalQueue(
metadata=client.V1ObjectMeta(
name=project.local_queue, namespace=project.namespace
),
spec={"clusterQueue": project.cluster_queue},
)

kueue.create_local_queue(local_queue)
logging.info(
f"Created user queue {project.local_queue!r} in namespace {project.namespace!r}"
)

# TODO: Apply finalizers to Kubernetes resources to prevent deletion while the project exists

db_obj = Project.model_validate(project)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
19 changes: 18 additions & 1 deletion backend/src/jobq_server/services/k8s.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def __init__(self):
)
config.load_kube_config()
self._in_cluster = False

self._core_v1_api = client.CoreV1Api()

@property
Expand Down Expand Up @@ -129,3 +128,21 @@ def list_workloads(self, namespace: str | None = None) -> list[KueueWorkload]:
KueueWorkload.model_validate(workload)
for workload in workloads.get("items", [])
]

def ensure_namespace(self, name: str) -> tuple[client.V1Namespace, bool]:
"""Create or look up a namespace by name
Returns
-------
tuple[client.V1Namespace, bool]
The namespace object and a boolean indicating whether it was created
"""

try:
return self._core_v1_api.read_namespace(name), False
except client.ApiException as e:
if e.status == 404:
return self._core_v1_api.create_namespace(
client.V1Namespace(metadata=client.V1ObjectMeta(name=name))
), True
raise
82 changes: 82 additions & 0 deletions backend/src/jobq_server/services/kueue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from kubernetes import client

from jobq_server.services.k8s import KubernetesService
from jobq_server.utils.kueue import ClusterQueue, LocalQueue


class KueueService:
def __init__(self, k8s: KubernetesService):
self.k8s = k8s
self.custom_obj_api = client.CustomObjectsApi()

def get_cluster_queue(self, name: str) -> ClusterQueue | None:
"""Get a cluster queue by name.
Returns
-------
ClusterQueue | None
The cluster queue if it exists, otherwise None.
"""
try:
k8s_obj = self.custom_obj_api.get_cluster_custom_object(
"kueue.x-k8s.io",
"v1beta1",
"clusterqueues",
name,
)
return ClusterQueue.model_validate(k8s_obj)
except client.ApiException as e:
if e.status == 404:
return None
raise

def get_local_queue(self, name: str, namespace: str) -> LocalQueue | None:
"""Get a local queue by name and namespace.
Returns
-------
LocalQueue | None
The local queue if it exists, otherwise None.
"""
try:
k8s_obj = self.custom_obj_api.get_namespaced_custom_object(
"kueue.x-k8s.io",
"v1beta1",
namespace,
"localqueues",
name,
)
return LocalQueue.model_validate(k8s_obj)
except client.ApiException as e:
if e.status == 404:
return None
raise

def create_local_queue(self, queue: LocalQueue) -> None:
_ = self.k8s.ensure_namespace(queue.metadata.namespace)

data = {
"apiVersion": "kueue.x-k8s.io/v1beta1",
"kind": "LocalQueue",
**queue.model_dump(),
}
return self.custom_obj_api.create_namespaced_custom_object(
"kueue.x-k8s.io",
"v1beta1",
queue.metadata.namespace,
"localqueues",
body=data,
)

def create_cluster_queue(self, queue: ClusterQueue) -> None:
data = {
"apiVersion": "kueue.x-k8s.io/v1beta1",
"kind": "ClusterQueue",
**queue.model_dump(),
}
return self.custom_obj_api.create_cluster_custom_object(
"kueue.x-k8s.io",
"v1beta1",
"clusterqueues",
body=data,
)
Loading

0 comments on commit 49725e2

Please sign in to comment.