diff --git a/src/ert/config/ert_config.py b/src/ert/config/ert_config.py index 61d5032439d..82bd3be2a2a 100644 --- a/src/ert/config/ert_config.py +++ b/src/ert/config/ert_config.py @@ -54,6 +54,7 @@ ForwardModelStepKeys, HistorySource, HookRuntime, + QueueSystemWithGeneric, init_forward_model_schema, init_site_config_schema, init_user_config_schema, @@ -250,6 +251,7 @@ class ErtConfig: DEFAULT_ENSPATH: ClassVar[str] = "storage" DEFAULT_RUNPATH_FILE: ClassVar[str] = ".ert_runpath_list" PREINSTALLED_FORWARD_MODEL_STEPS: ClassVar[Dict[str, ForwardModelStep]] = {} + ACTIVATE_SCRIPT: Optional[str] = None substitutions: Substitutions = field(default_factory=Substitutions) ensemble_config: EnsembleConfig = field(default_factory=EnsembleConfig) @@ -330,6 +332,7 @@ class ErtConfigWithPlugins(ErtConfig): PREINSTALLED_FORWARD_MODEL_STEPS: ClassVar[ Dict[str, ForwardModelStepPlugin] ] = preinstalled_fm_steps + ACTIVATE_SCRIPT = ErtPluginManager().activate_script() assert issubclass(ErtConfigWithPlugins, ErtConfig) return ErtConfigWithPlugins @@ -658,6 +661,12 @@ def _merge_user_and_site_config( user_config_dict[keyword] = value + original_entries elif keyword not in user_config_dict: user_config_dict[keyword] = value + if cls.ACTIVATE_SCRIPT: + if "QUEUE_OPTION" not in user_config_dict: + user_config_dict["QUEUE_OPTION"] = [] + user_config_dict["QUEUE_OPTION"].append( + [QueueSystemWithGeneric.GENERIC, "ACTIVATE_SCRIPT", cls.ACTIVATE_SCRIPT] + ) return user_config_dict @classmethod diff --git a/src/ert/config/queue_config.py b/src/ert/config/queue_config.py index 9062fbc9cd1..f9c82f59e74 100644 --- a/src/ert/config/queue_config.py +++ b/src/ert/config/queue_config.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import os import re import shutil from abc import abstractmethod @@ -11,7 +12,6 @@ from pydantic.dataclasses import dataclass from typing_extensions import Annotated -from ..plugins import ErtPluginManager from .parsing import ( ConfigDict, ConfigKeys, @@ -27,13 +27,29 @@ NonEmptyString = Annotated[str, pydantic.StringConstraints(min_length=1)] +def activate_script() -> str: + shell = os.environ.get("SHELL") + venv = os.environ.get("VIRTUAL_ENV") + if not venv or not shell: + return "" + script = [f"#!{shell}"] + match shell: + case x if "bash" in x or "zsh" in x: + script.append(f"source {venv}/bin/activate") + case x if "csh" in x or "tcsh" in x: + script.append(f"source {venv}/bin/activate.csh") + case x if "fish" in x: + script.append(f"source {venv}/bin/activate.fish") + return "\n".join(script) + + @pydantic.dataclasses.dataclass(config={"extra": "forbid", "validate_assignment": True}) class QueueOptions: name: str max_running: pydantic.NonNegativeInt = 0 submit_sleep: pydantic.NonNegativeFloat = 0.0 project_code: Optional[str] = None - activate_script: str = "" + activate_script: str = field(default_factory=activate_script) @staticmethod def create_queue_options( @@ -274,7 +290,6 @@ class QueueConfig: ] = pydantic.Field(default_factory=LocalQueueOptions, discriminator="name") queue_options_test_run: LocalQueueOptions = field(default_factory=LocalQueueOptions) stop_long_running: bool = False - activate_script: str = "" @no_type_check @classmethod @@ -295,10 +310,6 @@ def from_dict(cls, config_dict: ConfigDict) -> QueueConfig: _grouped_queue_options = _group_queue_options_by_queue_system( _raw_queue_options ) - activate_script = ErtPluginManager().activate_script() - _grouped_queue_options[selected_queue_system]["activate_script"] = ( - activate_script - ) _log_duplicated_queue_options(_raw_queue_options) _raise_for_defaulted_invalid_options(_raw_queue_options) diff --git a/src/ert/plugins/plugin_manager.py b/src/ert/plugins/plugin_manager.py index 6c0dc60eb63..0e17b8858c6 100644 --- a/src/ert/plugins/plugin_manager.py +++ b/src/ert/plugins/plugin_manager.py @@ -157,24 +157,9 @@ def _site_config_lines(self) -> List[str]: return list(chain.from_iterable(reversed(plugin_site_config_lines))) def activate_script(self) -> str: - def activate_script() -> str: - shell = os.environ.get("SHELL") - venv = os.environ.get("VIRTUAL_ENV") - if not venv or not shell: - return "" - script = [f"#!{shell}"] - match shell: - case x if "bash" in x or "zsh" in x: - script.append(f"source {venv}/bin/activate") - case x if "csh" in x or "tcsh" in x: - script.append(f"source {venv}/bin/activate.csh") - case x if "fish" in x: - script.append(f"source {venv}/bin/activate.fish") - return "\n".join(script) - plugin_responses = self.hook.activate_script() if not plugin_responses: - return activate_script() + return "" if len(plugin_responses) > 1: raise ValueError( f"Only one activate script is allowed, got {[plugin.plugin_metadata.plugin_name for plugin in plugin_responses]}" diff --git a/tests/ert/unit_tests/config/test_queue_config.py b/tests/ert/unit_tests/config/test_queue_config.py index 73e0ad7d680..00b03a79ff6 100644 --- a/tests/ert/unit_tests/config/test_queue_config.py +++ b/tests/ert/unit_tests/config/test_queue_config.py @@ -15,6 +15,7 @@ from ert.config.queue_config import ( LocalQueueOptions, LsfQueueOptions, + QueueOptions, SlurmQueueOptions, TorqueQueueOptions, ) @@ -509,3 +510,22 @@ def test_driver_initialization_from_defaults(queue_system): LocalDriver(**LocalQueueOptions().driver_options) if queue_system == QueueSystem.SLURM: SlurmDriver(**SlurmQueueOptions().driver_options) + + +@pytest.mark.parametrize("venv", ["my_env", None]) +@pytest.mark.parametrize( + "shell, expected", + [ + ["csh", "#!csh\nsource my_env/bin/activate.csh"], + ["bash", "#!bash\nsource my_env/bin/activate"], + ], +) +def test_default_activate_script_generation(shell, expected, monkeypatch, venv): + monkeypatch.setenv("SHELL", shell) + if venv: + monkeypatch.setenv("VIRTUAL_ENV", "my_env") + else: + monkeypatch.delenv("VIRTUAL_ENV", raising=False) + expected = "" + options = QueueOptions(name="local") + assert options.activate_script == expected diff --git a/tests/ert/unit_tests/plugins/test_plugin_manager.py b/tests/ert/unit_tests/plugins/test_plugin_manager.py index 8ae56c2b05a..f058fb410e1 100644 --- a/tests/ert/unit_tests/plugins/test_plugin_manager.py +++ b/tests/ert/unit_tests/plugins/test_plugin_manager.py @@ -171,25 +171,3 @@ def test_multiple_activate_script_hook(): pm = ErtPluginManager(plugins=[ActivatePlugin(), AnotherActivatePlugin()]) with pytest.raises(ValueError, match="one activate script is allowed"): pm.activate_script() - - -@pytest.mark.parametrize("plugins", [[], [EmptyActivatePlugin()]]) -@pytest.mark.parametrize("venv", ["my_env", None]) -@pytest.mark.parametrize( - "shell, expected", - [ - ["csh", "#!csh\nsource my_env/bin/activate.csh"], - ["bash", "#!bash\nsource my_env/bin/activate"], - ], -) -def test_default_activate_script_generation( - shell, expected, monkeypatch, venv, plugins -): - monkeypatch.setenv("SHELL", shell) - pm = ErtPluginManager(plugins=plugins) - if venv: - monkeypatch.setenv("VIRTUAL_ENV", "my_env") - assert pm.activate_script() == expected - else: - monkeypatch.delenv("VIRTUAL_ENV", raising=False) - assert not pm.activate_script()