Skip to content
This repository has been archived by the owner on Aug 2, 2023. It is now read-only.

feat: add start_service API #479

Merged
merged 21 commits into from
Dec 17, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
cc38070
Add new start_service API
kyujin-cho Sep 16, 2021
4080ded
Fix flake8 error
kyujin-cho Sep 16, 2021
2995ff4
Add towncrier changelog
kyujin-cho Sep 16, 2021
3a7d42b
Call wsproxy coordinator directly rather than exposing kernel info to…
kyujin-cho Sep 19, 2021
eb9d29d
Fix linting error
kyujin-cho Sep 19, 2021
5636605
Rename confusing keywords
kyujin-cho Sep 19, 2021
7cbb87f
Accept PR Reviews
kyujin-cho Oct 12, 2021
a08a250
Merge branch 'main' into feature/direct-wsproxy-connection
kyujin-cho Oct 13, 2021
c5b7e94
Add wsproxy-version API which returns target wsproxy version string
kyujin-cho Oct 13, 2021
36bb3e6
Fix invalid Redis API usage
kyujin-cho Oct 13, 2021
880ed52
Rename wsproxy_address to wsproxy_addr
kyujin-cho Oct 14, 2021
220814b
fix: typo in alembic downgrade function for wsproxy_addr
adrysn Oct 17, 2021
0693aa4
fix: error when scaling_groups.c.wsproxy_addr is am empty string ("")
adrysn Oct 17, 2021
5ba6d96
Merge branch 'main' into feature/direct-wsproxy-connection
adrysn Oct 17, 2021
2ae9e01
fix: flake8 errors (trailing comma)
adrysn Oct 17, 2021
7c11987
docs: Update news fragment
achimnol Oct 18, 2021
9e398de
Use aiotools' cache helper
kyujin-cho Oct 18, 2021
0834444
Set cache expire time to 30 seconds
kyujin-cho Oct 18, 2021
9ae4f07
Merge branch 'main' into feature/direct-wsproxy-connection
kyujin-cho Dec 13, 2021
505e19a
Merge branch 'main' into feature/direct-wsproxy-connection
kyujin-cho Dec 16, 2021
8890a43
Use dataclass instead of DummyHashObject
kyujin-cho Dec 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/479.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `session.start_service` API
99 changes: 99 additions & 0 deletions src/ai/backend/manager/api/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
TYPE_CHECKING,
cast,
)
from urllib.parse import urlparse
import uuid

import aiohttp
Expand All @@ -44,6 +45,8 @@
import sqlalchemy as sa
from sqlalchemy.sql.expression import true, null
import trafaret as t

from ai.backend.manager.models.scaling_group import scaling_groups
kyujin-cho marked this conversation as resolved.
Show resolved Hide resolved
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection

Expand Down Expand Up @@ -76,6 +79,7 @@
from ai.backend.common.logging import BraceStyleAdapter
from ai.backend.common.utils import cancel_tasks, str_to_timedelta
from ai.backend.common.types import (
AccessKey,
AgentId,
KernelId,
ClusterMode,
Expand Down Expand Up @@ -103,10 +107,12 @@
from ..models.kernel import match_session_ids
from ..models.utils import execute_with_retry
from .exceptions import (
AppNotFound,
InvalidAPIParameters,
GenericNotFound,
ImageNotFound,
InsufficientPrivilege,
ServiceUnavailable,
SessionNotFound,
SessionAlreadyExists,
TooManySessionsMatched,
Expand Down Expand Up @@ -1048,6 +1054,98 @@ async def create_cluster(request: web.Request, params: Any) -> web.Response:
return web.json_response(resp, status=201)


@server_status_required(READ_ALLOWED)
@auth_required
@check_api_params(
t.Dict({
tx.AliasedKey(['app', 'service']): t.String,
# The port argument is only required to use secondary ports
# when the target app listens multiple TCP ports.
# Otherwise it should be omitted or set to the same value of
# the actual port number used by the app.
tx.AliasedKey(['port'], default=None): t.Null | t.Int[1024:65535],
tx.AliasedKey(['envs'], default=None): t.Null | t.String, # stringified JSON
# e.g., '{"PASSWORD": "12345"}'
tx.AliasedKey(['arguments'], default=None): t.Null | t.String, # stringified JSON
# e.g., '{"-P": "12345"}'
# The value can be one of:
# None, str, List[str]
}))
async def start_service(request: web.Request, params: Mapping[str, Any]) -> web.Response:
root_ctx: RootContext = request.app['_root.context']
session_name: str = request.match_info['session_name']
access_key: AccessKey = request['keypair']['access_key']
service: str = params['app']
myself = asyncio.current_task()
assert myself is not None
try:
kernel = await asyncio.shield(root_ctx.registry.get_session(session_name, access_key))
except (SessionNotFound, TooManySessionsMatched):
raise

query = (sa.select([scaling_groups.c.wsproxy_address])
.select_from(scaling_groups)
.where((scaling_groups.c.name == kernel['scaling_group'])))

async with root_ctx.db.begin() as conn:
kyujin-cho marked this conversation as resolved.
Show resolved Hide resolved
result = await conn.execute(query)
sgroup = result.first()
wsproxy_address = sgroup['wsproxy_address']
if not wsproxy_address:
raise ServiceUnavailable('No coordinator configured for this resource group')

if kernel['kernel_host'] is None:
kernel_host = urlparse(kernel['agent_addr']).hostname
else:
kernel_host = kernel['kernel_host']
for sport in kernel['service_ports']:
if sport['name'] == service:
if params['port']:
# using one of the primary/secondary ports of the app
try:
hport_idx = sport['container_ports'].index(params['port'])
except ValueError:
raise InvalidAPIParameters(
f"Service {service} does not open the port number {params['port']}.")
host_port = sport['host_ports'][hport_idx]
else:
# using the default (primary) port of the app
if 'host_ports' not in sport:
host_port = sport['host_port'] # legacy kernels
else:
host_port = sport['host_ports'][0]
break
else:
raise AppNotFound(f'{session_name}:{service}')

await asyncio.shield(root_ctx.registry.increment_session_usage(session_name, access_key))

opts: MutableMapping[str, Union[None, str, List[str]]] = {}
if params['arguments'] is not None:
opts['arguments'] = json.loads(params['arguments'])
if params['envs'] is not None:
opts['envs'] = json.loads(params['envs'])

result = await asyncio.shield(
root_ctx.registry.start_service(session_name, access_key, service, opts)
)
if result['status'] == 'failed':
raise InternalServerError(
"Failed to launch the app service",
extra_data=result['error'])

async with aiohttp.ClientSession() as session:
async with session.post(f'{wsproxy_address}/v2/conf', json={
'kernel_host': kernel_host,
'kernel_port': host_port,
}) as resp:
token_json = await resp.json()
return web.json_response({
'token': token_json['token'],
'wsproxy_address': wsproxy_address,
})


async def handle_kernel_creation_lifecycle(
app: web.Application,
source: AgentId,
Expand Down Expand Up @@ -1992,4 +2090,5 @@ def create_app(default_cors_options: CORSOptions) -> Tuple[web.Application, Iter
cors.add(app.router.add_route('GET', '/{session_name}/download', download_files))
cors.add(app.router.add_route('GET', '/{session_name}/download_single', download_single))
cors.add(app.router.add_route('GET', '/{session_name}/files', list_files))
cors.add(app.router.add_route('POST', '/{session_name}/start-service', start_service))
return app, []
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Add wsproxy_address column on scaling_group

Revision ID: 60a1effa77d2
Revises: 8679d0a7e22b
Create Date: 2021-09-17 13:19:57.525513

"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = '60a1effa77d2'
down_revision = '8679d0a7e22b'
branch_labels = None
depends_on = None


def upgrade():
op.add_column('scaling_groups', sa.Column('wsproxy_address', sa.String(length=1024), nullable=True))
Copy link
Member

@adrysn adrysn Oct 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just a question. Is there a reason to set wsproxy_address per scaling group? I originally thought that the wsproxy endpoint can be configured globally from manager.toml or somewhere else. So, I wonder if there a scenario for multiple wsproxies for a different scaling group.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For more scalability and distributed deployments, I think it would be better to have this option, although in most cases it would fallback to the manager.toml configuration value.
Note: In the future this would be refactored using RBAC (also the default storage proxy for scaling groups / allowed storage proxies for user groups as well).



def downgrade():
op.drop_colu('scaling_groups', 'wsproxy_address')
7 changes: 7 additions & 0 deletions src/ai/backend/manager/models/scaling_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
sa.Column('is_active', sa.Boolean, index=True, default=True),
sa.Column('created_at', sa.DateTime(timezone=True),
server_default=sa.func.now()),
sa.Column('wsproxy_address', sa.String(length=1024), nullable=True),
sa.Column('driver', sa.String(length=64), nullable=False),
sa.Column('driver_opts', pgsql.JSONB(), nullable=False, default={}),
sa.Column('scheduler', sa.String(length=64), nullable=False),
Expand Down Expand Up @@ -168,6 +169,7 @@ class ScalingGroup(graphene.ObjectType):
description = graphene.String()
is_active = graphene.Boolean()
created_at = GQLDateTime()
wsproxy_address = graphene.String()
driver = graphene.String()
driver_opts = graphene.JSONString()
scheduler = graphene.String()
Expand All @@ -186,6 +188,7 @@ def from_row(
description=row['description'],
is_active=row['is_active'],
created_at=row['created_at'],
wsproxy_address=row['wsproxy_address'],
driver=row['driver'],
driver_opts=row['driver_opts'],
scheduler=row['scheduler'],
Expand Down Expand Up @@ -324,6 +327,7 @@ async def batch_load_by_name(
class CreateScalingGroupInput(graphene.InputObjectType):
description = graphene.String(required=False, default='')
is_active = graphene.Boolean(required=False, default=True)
wsproxy_address = graphene.String(required=False)
driver = graphene.String(required=True)
driver_opts = graphene.JSONString(required=False, default={})
scheduler = graphene.String(required=True)
Expand All @@ -333,6 +337,7 @@ class CreateScalingGroupInput(graphene.InputObjectType):
class ModifyScalingGroupInput(graphene.InputObjectType):
description = graphene.String(required=False)
is_active = graphene.Boolean(required=False)
wsproxy_address = graphene.String(required=False)
driver = graphene.String(required=False)
driver_opts = graphene.JSONString(required=False)
scheduler = graphene.String(required=False)
Expand Down Expand Up @@ -363,6 +368,7 @@ async def mutate(
'name': name,
'description': props.description,
'is_active': bool(props.is_active),
'wsproxy_address': props.wsproxy_address,
'driver': props.driver,
'driver_opts': props.driver_opts,
'scheduler': props.scheduler,
Expand Down Expand Up @@ -399,6 +405,7 @@ async def mutate(
set_if_set(props, data, 'description')
set_if_set(props, data, 'is_active')
set_if_set(props, data, 'driver')
set_if_set(props, data, 'wsproxy_address')
set_if_set(props, data, 'driver_opts')
set_if_set(props, data, 'scheduler')
set_if_set(props, data, 'scheduler_opts')
Expand Down