Skip to content

Commit

Permalink
Improve caching job details and fetch more job details
Browse files Browse the repository at this point in the history
  • Loading branch information
harshthakkar01 committed Nov 6, 2024
1 parent b857af0 commit b59abc6
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,18 @@
)


KeepStates = frozenset(
[
"RUNNING",
"CONFIGURING",
"STOPPED",
"SUSPENDED",
"COMPLETING",
"PENDING",
]
)


def start_instance_op(inst):
return lookup().compute.instances().start(
project=lookup().project,
Expand Down Expand Up @@ -327,20 +339,11 @@ def ignore_err(e) -> bool:

def sync_placement_groups():
"""Delete placement policies that are for jobs that have completed/terminated"""
keep_states = frozenset(
[
"RUNNING",
"CONFIGURING",
"STOPPED",
"SUSPENDED",
"COMPLETING",
]
)

keep_jobs = {
str(job["job_id"])
for job in json.loads(run(f"{lookup().scontrol} show jobs --json").stdout)["jobs"]
if "job_state" in job and set(job["job_state"]) & keep_states
str(job.id)
for job in lookup().get_jobs()
if job.job_state and job.job_state in KeepStates
}
keep_jobs.add("0") # Job 0 is a placeholder for static node placement

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def _fill_cfg_defaults(cfg: NSDict) -> NSDict:
"mount_options": "defaults,hard,intr,_netdev",
}
)

network_storage_iter = filter(
None,
(
Expand Down Expand Up @@ -474,8 +474,8 @@ def _download(bs) -> List[Any]:
), hash

def _assemble_config(
core: Any,
partitions: List[Any],
core: Any,
partitions: List[Any],
nodesets: List[Any],
nodesets_dyn: List[Any],
nodesets_tpu: List[Any],
Expand Down Expand Up @@ -510,17 +510,17 @@ def _add_nodesets(yamls: List[Any], target: dict):
for ns_name in chain(p.partition_nodeset, p.partition_nodeset_dyn, p.partition_nodeset_tpu):
if ns_name not in ns_names:
raise DeffetiveStoredConfigError(f"nodeset {ns_name} not defined in config")

return _fill_cfg_defaults(cfg)

def fetch_config() -> Tuple[bool, NSDict]:
"""
Fetches config from bucket and saves it locally
Fetches config from bucket and saves it locally
Returns True if new (updated) config was fetched
"""
hash_file = Path("/slurm/scripts/.config.hash")
old_hash = hash_file.read_text() if hash_file.exists() else None

cfg_and_hash = _fetch_config(old_hash=old_hash)
if not cfg_and_hash:
return False, _load_config()
Expand Down Expand Up @@ -1460,8 +1460,12 @@ class ReservationDetails:
@dataclass
class Job:
id: int
name: Optional[str] = None
required_nodes: Optional[str] = None
job_state: Optional[str] = None
duration: Optional[timedelta] = None


class Lookup:
"""Wrapper class for cached data access"""

Expand Down Expand Up @@ -1757,11 +1761,11 @@ def _get_reservation(self, project: str, zone: str, name: str) -> object:
"""See https://cloud.google.com/compute/docs/reference/rest/v1/reservations"""
return self.compute.reservations().get(
project=project, zone=zone, reservation=name).execute()

def nodeset_reservation(self, nodeset: object) -> Optional[ReservationDetails]:
if not nodeset.reservation_name:
return None

zones = list(nodeset.zone_policy_allow or [])
assert len(zones) == 1, "Only single zone is supported if using a reservation"
zone = zones[0]
Expand All @@ -1771,7 +1775,7 @@ def nodeset_reservation(self, nodeset: object) -> Optional[ReservationDetails]:
raise ValueError(
f"Invalid reservation name: '{nodeset.reservation_name}', expected format is 'projects/PROJECT/reservations/NAME'"
)

project, name = match.group("project", "reservation")
reservation = self._get_reservation(project, zone, name)

Expand Down Expand Up @@ -1928,26 +1932,58 @@ def nodeset_map(self, hostnames: list):
nodeset_map[self.node_nodeset_name(node)].append(node)
return nodeset_map

def _get_job_info(self, job_info, job_id=None) -> Optional[Job]:
"""Extract job details"""
if job_id is None:
if match:= re.search(r"JobId=(\d+)", job_info):
job_id = match.group(1)

if match:= re.search(r"TimeLimit=(?:(\d+)-)?(\d{2}):(\d{2}):(\d{2})", job_info):
days, hours, minutes, seconds = match.groups()
duration = timedelta(
days=int(days) if days else 0,
hours=int(hours),
minutes=int(minutes),
seconds=int(seconds)
)
else:
duration = None

if match := re.search(r"JobName=(\w+)", job_info):
name = match.group(1)
else:
name = None

if match := re.search(r"JobState=(\w+)", job_info):
job_state = match.group(1)
else:
job_state = None

if match := re.search(r"ReqNodeList=(\w+)", job_info):
required_nodes = match.group(1)
else:
required_nodes = None

return Job(id=job_id, duration=duration, name=name, job_state=job_state, required_nodes=required_nodes)

@lru_cache
def job(self, job_id: int) -> Optional[Job]:
jobInfo = run(f"{self.scontrol} show jobid {job_id}", check=False).stdout.rstrip()
if not jobInfo:
return None
def get_jobs(self) -> List[Job]:
res = run(f"{self.scontrol} show jobs", timeout=30)
all_jobs = [x for x in res.stdout.split("\n\n")[:-1]]

timePattern = r"TimeLimit=(?:(\d+)-)?(\d{2}):(\d{2}):(\d{2})"
match = re.search(timePattern, jobInfo)
res_jobs: List[Job] = []
for job in all_jobs:
res_jobs.append(self._get_job_info(job_info=job))

if not match:
return Job(id=job_id)
return res_jobs

days, hours, minutes, seconds = match.groups()
job_duration = timedelta(
days=int(days) if days else 0,
hours=int(hours),
minutes=int(minutes),
seconds=int(seconds)
)
return Job(id=job_id, duration=job_duration)
@lru_cache
def job(self, job_id: int) -> Optional[Job]:
job_info = run(f"{self.scontrol} show jobid {job_id}", check=False).stdout.rstrip()
if not job_info:
return None

return self._get_job_info(job_id=job_id, job_info=job_info)

@property
def etc_dir(self) -> Path:
Expand Down

0 comments on commit b59abc6

Please sign in to comment.