diff --git a/src/ert/gui/plottery/plots/ensemble.py b/src/ert/gui/plottery/plots/ensemble.py index 57b308d663f..9d24e270952 100644 --- a/src/ert/gui/plottery/plots/ensemble.py +++ b/src/ert/gui/plottery/plots/ensemble.py @@ -1,12 +1,13 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING, Dict, Optional import numpy as np import pandas as pd from ert.gui.plottery.plots.history import plotHistory from ert.gui.tools.plot.plot_api import EnsembleObject +from ert.shared.storage.summary_key_utils import is_rate from .observations import plotObservations from .plot_tools import PlotTools @@ -45,11 +46,13 @@ def plot( plot_context.deactivateDateSupport() plot_context.x_axis = plot_context.INDEX_AXIS + draw_style = "steps-pre" if is_rate(plot_context.key()) else None self._plotLines( axes, config, data, f"{ensemble.experiment_name} : {ensemble.name}", + draw_style, ) config.nextColor() @@ -71,6 +74,7 @@ def _plotLines( plot_config: PlotConfig, data: pd.DataFrame, ensemble_label: str, + draw_style: Optional[str] = None, ) -> None: style = plot_config.defaultStyle() @@ -86,6 +90,7 @@ def _plotLines( linewidth=style.width, linestyle=style.line_style, markersize=style.size, + drawstyle=draw_style, ) if len(lines) > 0: diff --git a/src/ert/shared/storage/summary_key_utils.py b/src/ert/shared/storage/summary_key_utils.py new file mode 100644 index 00000000000..9f0a4630b8d --- /dev/null +++ b/src/ert/shared/storage/summary_key_utils.py @@ -0,0 +1,200 @@ +from enum import Enum, auto +from typing import List + +special_keys = [ + "NAIMFRAC", + "NBAKFL", + "NBYTOT", + "NCPRLINS", + "NEWTFL", + "NEWTON", + "NLINEARP", + "NLINEARS", + "NLINSMAX", + "NLINSMIN", + "NLRESMAX", + "NLRESSUM", + "NMESSAGE", + "NNUMFL", + "NNUMST", + "NTS", + "NTSECL", + "NTSMCL", + "NTSPCL", + "ELAPSED", + "MAXDPR", + "MAXDSO", + "MAXDSG", + "MAXDSW", + "STEPTYPE", + "WNEWTON", +] +rate_keys = [ + "OPR", + "OIR", + "OVPR", + "OVIR", + "OFR", + "OPP", + "OPI", + "OMR", + "GPR", + "GIR", + "GVPR", + "GVIR", + "GFR", + "GPP", + "GPI", + "GMR", + "WGPR", + "WGIR", + "WPR", + "WIR", + "WVPR", + "WVIR", + "WFR", + "WPP", + "WPI", + "WMR", + "LPR", + "LFR", + "VPR", + "VIR", + "VFR", + "GLIR", + "RGR", + "EGR", + "EXGR", + "SGR", + "GSR", + "FGR", + "GIMR", + "GCR", + "NPR", + "NIR", + "CPR", + "CIR", + "SIR", + "SPR", + "TIR", + "TPR", + "GOR", + "WCT", + "OGR", + "WGR", + "GLR", +] + +seg_rate_keys = [ + "OFR", + "GFR", + "WFR", + "CFR", + "SFR", + "TFR", + "CVPR", + "WCT", + "GOR", + "OGR", + "WGR", +] + + +class SummaryKeyType(Enum): + INVALID = auto() + FIELD = auto() + REGION = auto() + GROUP = auto() + WELL = auto() + SEGMENT = auto() + BLOCK = auto() + AQUIFER = auto() + COMPLETION = auto() + NETWORK = auto() + REGION_2_REGION = auto() + LOCAL_BLOCK = auto() + LOCAL_COMPLETION = auto() + LOCAL_WELL = auto() + MISC = auto() + + @staticmethod + def determine_key_type(key: str) -> "SummaryKeyType": + if key in special_keys: + return SummaryKeyType.MISC + + if key.startswith("L"): + secondary = key[1] if len(key) > 1 else "" + return { + "B": SummaryKeyType.LOCAL_BLOCK, + "C": SummaryKeyType.LOCAL_COMPLETION, + "W": SummaryKeyType.LOCAL_WELL, + }.get(secondary, SummaryKeyType.MISC) + + if key.startswith("R"): + if len(key) == 3 and key[2] == "F": + return SummaryKeyType.REGION_2_REGION + if key == "RNLF": + return SummaryKeyType.REGION_2_REGION + if key == "RORFR": + return SummaryKeyType.REGION + if len(key) >= 4 and key[2] == "F" and key[3] in {"T", "R"}: + return SummaryKeyType.REGION_2_REGION + if len(key) >= 5 and key[3] == "F" and key[4] in {"T", "R"}: + return SummaryKeyType.REGION_2_REGION + return SummaryKeyType.REGION + + # default cases or miscellaneous if not matched + return { + "A": SummaryKeyType.AQUIFER, + "B": SummaryKeyType.BLOCK, + "C": SummaryKeyType.COMPLETION, + "F": SummaryKeyType.FIELD, + "G": SummaryKeyType.GROUP, + "N": SummaryKeyType.NETWORK, + "S": SummaryKeyType.SEGMENT, + "W": SummaryKeyType.WELL, + }.get(key[0], SummaryKeyType.MISC) + + +def match_keyword_vector(start: int, rate_keys: List[str], keyword: str) -> bool: + if len(keyword) < start: + return False + return any(keyword[start:].startswith(key) for key in rate_keys) + + +def match_keyword_string(start: int, rate_string: str, keyword: str) -> bool: + if len(keyword) < start: + return False + return keyword[start:].startswith(rate_string) + + +def is_rate(key: str) -> bool: + key_type = SummaryKeyType.determine_key_type(key) + if key_type in { + SummaryKeyType.WELL, + SummaryKeyType.GROUP, + SummaryKeyType.FIELD, + SummaryKeyType.REGION, + SummaryKeyType.COMPLETION, + SummaryKeyType.LOCAL_WELL, + SummaryKeyType.LOCAL_COMPLETION, + SummaryKeyType.NETWORK, + }: + if key_type in { + SummaryKeyType.LOCAL_WELL, + SummaryKeyType.LOCAL_COMPLETION, + SummaryKeyType.NETWORK, + }: + return match_keyword_vector(2, rate_keys, key) + return match_keyword_vector(1, rate_keys, key) + + if key_type == SummaryKeyType.SEGMENT: + return match_keyword_vector(1, seg_rate_keys, key) + + if key_type == SummaryKeyType.REGION_2_REGION: + # Region to region rates are identified by R*FR or R**FR + if match_keyword_string(2, "FR", key): + return True + return match_keyword_string(3, "FR", key) + + return False diff --git a/tests/unit_tests/shared/test_rate_keys.py b/tests/unit_tests/shared/test_rate_keys.py new file mode 100644 index 00000000000..aa8e3316208 --- /dev/null +++ b/tests/unit_tests/shared/test_rate_keys.py @@ -0,0 +1,62 @@ +import hypothesis.strategies as st +import pytest +from hypothesis import given +from resdata.summary import Summary + +from ert.shared.storage.summary_key_utils import is_rate +from tests.unit_tests.config.summary_generator import summary_variables + + +def nonempty_string_without_whitespace(): + return st.text( + st.characters(whitelist_categories=("Lu", "Ll", "Nd", "P")), min_size=1 + ) + + +@given(key=nonempty_string_without_whitespace()) +def test_is_rate_does_not_raise_error(key): + is_rate_bool = is_rate(key) + assert isinstance(is_rate_bool, bool) + + +examples = [ + ("OPR", False), + ("WOPR:OP_4", True), + ("WGIR", True), + ("FOPT", False), + ("GGPT", False), + ("RWPT", False), + ("COPR", True), + ("LPR", False), + ("LWPR", False), + ("LCOPR", True), + ("RWGIR", True), + ("RTPR", True), + ("RXFR", True), + ("XXX", False), + ("YYYY", False), + ("ZZT", False), + ("SGPR", False), + ("AAPR", False), + ("JOPR", False), + ("ROPRT", True), + ("RNFT", False), + ("RFR", False), + ("RRFRT", True), + ("ROC", False), + ("BPR:123", False), + ("FWIR", True), +] + + +@pytest.mark.parametrize("key, rate", examples) +def test_is_rate_determines_rate_key_correctly(key, rate): + is_rate_bool = is_rate(key) + assert is_rate_bool == rate + + +@given(key=summary_variables()) +def test_rate_determination_is_consistent(key): + # Here we verify that the determination of rate keys is the same + # as provided by resdata api + assert Summary.is_rate(key) == is_rate(key)