Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Store consts in one file #133

Merged
merged 1 commit into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions gantry/routes/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
from gantry.clients.prometheus import PrometheusClient
from gantry.clients.prometheus.util import IncompleteData
from gantry.models import Job

MB_IN_BYTES = 1_000_000
BUILD_STAGE_REGEX = r"^stage-\d+$"
from gantry.util import const

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -47,7 +45,7 @@ async def fetch_job(
if (
job.status != "success"
# if the stage is not stage-NUMBER, it's not a build job
or not re.match(BUILD_STAGE_REGEX, payload["build_stage"])
or not re.match(const.BUILD_STAGE_REGEX, payload["build_stage"])
# some jobs don't have runners..?
or payload["runner"] is None
# uo runners are not in Prometheus
Expand Down Expand Up @@ -135,7 +133,7 @@ async def fetch_node(
"hostname": hostname,
"cores": node_labels["cores"],
# convert to bytes to be consistent with other resource metrics
"mem": node_labels["mem"] * MB_IN_BYTES,
"mem": node_labels["mem"] * const.MB_IN_BYTES,
"arch": node_labels["arch"],
"os": node_labels["os"],
"instance_type": node_labels["instance_type"],
Expand Down
32 changes: 9 additions & 23 deletions gantry/routes/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,10 @@

import aiosqlite

from gantry.util import k8s
from gantry.util import const, k8s

logger = logging.getLogger(__name__)

IDEAL_SAMPLE = 5
DEFAULT_CPU_REQUEST = 1
DEFAULT_MEM_REQUEST = 2 * 1_000_000_000 # 2GB in bytes
EXPENSIVE_VARIANTS = {
"sycl",
"mpi",
"rocm",
"cuda",
"python",
"fortran",
"openmp",
"hdf5",
}


async def predict(db: aiosqlite.Connection, spec: dict) -> dict:
"""
Expand All @@ -37,8 +23,8 @@ async def predict(db: aiosqlite.Connection, spec: dict) -> dict:
predictions = {}
if not sample:
predictions = {
"cpu_request": DEFAULT_CPU_REQUEST,
"mem_request": DEFAULT_MEM_REQUEST,
"cpu_request": const.DEFAULT_CPU_REQUEST,
"mem_request": const.DEFAULT_MEM_REQUEST,
}
else:
# mapping of sample: [0] cpu_mean, [1] cpu_max, [2] mem_mean, [3] mem_max
Expand All @@ -51,10 +37,10 @@ async def predict(db: aiosqlite.Connection, spec: dict) -> dict:
# warn if the prediction is below some thresholds
if predictions["cpu_request"] < 0.2:
logger.warning(f"Warning: CPU request for {spec} is below 0.2 cores")
predictions["cpu_request"] = DEFAULT_CPU_REQUEST
predictions["cpu_request"] = const.DEFAULT_CPU_REQUEST
if predictions["mem_request"] < 10_000_000:
logger.warning(f"Warning: Memory request for {spec} is below 10MB")
predictions["mem_request"] = DEFAULT_MEM_REQUEST
predictions["mem_request"] = const.DEFAULT_MEM_REQUEST

# convert predictions to k8s friendly format
for k, v in predictions.items():
Expand Down Expand Up @@ -104,7 +90,7 @@ async def select_sample(query: str, filters: dict, extra_params: list = []) -> l
async with db.execute(query, list(filters.values()) + extra_params) as cursor:
sample = await cursor.fetchall()
# we can accept the sample if it's 1 shorter
if len(sample) >= IDEAL_SAMPLE - 1:
if len(sample) >= const.TRAINING_SAMPLES - 1:
return sample
return []

Expand All @@ -116,7 +102,7 @@ async def select_sample(query: str, filters: dict, extra_params: list = []) -> l
query = f"""
SELECT cpu_mean, cpu_max, mem_mean, mem_max FROM jobs
WHERE ref='develop' AND {' AND '.join(f'{param}=?' for param in filters.keys())}
ORDER BY end DESC LIMIT {IDEAL_SAMPLE}
ORDER BY end DESC LIMIT {const.TRAINING_SAMPLES}
"""

if sample := await select_sample(query, filters):
Expand All @@ -132,7 +118,7 @@ async def select_sample(query: str, filters: dict, extra_params: list = []) -> l

# iterate through all the expensive variants and create a set of conditions
# for the select query
for var in EXPENSIVE_VARIANTS:
for var in const.EXPENSIVE_VARIANTS:
variant_value = spec["pkg_variants_dict"].get(var)

# check against specs where hdf5=none like quantum-espresso
Expand All @@ -157,7 +143,7 @@ async def select_sample(query: str, filters: dict, extra_params: list = []) -> l
SELECT cpu_mean, cpu_max, mem_mean, mem_max FROM jobs
WHERE ref='develop' AND {' AND '.join(f'{param}=?' for param in filters.keys())}
AND {' AND '.join(exp_variant_conditions)}
ORDER BY end DESC LIMIT {IDEAL_SAMPLE}
ORDER BY end DESC LIMIT {const.TRAINING_SAMPLES}
"""

if sample := await select_sample(query, filters, exp_variant_values):
Expand Down
36 changes: 36 additions & 0 deletions gantry/util/const.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# centralized constants for the project

# resources
MB_IN_BYTES = 1_000_000
BYTES_IN_MB = 1 / MB_IN_BYTES
MILLICORES_IN_CORES = 1_000

# spec
# example: [email protected] +json+native+treesitter arch=x86_64%[email protected]
# this regex accommodates versions made up of any non-space characters
SPACK_SPEC_PATTERN = r"(.+?)@(\S+)\s+(.+?)\s+arch=(\S+)%([\w-]+)@(\S+)"

# gitlab
# sends dates in 2021-02-23 02:41:37 UTC format
# documentation says they use iso 8601, but they don't consistently apply it
# https://docs.gitlab.com/ee/user/project/integrations/webhook_events.html#job-events
GITLAB_DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S %Z"

# collection
# all build jobs will match this pattern (eg: stage-1 but not stage-index)
BUILD_STAGE_REGEX = r"^stage-\d+$"

# prediction
TRAINING_SAMPLES = 5 # number of past builds to use for prediction
DEFAULT_CPU_REQUEST = 1 # cores
DEFAULT_MEM_REQUEST = 2 * 1_000_000_000 # 2GB in bytes
EXPENSIVE_VARIANTS = {
"sycl",
"mpi",
"rocm",
"cuda",
"python",
"fortran",
"openmp",
"hdf5",
}
7 changes: 3 additions & 4 deletions gantry/util/k8s.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
BYTES_TO_MEGABYTES = 1 / 1_000_000
CORES_TO_MILLICORES = 1_000
from gantry.util.const import BYTES_IN_MB, MILLICORES_IN_CORES

# these functions convert the predictions to k8s friendly format
# https://kubernetes.io/docs/concepts/configuration/manage-resources-containers


def convert_bytes(bytes: float) -> str:
"""bytes to megabytes"""
return str(int(round(bytes * BYTES_TO_MEGABYTES))) + "M"
return str(int(round(bytes * BYTES_IN_MB))) + "M"


def convert_cores(cores: float) -> str:
"""cores to millicores"""
return str(int(round(cores * CORES_TO_MILLICORES))) + "m"
return str(int(round(cores * MILLICORES_IN_CORES))) + "m"
6 changes: 3 additions & 3 deletions gantry/util/spec.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
import re

from gantry.util.const import SPACK_SPEC_PATTERN


def spec_variants(spec: str) -> dict:
"""Given a spec's concrete variants, return a dict in name: value format."""
Expand Down Expand Up @@ -51,9 +53,7 @@ def parse_alloc_spec(spec: str) -> dict:
for the client.
"""

# example: [email protected] +json+native+treesitter arch=x86_64%[email protected]
# this regex accommodates versions made up of any non-space characters
spec_pattern = re.compile(r"(.+?)@(\S+)\s+(.+?)\s+arch=(\S+)%([\w-]+)@(\S+)")
spec_pattern = re.compile(SPACK_SPEC_PATTERN)

match = spec_pattern.match(spec)
if not match:
Expand Down
11 changes: 4 additions & 7 deletions gantry/util/time.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import datetime

from gantry.util.const import GITLAB_DATETIME_FORMAT


def webhook_timestamp(dt: str) -> float:
"""Converts a gitlab webhook datetime to a unix timestamp."""
# gitlab sends dates in 2021-02-23 02:41:37 UTC format
# documentation says they use iso 8601, but they don't consistently apply it
# https://docs.gitlab.com/ee/user/project/integrations/webhook_events.html#job-events
GITLAB_DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S %Z"
# strptime doesn't tag with timezone by default
return (
datetime.datetime.strptime(dt, GITLAB_DATETIME_FORMAT)
.replace(tzinfo=datetime.timezone.utc)
.timestamp()
# strptime doesn't tag with timezone by default
.replace(tzinfo=datetime.timezone.utc).timestamp()
)