Skip to content
This repository has been archived by the owner on Aug 2, 2023. It is now read-only.

Commit

Permalink
fix: Don't allow empty values in allowed session types and scan all a…
Browse files Browse the repository at this point in the history
…llowed scaling groups (#542)

* Improve the failure messages and rewrite the exclusion logic.
* Add a new test case to demonstrate how to write test cases with mocking.

Co-authored-by: Joongi Kim <[email protected]>
  • Loading branch information
vesselofgod and achimnol authored Mar 4, 2022
1 parent b947549 commit d04d3c1
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 16 deletions.
1 change: 1 addition & 0 deletions changes/542.fix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Disallow empty values in `allowed_session_types` of scheduler options and fix the predicate check to scan all scaling groups
2 changes: 1 addition & 1 deletion src/ai/backend/manager/models/scaling_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
sa.Column('scheduler_opts', StructuredJSONBColumn(
t.Dict({
t.Key('allowed_session_types', default=['interactive', 'batch']):
t.List(tx.Enum(SessionTypes)),
t.List(tx.Enum(SessionTypes), min_length=1),
}).allow_extra('*'),
), nullable=False, default={}),
)
Expand Down
43 changes: 28 additions & 15 deletions src/ai/backend/manager/scheduler/predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,36 +211,49 @@ async def _query():
)

sgroups = await execute_with_retry(_query)
if not sgroups:
return PredicateResult(
False,
"You do not have any scaling groups allowed to use.",
permanent=True,
)
target_sgroup_names: List[str] = []
preferred_sgroup_name = sess_ctx.scaling_group
if preferred_sgroup_name is not None:
# Consider only the preferred scaling group.
for sgroup in sgroups:
if preferred_sgroup_name == sgroup['name']:
break
else:
return PredicateResult(
False,
f"The given preferred scaling group is not allowed to use. "
f"({preferred_sgroup_name})",
f"You do not have access to the scaling group '{preferred_sgroup_name}'.",
permanent=True,
)
# Consider agents only in the preferred scaling group.
target_sgroup_names = [preferred_sgroup_name]
else:
# Consider all agents in all allowed scaling groups.
target_sgroup_names = [sgroup['name'] for sgroup in sgroups]
log.debug('considered scaling groups: {}', target_sgroup_names)
if not target_sgroup_names:
return PredicateResult(
False,
"No available resource in scaling groups.",
)
for sgroup in sgroups:
allowed_session_types = sgroup['scheduler_opts']['allowed_session_types']
if sess_ctx.session_type.value.lower() not in allowed_session_types:
return PredicateResult(
False,
"Not allowed session type in scaling groups.",
f"The scaling group '{preferred_sgroup_name}' does not accept "
f"the session type '{sess_ctx.session_type}'. ",
permanent=True,
)
target_sgroup_names = [preferred_sgroup_name]
else:
# Consider all allowed scaling groups.
usable_sgroups = []
for sgroup in sgroups:
allowed_session_types = sgroup['scheduler_opts']['allowed_session_types']
if sess_ctx.session_type.value.lower() in allowed_session_types:
usable_sgroups.append(sgroup)
if not usable_sgroups:
return PredicateResult(
False,
f"No scaling groups accept the session type '{sess_ctx.session_type}'.",
permanent=True,
)
target_sgroup_names = [sgroup['name'] for sgroup in usable_sgroups]
assert target_sgroup_names
log.debug("scaling groups considered for s:{} are {}", sess_ctx.session_id, target_sgroup_names)
sess_ctx.target_sgroup_names.extend(target_sgroup_names)
return PredicateResult(True)
161 changes: 161 additions & 0 deletions tests/test_predicates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from __future__ import annotations

from unittest import mock
from unittest.mock import MagicMock

import pytest

from ai.backend.common.types import SessionTypes
from ai.backend.manager.scheduler.predicates import check_scaling_group


@pytest.mark.asyncio
@mock.patch('ai.backend.manager.scheduler.predicates.execute_with_retry')
async def test_allowed_session_types_check(mock_query):
mock_query.return_value = [
{
'name': 'a',
'scheduler_opts': {
'allowed_session_types': ['batch'],
},
},
{
'name': 'b',
'scheduler_opts': {
'allowed_session_types': ['interactive'],
},
},
{
'name': 'c',
'scheduler_opts': {
'allowed_session_types': ['batch', 'interactive'],
},
},
]
mock_conn = MagicMock()
mock_sched_ctx = MagicMock()
mock_sess_ctx = MagicMock()

# Preferred scaling group with one match in allowed sgroups

mock_sess_ctx.session_type = SessionTypes.BATCH
mock_sess_ctx.scaling_group = 'a'
mock_sess_ctx.target_sgroup_names = []
result = await check_scaling_group(mock_conn, mock_sched_ctx, mock_sess_ctx)
assert result.passed
assert mock_sess_ctx.target_sgroup_names == ['a']

mock_sess_ctx.session_type = SessionTypes.BATCH
mock_sess_ctx.scaling_group = 'b'
mock_sess_ctx.target_sgroup_names = []
result = await check_scaling_group(mock_conn, mock_sched_ctx, mock_sess_ctx)
assert not result.passed
assert result.message is not None
assert "does not accept" in result.message
assert mock_sess_ctx.target_sgroup_names == []

mock_sess_ctx.session_type = SessionTypes.BATCH
mock_sess_ctx.scaling_group = 'c'
mock_sess_ctx.target_sgroup_names = []
result = await check_scaling_group(mock_conn, mock_sched_ctx, mock_sess_ctx)
assert result.passed
assert mock_sess_ctx.target_sgroup_names == ['c']

mock_sess_ctx.session_type = SessionTypes.INTERACTIVE
mock_sess_ctx.scaling_group = 'a'
mock_sess_ctx.target_sgroup_names = []
result = await check_scaling_group(mock_conn, mock_sched_ctx, mock_sess_ctx)
assert not result.passed
assert result.message is not None
assert "does not accept" in result.message
assert mock_sess_ctx.target_sgroup_names == []

mock_sess_ctx.session_type = SessionTypes.INTERACTIVE
mock_sess_ctx.scaling_group = 'b'
mock_sess_ctx.target_sgroup_names = []
result = await check_scaling_group(mock_conn, mock_sched_ctx, mock_sess_ctx)
assert result.passed
assert mock_sess_ctx.target_sgroup_names == ['b']

mock_sess_ctx.session_type = SessionTypes.INTERACTIVE
mock_sess_ctx.scaling_group = 'c'
mock_sess_ctx.target_sgroup_names = []
result = await check_scaling_group(mock_conn, mock_sched_ctx, mock_sess_ctx)
assert result.passed
assert mock_sess_ctx.target_sgroup_names == ['c']

# Non-existent/disallowed preferred scaling group

mock_sess_ctx.session_type = SessionTypes.INTERACTIVE
mock_sess_ctx.scaling_group = 'x'
mock_sess_ctx.target_sgroup_names = []
result = await check_scaling_group(mock_conn, mock_sched_ctx, mock_sess_ctx)
assert not result.passed
assert result.message is not None
assert "do not have access" in result.message
assert mock_sess_ctx.target_sgroup_names == []

# No preferred scaling group with partially matching allowed sgroups

mock_sess_ctx.session_type = SessionTypes.BATCH
mock_sess_ctx.scaling_group = None
mock_sess_ctx.target_sgroup_names = []
result = await check_scaling_group(mock_conn, mock_sched_ctx, mock_sess_ctx)
assert result.passed
assert mock_sess_ctx.target_sgroup_names == ['a', 'c']

mock_sess_ctx.session_type = SessionTypes.INTERACTIVE
mock_sess_ctx.scaling_group = None
mock_sess_ctx.target_sgroup_names = []
result = await check_scaling_group(mock_conn, mock_sched_ctx, mock_sess_ctx)
assert result.passed
assert mock_sess_ctx.target_sgroup_names == ['b', 'c']

# No preferred scaling group with an empty list of allowed sgroups

mock_query.return_value = []

mock_sess_ctx.session_type = SessionTypes.BATCH
mock_sess_ctx.scaling_group = 'x'
mock_sess_ctx.target_sgroup_names = []
result = await check_scaling_group(mock_conn, mock_sched_ctx, mock_sess_ctx)
assert not result.passed
assert result.message is not None
assert "do not have any" in result.message
assert mock_sess_ctx.target_sgroup_names == []

mock_sess_ctx.session_type = SessionTypes.INTERACTIVE
mock_sess_ctx.scaling_group = 'x'
mock_sess_ctx.target_sgroup_names = []
result = await check_scaling_group(mock_conn, mock_sched_ctx, mock_sess_ctx)
assert not result.passed
assert result.message is not None
assert "do not have any" in result.message
assert mock_sess_ctx.target_sgroup_names == []

# No preferred scaling group with a non-empty list of allowed sgroups

mock_query.return_value = [
{
'name': 'a',
'scheduler_opts': {
'allowed_session_types': ['batch'],
},
},
]

mock_sess_ctx.session_type = SessionTypes.BATCH
mock_sess_ctx.scaling_group = None
mock_sess_ctx.target_sgroup_names = []
result = await check_scaling_group(mock_conn, mock_sched_ctx, mock_sess_ctx)
assert result.passed
assert mock_sess_ctx.target_sgroup_names == ['a']

mock_sess_ctx.session_type = SessionTypes.INTERACTIVE
mock_sess_ctx.scaling_group = None
mock_sess_ctx.target_sgroup_names = []
result = await check_scaling_group(mock_conn, mock_sched_ctx, mock_sess_ctx)
assert not result.passed
assert result.message is not None
assert "No scaling groups accept" in result.message
assert mock_sess_ctx.target_sgroup_names == []

0 comments on commit d04d3c1

Please sign in to comment.