Skip to content

Commit

Permalink
Allow required_variable to only enforce on certain modes
Browse files Browse the repository at this point in the history
gcp-metadata mod can use this, as it does not need `hostlist` when run
in local mode.
  • Loading branch information
linsword13 committed Oct 31, 2024
1 parent 533121e commit dbfd0c4
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 3 deletions.
2 changes: 1 addition & 1 deletion lib/ramble/ramble/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,7 @@ def build_modifier_instances(self):
self._modifier_instances.append(mod_inst)

# Add this modifiers required variables for validation
self.keywords.update_keys(mod_inst.required_vars)
self.keywords.update_keys(mod_inst.get_required_variables())

# Ensure no expand vars are set correctly for modifiers
for mod_inst in self._modifier_instances:
Expand Down
6 changes: 5 additions & 1 deletion lib/ramble/ramble/language/modifier_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,19 +256,23 @@ def _env_var_modification(mod):


@modifier_directive("required_vars")
def required_variable(var: str, results_level="variable"):
def required_variable(var: str, results_level="variable", modes=None):
"""Mark a variable as being required by this modifier
Args:
var (str): Variable name to mark as required
results_level (str): 'variable' or 'key'. If 'key', variable is promoted to
a key within JSON or YAML formatted results.
modes (list[str] | None): modes that the required check should be applied. The
default None means apply to all modes.
"""

def _mark_required_var(mod):
mod.required_vars[var] = {
"type": ramble.keywords.key_type.required,
"level": ramble.keywords.output_level.variable,
# Extra prop that's only used for filtering
"modes": set(modes) if modes is not None else None,
}

return _mark_required_var
Expand Down
16 changes: 16 additions & 0 deletions lib/ramble/ramble/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,22 @@ def _prepare_analysis(self, workspace):
"""
pass

def get_required_variables(self):
"""Get all the required variables based on the mode."""
required_vars = self.required_vars
filtered_vars = {}
if required_vars:
mode = self._usage_mode
for var_name, var_props in required_vars.items():
modes = var_props["modes"]
if modes is None or mode in modes:
filtered_vars[var_name] = {
# Exclude the extra modes prop
k: var_props[k]
for k in var_props.keys() - {"modes"}
}
return filtered_vars


class ModifierError(RambleError):
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2022-2024 The Ramble Authors
#
# Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
# https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
# <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
# option. This file may not be copied, modified, or distributed
# except according to those terms.

import os

import pytest

import ramble.workspace
import ramble.experiment_set
from ramble.main import RambleCommand

workspace = RambleCommand("workspace")


@pytest.mark.parametrize(
"test_name,mode,expect_error",
[
(
"standard_require_hostlist",
"standard",
ramble.experiment_set.RambleVariableDefinitionError,
),
("local_no_require_hostlist", "local", None),
],
)
def test_required_variables(
test_name, mode, expect_error, mutable_mock_workspace_path, mutable_applications
):
workspace_name = test_name

test_config = f"""
ramble:
variables:
mpi_command: ''
batch_submit: 'batch_submit {{execute_experiment}}'
processes_per_node: 1
applications:
hostname:
workloads:
local:
experiments:
test:
variables:
n_nodes: 1
modifiers:
- name: gcp-metadata
mode: {mode}
"""

with ramble.workspace.create(workspace_name) as ws:
ws.write()

config_path = os.path.join(ws.config_dir, ramble.workspace.config_file_name)

with open(config_path, "w+") as f:
f.write(test_config)

ws._re_read()

if expect_error:
with pytest.raises(expect_error):
workspace("setup", "--dry-run", global_args=["-D", ws.root])
else:
workspace("setup", "--dry-run", global_args=["-D", ws.root])
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class GcpMetadata(BasicModifier):

software_spec("pdsh", pkg_spec="pdsh", package_manager="spack*")

required_variable("hostlist")
required_variable("hostlist", modes=["standard"])

modifier_variable(
"metadata_parallel_prefix",
Expand Down

0 comments on commit dbfd0c4

Please sign in to comment.