Skip to content

Commit

Permalink
Fix more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
frode-aarstad committed Dec 10, 2024
1 parent 07c6de1 commit 43eedff
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 85 deletions.
7 changes: 2 additions & 5 deletions src/everest/bin/everest_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import threading
from functools import partial

from ert.run_models.everest_run_model import EverestRunModel
from everest.config import EverestConfig, ServerConfig
from everest.detached import (
ServerStatus,
Expand All @@ -24,7 +23,6 @@
from everest.util import (
makedirs_if_needed,
version_info,
warn_user_that_runpath_is_nonempty,
)

from .utils import (
Expand Down Expand Up @@ -115,8 +113,8 @@ async def run_everest(options):
except ValueError as exc:
raise SystemExit(f"Config validation error: {exc}") from exc

if EverestRunModel.create(options.config).check_if_runpath_exists():
warn_user_that_runpath_is_nonempty()
# if EverestRunModel.create(options.config).check_if_runpath_exists():
# warn_user_that_runpath_is_nonempty()

try:
output_dir = options.config.output_dir
Expand All @@ -135,7 +133,6 @@ async def run_everest(options):
start_experiment(
server_context=ServerConfig.get_server_context(options.config.output_dir),
config=options.config,
debug=options.debug,
)

run_detached_monitor(
Expand Down
2 changes: 1 addition & 1 deletion src/everest/detached/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
def start_experiment(
server_context: Tuple[str, str, Tuple[str, str]],
config: EverestConfig,
debug: bool = False,
) -> None:
try:
url, cert, auth = server_context
Expand All @@ -65,6 +64,7 @@ def start_experiment(
verify=cert,
auth=auth,
proxies=PROXY, # type: ignore
json=config.to_dict(),
)
response.raise_for_status()
except:
Expand Down
22 changes: 16 additions & 6 deletions src/everest/detached/jobs/everest_server_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,12 @@ def get_exit_code(self) -> Optional[ExitCode]:


class EverestServerAPI(threading.Thread):
def __init__(self, everest_config: EverestConfig):
def __init__(self, output_dir: str, optimization_output_dir: str):
super().__init__()

self.output_dir = output_dir
self.optimization_output_dir = optimization_output_dir

self.app = FastAPI()

self.router = APIRouter()
Expand Down Expand Up @@ -239,10 +242,6 @@ def __init__(self, everest_config: EverestConfig):

self.runner: Optional[ExperimentRunner] = None

self.everest_config = everest_config
self.output_dir = everest_config.output_dir
self.optimization_output_dir = everest_config.optimization_output_dir

# same code is in ensemble evaluator
self.authentication = _generate_authentication()

Expand Down Expand Up @@ -312,8 +311,18 @@ def get_exit_code(
self._log(request)
self._check_user(credentials)

if self.state[STOP_ENDPOINT] == True:
return JSONResponse(
jsonable_encoder(
ExitCode(
message="Everest server stopped",
)
)
)

if not self.runner:
return JSONResponse(jsonable_encoder({}))

return JSONResponse(
jsonable_encoder(
self.runner.get_exit_code() if self.runner.get_exit_code() else {}
Expand All @@ -330,13 +339,14 @@ def get_opt_progress(

def start_experiment(
self,
config: EverestConfig,
request: Request,
credentials: HTTPBasicCredentials = Depends(security),
) -> Response:
self._log(request)
self._check_user(credentials)

self.runner = ExperimentRunner(self.everest_config, self.state)
self.runner = ExperimentRunner(config, self.state)
self.runner.start()

return Response("Everest experiment started", 200)
Expand Down
35 changes: 22 additions & 13 deletions src/everest/detached/jobs/everserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,10 @@ def main():
STOP_ENDPOINT: False,
}

everest_server_api = EverestServerAPI(config)
everest_server_api = EverestServerAPI(
output_dir=config.output_dir,
optimization_output_dir=config.optimization_output_dir,
)
everest_server_api.daemon = True
everest_server_api.start()

Expand All @@ -172,40 +175,46 @@ def main():
is_running = False
while not is_running:
try:
requests.get(url + "/", verify=cert, auth=auth, proxies=PROXY)
requests.get(url + "/", verify=cert, auth=auth, proxies=PROXY) # type: ignore
is_running = True
except:
time.sleep(1)

update_everserver_status(status_path, ServerStatus.running)

# response = requests.post(
# url + "/" + START_ENDPOINT, verify=cert, auth=auth, proxies=PROXY
# )

is_done = False
while not is_done:
response = requests.get(
url + "/" + EXIT_CODE_ENDPOINT, verify=cert, auth=auth, proxies=PROXY
resp: requests.Response = requests.get(
url + "/" + EXIT_CODE_ENDPOINT,
verify=cert,
auth=auth,
proxies=PROXY, # type: ignore
)
exit_code = ExitCode.model_validate_json(
resp.text if hasattr(resp, "text") else resp.body
)
exit_code = ExitCode.model_validate_json(response.body)
if exit_code.exit_code or exit_code.message:
is_done = True
else:
time.sleep(1)

if exit_code.message:
if exit_code.message and exit_code.message != "Everest server stopped":
update_everserver_status(
status_path,
ServerStatus.failed,
message=exit_code.message,
)
return

response = requests.get(
url + "/" + SHARED_DATA_ENDPOINT, verify=cert, auth=auth, proxies=PROXY
response: requests.Response = requests.get(
url + "/" + SHARED_DATA_ENDPOINT,
verify=cert,
auth=auth,
proxies=PROXY, # type: ignore
)
if json_body := json.loads(response.body):
if json_body := json.loads(
response.text if hasattr(response, "text") else response.body
):
shared_data = json_body

status, message = _get_optimization_status(exit_code.exit_code, shared_data)
Expand Down
2 changes: 1 addition & 1 deletion tests/everest/test_detached.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ async def test_https_requests(copy_math_func_test_data_to_tmp):
raise e

server_status = everserver_status(status_path)
assert ServerStatus.running == server_status["status"]
assert server_status["status"] in [ServerStatus.running, ServerStatus.starting]

url, cert, auth = ServerConfig.get_server_context(everest_config.output_dir)
result = requests.get(url, verify=cert, auth=auth, proxies=PROXY) # noqa: ASYNC210
Expand Down
3 changes: 2 additions & 1 deletion tests/everest/test_everest_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,13 @@ def useless_cb(*args, **kwargs):
@patch("everest.bin.everest_script.server_is_running", return_value=False)
@patch("everest.bin.everest_script.run_detached_monitor")
@patch("everest.bin.everest_script.wait_for_server")
@patch("everest.bin.everest_script.start_experiment")
@patch("everest.bin.everest_script.start_server")
@patch(
"everest.bin.everest_script.everserver_status",
return_value={"status": ServerStatus.never_run, "message": None},
)
def test_save_running_config(_, _1, _2, _3, _4, copy_math_func_test_data_to_tmp):
def test_save_running_config(_, _1, _2, _3, _4, _5, copy_math_func_test_data_to_tmp):
"""Test everest detached, when an optimization has already run"""
# optimization already run, notify the user
file_name = "config_minimal.yml"
Expand Down
61 changes: 3 additions & 58 deletions tests/everest/test_everserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse, PlainTextResponse
from ropt.enums import OptimizerExitCode
from seba_sqlite.snapshot import SebaSnapshot

from everest.config import EverestConfig, ServerConfig
from everest.detached import PROXY, ServerStatus, everserver_status
Expand Down Expand Up @@ -142,15 +141,13 @@ def mocked_server(url, verify, auth, proxies):
}
)
)

return PlainTextResponse("Everest is running")

mocked_get.side_effect = mocked_server

everserver.main()

url, cert, auth = ServerConfig.get_server_context(config.output_dir)
requests.post(url + "/start", verify=cert, auth=auth, proxies=PROXY)

status = everserver_status(
ServerConfig.get_everserver_status_path(config.output_dir)
)
Expand Down Expand Up @@ -208,7 +205,7 @@ def mocked_server(url, verify, auth, proxies):
everserver.main()

url, cert, auth = ServerConfig.get_server_context(config.output_dir)
requests.post(url + "/start", verify=cert, auth=auth, proxies=PROXY)
requests.post(url + "/start", verify=cert, auth=auth, proxies=PROXY) # type: ignore

status = everserver_status(
ServerConfig.get_everserver_status_path(config.output_dir)
Expand Down Expand Up @@ -259,7 +256,7 @@ def mocked_server(url, verify, auth, proxies):
everserver.main()

url, cert, auth = ServerConfig.get_server_context(config.output_dir)
requests.post(url + "/start", verify=cert, auth=auth, proxies=PROXY)
requests.post(url + "/start", verify=cert, auth=auth, proxies=PROXY) # type: ignore

status = everserver_status(
ServerConfig.get_everserver_status_path(config.output_dir)
Expand All @@ -269,55 +266,3 @@ def mocked_server(url, verify, auth, proxies):
# start_optimization raised.
assert status["status"] == ServerStatus.failed
assert "Exception: Failed optimization" in status["message"]


@pytest.mark.integration_test
@patch("sys.argv", ["name", "--config-file", "config_one_batch.yml"])
@patch("everest.detached.jobs.everserver._configure_loggers")
@patch("requests.get")
def test_everserver_status_max_batch_num(
mocked_get, mocked_logger, copy_math_func_test_data_to_tmp
):
config_file = "config_one_batch.yml"
config = EverestConfig.load_file(config_file)

def mocked_server(url, verify, auth, proxies):
if "/exit_code" in url:
return JSONResponse(
jsonable_encoder(
ExitCode(exit_code=OptimizerExitCode.OPTIMIZER_STEP_FINISHED)
)
)
if "/shared_data" in url:
return JSONResponse(
jsonable_encoder(
{
SIM_PROGRESS_ENDPOINT: {
"status": {},
"progress": [],
},
STOP_ENDPOINT: False,
}
)
)
return PlainTextResponse("Everest is running")

mocked_get.side_effect = mocked_server

everserver.main()

url, cert, auth = ServerConfig.get_server_context(config.output_dir)
requests.post(url + "/start", verify=cert, auth=auth, proxies=PROXY)

status = everserver_status(
ServerConfig.get_everserver_status_path(config.output_dir)
)

# The server should complete without error.
assert status["status"] == ServerStatus.completed

# Check that there is only one batch.
snapshot = SebaSnapshot(config.optimization_output_dir).get_snapshot(
filter_out_gradient=False, batches=None
)
assert {data.batch for data in snapshot.simulation_data} == {0}
5 changes: 5 additions & 0 deletions tests/everest/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ert.scheduler.event import FinishedEvent
from everest.config import EverestConfig, ServerConfig
from everest.detached import (
start_experiment,
start_server,
wait_for_server,
)
Expand Down Expand Up @@ -35,6 +36,10 @@ async def server_running():
driver = await start_server(everest_config, debug=True)
try:
wait_for_server(everest_config.output_dir, 60)
start_experiment(
server_context=ServerConfig.get_server_context(everest_config.output_dir),
config=everest_config,
)
except (SystemExit, RuntimeError) as e:
raise e
await server_running()
Expand Down

0 comments on commit 43eedff

Please sign in to comment.