Skip to content
This repository has been archived by the owner on Aug 2, 2023. It is now read-only.

Commit

Permalink
Add start_service API (#479)
Browse files Browse the repository at this point in the history
This API is dedicated to start a container application from container's
service definitions, while existing `stream_proxy` API performs both
`start_service` and then actual websocket proxying, for wsproxy v2.

Note: wsproxy v2 directly connects the user clients and the container
apps to provide better horizontal scalability. The separation of those
two functions is the key starting point.

Co-authored-by: Jonghyun Park <[email protected]>
Backported-From: main (22.03)
Backported-To: 21.09
  • Loading branch information
2 people authored and achimnol committed Dec 17, 2021
1 parent 08675dc commit 38d54bb
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 0 deletions.
1 change: 1 addition & 0 deletions changes/479.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `session.start_service` API to support wsproxy v2
54 changes: 54 additions & 0 deletions src/ai/backend/manager/api/scaling_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,19 @@
)

from aiohttp import web
import aiohttp
import aiohttp_cors
import aiotools
from dataclasses import dataclass, field
import trafaret as t

from ai.backend.common import validators as tx
from ai.backend.common.logging import BraceStyleAdapter

from ai.backend.manager.api.exceptions import GenericNotFound

from ai.backend.manager.models.utils import ExtendedAsyncSAEngine

from ..models import (
query_allowed_sgroups,
)
Expand All @@ -29,6 +36,35 @@
log = BraceStyleAdapter(logging.getLogger(__name__))


@dataclass(unsafe_hash=True)
class WSProxyVersionQueryParams:
db_ctx: ExtendedAsyncSAEngine = field(hash=False)
access_key: str = field(hash=False)
domain_name: str = field(hash=False)


@aiotools.lru_cache(expire_after=30) # expire after 30 seconds
async def query_wsproxy_version(
params: WSProxyVersionQueryParams,
group_id_or_name: str,
) -> str:
async with params.db_ctx.begin_readonly() as conn:
sgroups = await query_allowed_sgroups(
conn, params.domain_name, group_id_or_name, params.access_key)

if len(sgroups) == 0:
raise GenericNotFound

wsproxy_addr = sgroups[0]['wsproxy_addr']
if not wsproxy_addr:
return 'v1'
else:
async with aiohttp.ClientSession() as session:
async with session.get(wsproxy_addr + '/status') as resp:
version_json = await resp.json()
return version_json['api_version']


@auth_required
@server_status_required(READ_ALLOWED)
@check_api_params(
Expand All @@ -53,6 +89,23 @@ async def list_available_sgroups(request: web.Request, params: Any) -> web.Respo
}, status=200)


@auth_required
@server_status_required(READ_ALLOWED)
async def get_wsproxy_version(request: web.Request) -> web.Response:
root_ctx: RootContext = request.app['_root.context']
access_key = request['keypair']['access_key']
domain_name = request['user']['domain_name']
group_id_or_name = request.match_info['scaling_group']

params = WSProxyVersionQueryParams(
db_ctx=root_ctx.db,
access_key=access_key,
domain_name=domain_name,
)
wsproxy_version = await query_wsproxy_version(params, group_id_or_name)
return web.json_response({'version': wsproxy_version})


async def init(app: web.Application) -> None:
pass

Expand All @@ -70,4 +123,5 @@ def create_app(default_cors_options: CORSOptions) -> Tuple[web.Application, Iter
cors = aiohttp_cors.setup(app, defaults=default_cors_options)
root_resource = cors.add(app.router.add_resource(r''))
cors.add(root_resource.add_route('GET', list_available_sgroups))
cors.add(app.router.add_route('GET', '/{scaling_group}/wsproxy-version', get_wsproxy_version))
return app, []
99 changes: 99 additions & 0 deletions src/ai/backend/manager/api/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
TYPE_CHECKING,
cast,
)
from urllib.parse import urlparse
import uuid

import aiohttp
Expand All @@ -43,6 +44,7 @@
import sqlalchemy as sa
from sqlalchemy.sql.expression import true, null
import trafaret as t

if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection

Expand Down Expand Up @@ -75,6 +77,7 @@
from ai.backend.common.logging import BraceStyleAdapter
from ai.backend.common.utils import cancel_tasks, str_to_timedelta
from ai.backend.common.types import (
AccessKey,
AgentId,
KernelId,
ClusterMode,
Expand All @@ -91,6 +94,7 @@
association_groups_users as agus, groups,
keypairs, kernels, query_bootstrap_script,
keypair_resource_policies,
scaling_groups,
users, UserRole,
vfolders,
AgentStatus, KernelStatus,
Expand All @@ -102,10 +106,12 @@
from ..models.kernel import match_session_ids
from ..models.utils import execute_with_retry
from .exceptions import (
AppNotFound,
InvalidAPIParameters,
GenericNotFound,
ImageNotFound,
InsufficientPrivilege,
ServiceUnavailable,
SessionNotFound,
SessionAlreadyExists,
TooManySessionsMatched,
Expand Down Expand Up @@ -1076,6 +1082,98 @@ async def create_cluster(request: web.Request, params: Any) -> web.Response:
return web.json_response(resp, status=201)


@server_status_required(READ_ALLOWED)
@auth_required
@check_api_params(
t.Dict({
tx.AliasedKey(['app', 'service']): t.String,
# The port argument is only required to use secondary ports
# when the target app listens multiple TCP ports.
# Otherwise it should be omitted or set to the same value of
# the actual port number used by the app.
tx.AliasedKey(['port'], default=None): t.Null | t.Int[1024:65535],
tx.AliasedKey(['envs'], default=None): t.Null | t.String, # stringified JSON
# e.g., '{"PASSWORD": "12345"}'
tx.AliasedKey(['arguments'], default=None): t.Null | t.String, # stringified JSON
# e.g., '{"-P": "12345"}'
# The value can be one of:
# None, str, List[str]
}))
async def start_service(request: web.Request, params: Mapping[str, Any]) -> web.Response:
root_ctx: RootContext = request.app['_root.context']
session_name: str = request.match_info['session_name']
access_key: AccessKey = request['keypair']['access_key']
service: str = params['app']
myself = asyncio.current_task()
assert myself is not None
try:
kernel = await asyncio.shield(root_ctx.registry.get_session(session_name, access_key))
except (SessionNotFound, TooManySessionsMatched):
raise

query = (sa.select([scaling_groups.c.wsproxy_addr])
.select_from(scaling_groups)
.where((scaling_groups.c.name == kernel['scaling_group'])))

async with root_ctx.db.begin_readonly() as conn:
result = await conn.execute(query)
sgroup = result.first()
wsproxy_addr = sgroup['wsproxy_addr']
if not wsproxy_addr:
raise ServiceUnavailable('No coordinator configured for this resource group')

if kernel['kernel_host'] is None:
kernel_host = urlparse(kernel['agent_addr']).hostname
else:
kernel_host = kernel['kernel_host']
for sport in kernel['service_ports']:
if sport['name'] == service:
if params['port']:
# using one of the primary/secondary ports of the app
try:
hport_idx = sport['container_ports'].index(params['port'])
except ValueError:
raise InvalidAPIParameters(
f"Service {service} does not open the port number {params['port']}.")
host_port = sport['host_ports'][hport_idx]
else:
# using the default (primary) port of the app
if 'host_ports' not in sport:
host_port = sport['host_port'] # legacy kernels
else:
host_port = sport['host_ports'][0]
break
else:
raise AppNotFound(f'{session_name}:{service}')

await asyncio.shield(root_ctx.registry.increment_session_usage(session_name, access_key))

opts: MutableMapping[str, Union[None, str, List[str]]] = {}
if params['arguments'] is not None:
opts['arguments'] = json.loads(params['arguments'])
if params['envs'] is not None:
opts['envs'] = json.loads(params['envs'])

result = await asyncio.shield(
root_ctx.registry.start_service(session_name, access_key, service, opts),
)
if result['status'] == 'failed':
raise InternalServerError(
"Failed to launch the app service",
extra_data=result['error'])

async with aiohttp.ClientSession() as session:
async with session.post(f'{wsproxy_addr}/v2/conf', json={
'kernel_host': kernel_host,
'kernel_port': host_port,
}) as resp:
token_json = await resp.json()
return web.json_response({
'token': token_json['token'],
'wsproxy_addr': wsproxy_addr,
})


async def handle_kernel_creation_lifecycle(
app: web.Application,
source: AgentId,
Expand Down Expand Up @@ -2010,4 +2108,5 @@ def create_app(default_cors_options: CORSOptions) -> Tuple[web.Application, Iter
cors.add(app.router.add_route('GET', '/{session_name}/download', download_files))
cors.add(app.router.add_route('GET', '/{session_name}/download_single', download_single))
cors.add(app.router.add_route('GET', '/{session_name}/files', list_files))
cors.add(app.router.add_route('POST', '/{session_name}/start-service', start_service))
return app, []
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Add wsproxy_addr column on scaling_group
Revision ID: 60a1effa77d2
Revises: 8679d0a7e22b
Create Date: 2021-09-17 13:19:57.525513
"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = '60a1effa77d2'
down_revision = '8679d0a7e22b'
branch_labels = None
depends_on = None


def upgrade():
op.add_column('scaling_groups', sa.Column('wsproxy_addr', sa.String(length=1024), nullable=True))


def downgrade():
op.drop_column('scaling_groups', 'wsproxy_addr')
7 changes: 7 additions & 0 deletions src/ai/backend/manager/models/scaling_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
sa.Column('is_active', sa.Boolean, index=True, default=True),
sa.Column('created_at', sa.DateTime(timezone=True),
server_default=sa.func.now()),
sa.Column('wsproxy_addr', sa.String(length=1024), nullable=True),
sa.Column('driver', sa.String(length=64), nullable=False),
sa.Column('driver_opts', pgsql.JSONB(), nullable=False, default={}),
sa.Column('scheduler', sa.String(length=64), nullable=False),
Expand Down Expand Up @@ -168,6 +169,7 @@ class ScalingGroup(graphene.ObjectType):
description = graphene.String()
is_active = graphene.Boolean()
created_at = GQLDateTime()
wsproxy_addr = graphene.String()
driver = graphene.String()
driver_opts = graphene.JSONString()
scheduler = graphene.String()
Expand All @@ -186,6 +188,7 @@ def from_row(
description=row['description'],
is_active=row['is_active'],
created_at=row['created_at'],
wsproxy_addr=row['wsproxy_addr'],
driver=row['driver'],
driver_opts=row['driver_opts'],
scheduler=row['scheduler'],
Expand Down Expand Up @@ -324,6 +327,7 @@ async def batch_load_by_name(
class CreateScalingGroupInput(graphene.InputObjectType):
description = graphene.String(required=False, default='')
is_active = graphene.Boolean(required=False, default=True)
wsproxy_addr = graphene.String(required=False)
driver = graphene.String(required=True)
driver_opts = graphene.JSONString(required=False, default={})
scheduler = graphene.String(required=True)
Expand All @@ -333,6 +337,7 @@ class CreateScalingGroupInput(graphene.InputObjectType):
class ModifyScalingGroupInput(graphene.InputObjectType):
description = graphene.String(required=False)
is_active = graphene.Boolean(required=False)
wsproxy_addr = graphene.String(required=False)
driver = graphene.String(required=False)
driver_opts = graphene.JSONString(required=False)
scheduler = graphene.String(required=False)
Expand Down Expand Up @@ -363,6 +368,7 @@ async def mutate(
'name': name,
'description': props.description,
'is_active': bool(props.is_active),
'wsproxy_addr': props.wsproxy_addr,
'driver': props.driver,
'driver_opts': props.driver_opts,
'scheduler': props.scheduler,
Expand Down Expand Up @@ -399,6 +405,7 @@ async def mutate(
set_if_set(props, data, 'description')
set_if_set(props, data, 'is_active')
set_if_set(props, data, 'driver')
set_if_set(props, data, 'wsproxy_addr')
set_if_set(props, data, 'driver_opts')
set_if_set(props, data, 'scheduler')
set_if_set(props, data, 'scheduler_opts')
Expand Down

0 comments on commit 38d54bb

Please sign in to comment.