Skip to content

Commit

Permalink
Stop using cached seba data
Browse files Browse the repository at this point in the history
  • Loading branch information
frode-aarstad committed Dec 13, 2024
1 parent 4301b63 commit 104118d
Showing 1 changed file with 20 additions and 39 deletions.
59 changes: 20 additions & 39 deletions tests/everest/entry_points/test_config_branch_entry.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,27 @@
import difflib
from os.path import exists
from unittest.mock import PropertyMock, patch
from pathlib import Path

from seba_sqlite.snapshot import SebaSnapshot

from everest.bin.config_branch_script import config_branch_entry
from everest.config import EverestConfig
from everest.config_file_loader import load_yaml
from everest.config_keys import ConfigKeys as CK
from tests.everest.utils import relpath

CONFIG_FILE = "config_advanced.yml"
CACHED_SEBA_FOLDER = relpath("test_data", "cached_results_config_advanced")

def test_config_branch_entry(cached_example):
path, _, _ = cached_example("math_func/config_advanced.yml")

# @patch.object(EverestConfig, "optimization_output_dir", new_callable=PropertyMock)
@patch.object(
EverestConfig,
"optimization_output_dir",
new_callable=PropertyMock,
return_value=CACHED_SEBA_FOLDER,
)
def test_config_branch_entry(get_opt_output_dir_mock, copy_math_func_test_data_to_tmp):
new_config_file_name = "new_restart_config.yml"
batch_id = 1
config_branch_entry(["config_advanced.yml", "new_restart_config.yml", "-b", "1"])

config_branch_entry([CONFIG_FILE, new_config_file_name, "-b", str(batch_id)])
assert exists("new_restart_config.yml")

get_opt_output_dir_mock.assert_called_once()
assert exists(new_config_file_name)

old_config = load_yaml(CONFIG_FILE)
old_config = load_yaml("config_advanced.yml")
old_controls = old_config[CK.CONTROLS]

assert CK.INITIAL_GUESS in old_controls[0]

new_config = load_yaml(new_config_file_name)
new_config = load_yaml("new_restart_config.yml")
new_controls = new_config[CK.CONTROLS]

assert CK.INITIAL_GUESS not in new_controls[0]
Expand All @@ -44,9 +30,9 @@ def test_config_branch_entry(get_opt_output_dir_mock, copy_math_func_test_data_t

opt_controls = {}

snapshot = SebaSnapshot(CACHED_SEBA_FOLDER)
snapshot = SebaSnapshot(Path(path) / "everest_output" / "optimization_output")
for opt_data in snapshot._optimization_data():
if opt_data.batch_id == batch_id:
if opt_data.batch_id == 1:
opt_controls = opt_data.controls

new_controls_initial_guesses = {
Expand All @@ -57,36 +43,31 @@ def test_config_branch_entry(get_opt_output_dir_mock, copy_math_func_test_data_t
assert new_controls_initial_guesses == opt_control_val_for_batch_id


@patch.object(
EverestConfig,
"optimization_output_dir",
new_callable=PropertyMock,
return_value=CACHED_SEBA_FOLDER,
)
def test_config_branch_preserves_config_section_order(
get_opt_output_dir_mock, copy_math_func_test_data_to_tmp
):
new_config_file_name = "new_restart_config.yml"
batch_id = 1
def test_config_branch_preserves_config_section_order(cached_example):
path, _, _ = cached_example("math_func/config_advanced.yml")

config_branch_entry([CONFIG_FILE, new_config_file_name, "-b", str(batch_id)])
config_branch_entry(["config_advanced.yml", "new_restart_config.yml", "-b", "1"])

get_opt_output_dir_mock.assert_called_once()
assert exists(new_config_file_name)
assert exists("new_restart_config.yml")

opt_controls = {}

snapshot = SebaSnapshot(CACHED_SEBA_FOLDER)
snapshot = SebaSnapshot(Path(path) / "everest_output" / "optimization_output")
for opt_data in snapshot._optimization_data():
if opt_data.batch_id == batch_id:
if opt_data.batch_id == 1:
opt_controls = opt_data.controls

opt_control_val_for_batch_id = {v for k, v in opt_controls.items()}

diff_lines = []
with (
<<<<<<< HEAD
open(CONFIG_FILE, encoding="utf-8") as initial_config,
open(new_config_file_name, encoding="utf-8") as branch_config,
=======
open("config_advanced.yml", "r", encoding="utf-8") as initial_config,
open("new_restart_config.yml", "r", encoding="utf-8") as branch_config,
>>>>>>> 36897088f (Stop using cached seba data)
):
diff = difflib.unified_diff(
initial_config.readlines(),
Expand Down

0 comments on commit 104118d

Please sign in to comment.