Skip to content

Commit

Permalink
Review comment
Browse files Browse the repository at this point in the history
  • Loading branch information
oyvindeide committed Nov 26, 2024
1 parent 7faf381 commit 186a508
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 45 deletions.
9 changes: 9 additions & 0 deletions src/ert/config/ert_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
ForwardModelStepKeys,
HistorySource,
HookRuntime,
QueueSystemWithGeneric,
init_forward_model_schema,
init_site_config_schema,
init_user_config_schema,
Expand Down Expand Up @@ -250,6 +251,7 @@ class ErtConfig:
DEFAULT_ENSPATH: ClassVar[str] = "storage"
DEFAULT_RUNPATH_FILE: ClassVar[str] = ".ert_runpath_list"
PREINSTALLED_FORWARD_MODEL_STEPS: ClassVar[Dict[str, ForwardModelStep]] = {}
ACTIVATE_SCRIPT: Optional[str] = None

substitutions: Substitutions = field(default_factory=Substitutions)
ensemble_config: EnsembleConfig = field(default_factory=EnsembleConfig)
Expand Down Expand Up @@ -330,6 +332,7 @@ class ErtConfigWithPlugins(ErtConfig):
PREINSTALLED_FORWARD_MODEL_STEPS: ClassVar[
Dict[str, ForwardModelStepPlugin]
] = preinstalled_fm_steps
ACTIVATE_SCRIPT = ErtPluginManager().activate_script()

assert issubclass(ErtConfigWithPlugins, ErtConfig)
return ErtConfigWithPlugins
Expand Down Expand Up @@ -658,6 +661,12 @@ def _merge_user_and_site_config(
user_config_dict[keyword] = value + original_entries
elif keyword not in user_config_dict:
user_config_dict[keyword] = value
if cls.ACTIVATE_SCRIPT:
if "QUEUE_OPTION" not in user_config_dict:
user_config_dict["QUEUE_OPTION"] = []
user_config_dict["QUEUE_OPTION"].append(
[QueueSystemWithGeneric.GENERIC, "ACTIVATE_SCRIPT", cls.ACTIVATE_SCRIPT]
)
return user_config_dict

@classmethod
Expand Down
25 changes: 18 additions & 7 deletions src/ert/config/queue_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
import os
import re
import shutil
from abc import abstractmethod
Expand All @@ -11,7 +12,6 @@
from pydantic.dataclasses import dataclass
from typing_extensions import Annotated

from ..plugins import ErtPluginManager
from .parsing import (
ConfigDict,
ConfigKeys,
Expand All @@ -27,13 +27,29 @@
NonEmptyString = Annotated[str, pydantic.StringConstraints(min_length=1)]


def activate_script() -> str:
shell = os.environ.get("SHELL")
venv = os.environ.get("VIRTUAL_ENV")
if not venv or not shell:
return ""
script = [f"#!{shell}"]
match shell:
case x if "bash" in x or "zsh" in x:
script.append(f"source {venv}/bin/activate")
case x if "csh" in x or "tcsh" in x:
script.append(f"source {venv}/bin/activate.csh")
case x if "fish" in x:
script.append(f"source {venv}/bin/activate.fish")
return "\n".join(script)


@pydantic.dataclasses.dataclass(config={"extra": "forbid", "validate_assignment": True})
class QueueOptions:
name: str
max_running: pydantic.NonNegativeInt = 0
submit_sleep: pydantic.NonNegativeFloat = 0.0
project_code: Optional[str] = None
activate_script: str = ""
activate_script: str = field(default_factory=activate_script)

@staticmethod
def create_queue_options(
Expand Down Expand Up @@ -274,7 +290,6 @@ class QueueConfig:
] = pydantic.Field(default_factory=LocalQueueOptions, discriminator="name")
queue_options_test_run: LocalQueueOptions = field(default_factory=LocalQueueOptions)
stop_long_running: bool = False
activate_script: str = ""

@no_type_check
@classmethod
Expand All @@ -295,10 +310,6 @@ def from_dict(cls, config_dict: ConfigDict) -> QueueConfig:
_grouped_queue_options = _group_queue_options_by_queue_system(
_raw_queue_options
)
activate_script = ErtPluginManager().activate_script()
_grouped_queue_options[selected_queue_system]["activate_script"] = (
activate_script
)
_log_duplicated_queue_options(_raw_queue_options)
_raise_for_defaulted_invalid_options(_raw_queue_options)

Expand Down
17 changes: 1 addition & 16 deletions src/ert/plugins/plugin_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,24 +157,9 @@ def _site_config_lines(self) -> List[str]:
return list(chain.from_iterable(reversed(plugin_site_config_lines)))

def activate_script(self) -> str:
def activate_script() -> str:
shell = os.environ.get("SHELL")
venv = os.environ.get("VIRTUAL_ENV")
if not venv or not shell:
return ""
script = [f"#!{shell}"]
match shell:
case x if "bash" in x or "zsh" in x:
script.append(f"source {venv}/bin/activate")
case x if "csh" in x or "tcsh" in x:
script.append(f"source {venv}/bin/activate.csh")
case x if "fish" in x:
script.append(f"source {venv}/bin/activate.fish")
return "\n".join(script)

plugin_responses = self.hook.activate_script()
if not plugin_responses:
return activate_script()
return ""
if len(plugin_responses) > 1:
raise ValueError(
f"Only one activate script is allowed, got {[plugin.plugin_metadata.plugin_name for plugin in plugin_responses]}"
Expand Down
20 changes: 20 additions & 0 deletions tests/ert/unit_tests/config/test_queue_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ert.config.queue_config import (
LocalQueueOptions,
LsfQueueOptions,
QueueOptions,
SlurmQueueOptions,
TorqueQueueOptions,
)
Expand Down Expand Up @@ -509,3 +510,22 @@ def test_driver_initialization_from_defaults(queue_system):
LocalDriver(**LocalQueueOptions().driver_options)
if queue_system == QueueSystem.SLURM:
SlurmDriver(**SlurmQueueOptions().driver_options)


@pytest.mark.parametrize("venv", ["my_env", None])
@pytest.mark.parametrize(
"shell, expected",
[
["csh", "#!csh\nsource my_env/bin/activate.csh"],
["bash", "#!bash\nsource my_env/bin/activate"],
],
)
def test_default_activate_script_generation(shell, expected, monkeypatch, venv):
monkeypatch.setenv("SHELL", shell)
if venv:
monkeypatch.setenv("VIRTUAL_ENV", "my_env")
else:
monkeypatch.delenv("VIRTUAL_ENV", raising=False)
expected = ""
options = QueueOptions(name="local")
assert options.activate_script == expected
22 changes: 0 additions & 22 deletions tests/ert/unit_tests/plugins/test_plugin_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,25 +171,3 @@ def test_multiple_activate_script_hook():
pm = ErtPluginManager(plugins=[ActivatePlugin(), AnotherActivatePlugin()])
with pytest.raises(ValueError, match="one activate script is allowed"):
pm.activate_script()


@pytest.mark.parametrize("plugins", [[], [EmptyActivatePlugin()]])
@pytest.mark.parametrize("venv", ["my_env", None])
@pytest.mark.parametrize(
"shell, expected",
[
["csh", "#!csh\nsource my_env/bin/activate.csh"],
["bash", "#!bash\nsource my_env/bin/activate"],
],
)
def test_default_activate_script_generation(
shell, expected, monkeypatch, venv, plugins
):
monkeypatch.setenv("SHELL", shell)
pm = ErtPluginManager(plugins=plugins)
if venv:
monkeypatch.setenv("VIRTUAL_ENV", "my_env")
assert pm.activate_script() == expected
else:
monkeypatch.delenv("VIRTUAL_ENV", raising=False)
assert not pm.activate_script()

0 comments on commit 186a508

Please sign in to comment.