Skip to content

Commit

Permalink
Remove the get_forward_models hook
Browse files Browse the repository at this point in the history
  • Loading branch information
verveerpj committed Nov 8, 2024
1 parent 55c07c4 commit 0c6f980
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 37 deletions.
5 changes: 0 additions & 5 deletions src/everest/plugins/hook_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,6 @@ def flow_config_path():
return None


@hookimpl
def get_forward_models():
return None


@hookimpl
def lint_forward_model():
return None
Expand Down
7 changes: 0 additions & 7 deletions src/everest/plugins/hook_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,6 @@ def flow_config_path():
"""


@hookspec
def get_forward_models():
"""
Return a list of dicts detailing the names and paths to forward models.
"""


@hookspec(firstresult=True)
def lint_forward_model(job: str, args: Sequence[str]):
"""
Expand Down
5 changes: 0 additions & 5 deletions src/everest/util/forward_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from itertools import chain
from typing import List, Type, TypeVar

from pydantic import BaseModel, ValidationError
Expand All @@ -9,10 +8,6 @@
T = TypeVar("T", bound=BaseModel)


def collect_forward_models():
return chain.from_iterable(pm.hook.get_forward_models())


def collect_forward_model_schemas():
schemas = pm.hook.get_forward_models_schemas()
if schemas:
Expand Down
10 changes: 0 additions & 10 deletions tests/everest/test_everlint.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from everest import ConfigKeys
from everest.config import EverestConfig
from everest.config_file_loader import yaml_file_to_substituted_config_dict
from everest.util.forward_models import collect_forward_models
from tests.everest.test_config_validation import has_error
from tests.everest.utils import relpath

Expand Down Expand Up @@ -585,12 +584,3 @@ def test_lint_everest_models_jobs():
config = EverestConfig.load_file(config_file).to_dict()
# Check initial config file is valid
assert len(EverestConfig.lint_config_dict(config)) == 0


def test_overloading_everest_models_names():
config = yaml_file_to_substituted_config_dict(SNAKE_OIL_CONFIG)
for job in collect_forward_models():
config["install_jobs"][2]["name"] = job
config["forward_model"][1] = job
errors = EverestConfig.lint_config_dict(config)
assert len(errors) == 0, f"Failed for job {job}"
19 changes: 9 additions & 10 deletions tests/everest/test_fm_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from everest.plugins import hook_impl, hook_specs, hookimpl
from everest.simulator.everest_to_ert import everest_to_ert_config
from everest.strings import EVEREST
from everest.util.forward_models import collect_forward_models
from tests.everest.utils import relpath

SNAKE_CONFIG_PATH = relpath("test_data/snake_oil/everest/model/snake_oil.yml")
Expand Down Expand Up @@ -42,9 +41,9 @@ def register_plugin_hooks(*plugins) -> MockPluginManager:


def test_everest_models_jobs():
pytest.importorskip("everest_models")
everest_models = pytest.importorskip("everest_models")
ert_config = everest_to_ert_config(EverestConfig.load_file(SNAKE_CONFIG_PATH))
jobs = collect_forward_models()
jobs = everest_models.forward_models.get_forward_models()
assert bool(jobs)
for job in jobs:
job_class = ert_config.installed_forward_model_steps.get(job)
Expand All @@ -53,22 +52,22 @@ def test_everest_models_jobs():


def test_multiple_plugins(plugin_manager):
_JOBS = ["job1", "job2"]
_SCHEMAS = [{"job1": 1}, {"job2": 2}]

class Plugin1:
@hookimpl
def get_forward_models(self):
return [_JOBS[0]]
def get_forward_models_schemas(self):
return [_SCHEMAS[0]]

class Plugin2:
@hookimpl
def get_forward_models(self):
return [_JOBS[1]]
def get_forward_models_schemas(self):
return [_SCHEMAS[1]]

pm = plugin_manager(Plugin1(), Plugin2())

jobs = set(chain.from_iterable(pm.hook.get_forward_models()))
for value in _JOBS:
jobs = list(chain.from_iterable(pm.hook.get_forward_models_schemas()))
for value in _SCHEMAS:
assert value in jobs


Expand Down

0 comments on commit 0c6f980

Please sign in to comment.