Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for prediction API #89

Merged
merged 6 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ The spec sent to the endpoint should have the following format:
pkg_name@pkg_version +variant1+variant2%compiler@compiler_version
```

Be sure that the string is URL-encoded. For instance, the `urllib.parse.quote` method will ensure the proper format. Without it, the allocation algorithm may return inaccurate results.

**There must be a space after the package version in order to account for variant parsing.**

If the request does not contain a valid spec, the API will respond with `400 Bad Request`. The maximum allowed size of the `GET` request is 8190 bytes.
Expand Down
1 change: 1 addition & 0 deletions gantry/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ async def apply_migrations(db: aiosqlite.Connection):
# they are applied in the correct order
# and not inadvertently added to the migrations folder
("001_initial.sql", 1),
("002_spec_index.sql", 2),
]

# apply migrations that have not been applied
Expand Down
18 changes: 10 additions & 8 deletions gantry/routes/prediction/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
logger = logging.getLogger(__name__)

IDEAL_SAMPLE = 5
DEFAULT_CPU_REQUEST = 1.0
DEFAULT_CPU_REQUEST = 1
DEFAULT_MEM_REQUEST = 2 * 1_000_000_000 # 2GB in bytes
EXPENSIVE_VARIANTS = {
"sycl",
Expand Down Expand Up @@ -49,17 +49,16 @@ async def predict(db: aiosqlite.Connection, spec: dict, strategy: str = None) ->
# mapping of sample: [0] cpu_mean, [1] cpu_max, [2] mem_mean, [3] mem_max
predictions = {
# averages the respective metric in the sample
# cpu should always be whole number
"cpu_request": round(sum([build[0] for build in sample]) / len(sample)),
"cpu_request": sum([build[0] for build in sample]) / len(sample),
"mem_request": sum([build[2] for build in sample]) / len(sample),
}

if strategy == "ensure_higher":
ensure_higher_pred(predictions, spec["pkg_name"])

# warn if the prediction is below some thresholds
if predictions["cpu_request"] < 0.25:
logger.warning(f"Warning: CPU request for {spec} is below 0.25 cores")
if predictions["cpu_request"] < 0.2:
logger.warning(f"Warning: CPU request for {spec} is below 0.2 cores")
predictions["cpu_request"] = DEFAULT_CPU_REQUEST
if predictions["mem_request"] < 10_000_000:
logger.warning(f"Warning: Memory request for {spec} is below 10MB")
Expand All @@ -68,7 +67,7 @@ async def predict(db: aiosqlite.Connection, spec: dict, strategy: str = None) ->
# convert predictions to k8s friendly format
for k, v in predictions.items():
if k.startswith("cpu"):
predictions[k] = str(int(v))
predictions[k] = k8s.convert_cores(v)
elif k.startswith("mem"):
predictions[k] = k8s.convert_bytes(v)

Expand Down Expand Up @@ -142,14 +141,17 @@ async def select_sample(query: str, filters: dict, extra_params: list = []) -> l
# iterate through all the expensive variants and create a set of conditions
# for the select query
for var in EXPENSIVE_VARIANTS:
if var in spec["pkg_variants_dict"]:
variant_value = spec["pkg_variants_dict"].get(var)

# check against specs where hdf5=none like quantum-espresso
if isinstance(variant_value, (bool, int)):
# if the client has queried for an expensive variant, we want to ensure
# that the sample has the same exact value
exp_variant_conditions.append(
f"json_extract(pkg_variants, '$.{var}')=?"
)

exp_variant_values.append(int(spec["pkg_variants_dict"].get(var, 0)))
exp_variant_values.append(int(variant_value))
else:
# if an expensive variant was not queried for,
# we want to make sure that the variant was not set within the sample
Expand Down
6 changes: 3 additions & 3 deletions gantry/tests/defs/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@
# calculated by running the baseline prediction algorithm on the sample data in gantry/tests/sql/insert_prediction.sql
NORMAL_PREDICTION = {
"variables": {
"KUBERNETES_CPU_REQUEST": "12",
"KUBERNETES_MEMORY_REQUEST": "9576M",
"KUBERNETES_CPU_REQUEST": "11779m",
"KUBERNETES_MEMORY_REQUEST": "9577M",
},
}

# this is what will get returned when there are no samples in the database
# that match what the client wants
DEFAULT_PREDICTION = {
"variables": {
"KUBERNETES_CPU_REQUEST": "1",
"KUBERNETES_CPU_REQUEST": "1000m",
"KUBERNETES_MEMORY_REQUEST": "2000M",
},
}
4 changes: 2 additions & 2 deletions gantry/tests/test_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ async def test_empty_sample(db_conn):
# Test validate_payload
def test_valid_spec():
"""Tests that a valid spec is parsed correctly."""
assert parse_alloc_spec("[email protected] +json+native+treesitter%[email protected]") == {
assert parse_alloc_spec("[email protected]-test +json+native+treesitter%[email protected]") == {
"pkg_name": "emacs",
"pkg_version": "29.2",
"pkg_version": "29.2-test",
"pkg_variants": '{"json": true, "native": true, "treesitter": true}',
"pkg_variants_dict": {"json": True, "native": True, "treesitter": True},
"compiler_name": "gcc",
Expand Down
9 changes: 8 additions & 1 deletion gantry/util/k8s.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
BYTES_TO_MEGABYTES = 1 / 1_000_000
CORES_TO_MILLICORES = 1_000

# these functions convert the predictions to k8s friendly format
# https://kubernetes.io/docs/concepts/configuration/manage-resources-containers


def convert_bytes(bytes: float) -> str:
"""bytes to megabytes"""
return str(int(bytes * BYTES_TO_MEGABYTES)) + "M"
return str(int(round(bytes * BYTES_TO_MEGABYTES))) + "M"


def convert_cores(cores: float) -> str:
"""cores to millicores"""
return str(int(round(cores * CORES_TO_MILLICORES))) + "m"
3 changes: 2 additions & 1 deletion gantry/util/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def parse_alloc_spec(spec: str) -> dict:
"""

# example: [email protected] +json+native+treesitter%[email protected]
spec_pattern = re.compile(r"(.+?)@([\d.]+)\s+(.+?)%([\w-]+)@([\d.]+)")
# this regex accommodates versions made up of any non-space characters
spec_pattern = re.compile(r"(.+?)@(\S+)\s+(.+?)%([\w-]+)@(\S+)")

match = spec_pattern.match(spec)
if not match:
Expand Down
1 change: 1 addition & 0 deletions migrations/002_spec_index.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE INDEX complete_spec on jobs(pkg_name, pkg_variants, pkg_version, compiler_name, compiler_version, end);