diff --git a/cli/medperf/account_management/account_management.py b/cli/medperf/account_management/account_management.py index f8ea4e88e..eb2d9a180 100644 --- a/cli/medperf/account_management/account_management.py +++ b/cli/medperf/account_management/account_management.py @@ -1,15 +1,15 @@ from .token_storage import TokenStore -from medperf.config_management import read_config, write_config -from medperf import config +from medperf.config_management import config +from medperf import settings from medperf.exceptions import MedperfException def read_user_account(): - config_p = read_config() - if config.credentials_keyword not in config_p.active_profile: + config_p = config.read_config() + if settings.credentials_keyword not in config_p.active_profile: return - account_info = config_p.active_profile[config.credentials_keyword] + account_info = config_p.active_profile[settings.credentials_keyword] return account_info @@ -23,7 +23,7 @@ def set_credentials( ): email = id_token_payload["email"] TokenStore().set_tokens(email, access_token, refresh_token) - config_p = read_config() + config_p = config.read_config() if login_event: # Set the time the user logged in, so that we can track the lifetime of @@ -31,7 +31,7 @@ def set_credentials( logged_in_at = token_issued_at else: # This means this is a refresh event. Preserve the logged_in_at timestamp. - logged_in_at = config_p.active_profile[config.credentials_keyword][ + logged_in_at = config_p.active_profile[settings.credentials_keyword][ "logged_in_at" ] @@ -42,8 +42,8 @@ def set_credentials( "logged_in_at": logged_in_at, } - config_p.active_profile[config.credentials_keyword] = account_info - write_config(config_p) + config_p.active_profile[settings.credentials_keyword] = account_info + config_p.write_config() def read_credentials(): @@ -61,35 +61,35 @@ def read_credentials(): def delete_credentials(): - config_p = read_config() - if config.credentials_keyword not in config_p.active_profile: + config_p = config.read_config() + if settings.credentials_keyword not in config_p.active_profile: raise MedperfException("You are not logged in") - email = config_p.active_profile[config.credentials_keyword]["email"] + email = config_p.active_profile[settings.credentials_keyword]["email"] TokenStore().delete_tokens(email) - config_p.active_profile.pop(config.credentials_keyword) - write_config(config_p) + config_p.active_profile.pop(settings.credentials_keyword) + config_p.write_config() def set_medperf_user_data(): """Get and cache user data from the MedPerf server""" - config_p = read_config() + config_p = config.read_config() medperf_user = config.comms.get_current_user() - config_p.active_profile[config.credentials_keyword]["medperf_user"] = medperf_user - write_config(config_p) + config_p.active_profile[settings.credentials_keyword]["medperf_user"] = medperf_user + config_p.write_config() return medperf_user def get_medperf_user_data(): """Return cached medperf user data. Get from the server if not found""" - config_p = read_config() - if config.credentials_keyword not in config_p.active_profile: + config_p = config.read_config() + if settings.credentials_keyword not in config_p.active_profile: raise MedperfException("You are not logged in") - medperf_user = config_p.active_profile[config.credentials_keyword].get( + medperf_user = config_p.active_profile[settings.credentials_keyword].get( "medperf_user", None ) if medperf_user is None: diff --git a/cli/medperf/account_management/token_storage/filesystem.py b/cli/medperf/account_management/token_storage/filesystem.py index 3d12960b4..0488516ac 100644 --- a/cli/medperf/account_management/token_storage/filesystem.py +++ b/cli/medperf/account_management/token_storage/filesystem.py @@ -2,12 +2,12 @@ import base64 import logging from medperf.utils import remove_path -from medperf import config +from medperf import settings class FilesystemTokenStore: def __init__(self): - self.creds_folder = config.creds_folder + self.creds_folder = settings.creds_folder os.makedirs(self.creds_folder, mode=0o700, exist_ok=True) def __get_paths(self, account_id): @@ -19,9 +19,9 @@ def __get_paths(self, account_id): account_folder = os.path.join(self.creds_folder, account_id_encoded) os.makedirs(account_folder, mode=0o700, exist_ok=True) - access_token_file = os.path.join(account_folder, config.access_token_storage_id) + access_token_file = os.path.join(account_folder, settings.access_token_storage_id) refresh_token_file = os.path.join( - account_folder, config.refresh_token_storage_id + account_folder, settings.refresh_token_storage_id ) return access_token_file, refresh_token_file diff --git a/cli/medperf/account_management/token_storage/keyring_.py b/cli/medperf/account_management/token_storage/keyring_.py index 2aa007132..b302b5b10 100644 --- a/cli/medperf/account_management/token_storage/keyring_.py +++ b/cli/medperf/account_management/token_storage/keyring_.py @@ -2,7 +2,7 @@ users who connect to remote machines through passwordless SSH faced some issues.""" import keyring -from medperf import config +from medperf import settings class KeyringTokenStore: @@ -11,33 +11,33 @@ def __init__(self): def set_tokens(self, account_id, access_token, refresh_token): keyring.set_password( - config.access_token_storage_id, + settings.access_token_storage_id, account_id, access_token, ) keyring.set_password( - config.refresh_token_storage_id, + settings.refresh_token_storage_id, account_id, refresh_token, ) def read_tokens(self, account_id): access_token = keyring.get_password( - config.access_token_storage_id, + settings.access_token_storage_id, account_id, ) refresh_token = keyring.get_password( - config.refresh_token_storage_id, + settings.refresh_token_storage_id, account_id, ) return access_token, refresh_token def delete_tokens(self, account_id): keyring.delete_password( - config.access_token_storage_id, + settings.access_token_storage_id, account_id, ) keyring.delete_password( - config.refresh_token_storage_id, + settings.refresh_token_storage_id, account_id, ) diff --git a/cli/medperf/cli.py b/cli/medperf/cli.py index bdce039f3..f1bbc4fd3 100644 --- a/cli/medperf/cli.py +++ b/cli/medperf/cli.py @@ -4,7 +4,8 @@ import logging.handlers from medperf import __version__ -import medperf.config as config +from medperf import settings +from medperf.config_management import config from medperf.decorators import clean_except, add_inline_parameters import medperf.commands.result.result as result from medperf.commands.result.create import BenchmarkExecution @@ -94,10 +95,10 @@ def main( # Set inline parameters inline_args = ctx.params for param in inline_args: - setattr(config, param, inline_args[param]) + setattr(settings, param, inline_args[param]) # Update logging level according to the passed inline params - loglevel = config.loglevel.upper() + loglevel = settings.loglevel.upper() logging.getLogger().setLevel(loglevel) logging.getLogger("requests").setLevel(loglevel) diff --git a/cli/medperf/commands/association/approval.py b/cli/medperf/commands/association/approval.py index 4ed343911..14df5652d 100644 --- a/cli/medperf/commands/association/approval.py +++ b/cli/medperf/commands/association/approval.py @@ -1,4 +1,4 @@ -from medperf import config +from medperf.config_management import config from medperf.exceptions import InvalidArgumentError diff --git a/cli/medperf/commands/association/association.py b/cli/medperf/commands/association/association.py index fa69682ed..a0611d11d 100644 --- a/cli/medperf/commands/association/association.py +++ b/cli/medperf/commands/association/association.py @@ -1,7 +1,7 @@ import typer from typing import Optional -import medperf.config as config +from medperf.config_management import config from medperf.decorators import clean_except from medperf.commands.association.list import ListAssociations from medperf.commands.association.approval import Approval diff --git a/cli/medperf/commands/association/list.py b/cli/medperf/commands/association/list.py index e210fbc26..fc60ea6f8 100644 --- a/cli/medperf/commands/association/list.py +++ b/cli/medperf/commands/association/list.py @@ -1,6 +1,6 @@ from tabulate import tabulate -from medperf import config +from medperf.config_management import config class ListAssociations: diff --git a/cli/medperf/commands/association/priority.py b/cli/medperf/commands/association/priority.py index c58db2450..4d419dd18 100644 --- a/cli/medperf/commands/association/priority.py +++ b/cli/medperf/commands/association/priority.py @@ -1,4 +1,4 @@ -from medperf import config +from medperf.config_management import config from medperf.exceptions import InvalidArgumentError from medperf.entities.benchmark import Benchmark diff --git a/cli/medperf/commands/auth/auth.py b/cli/medperf/commands/auth/auth.py index 8908ba2c1..1a24ddf51 100644 --- a/cli/medperf/commands/auth/auth.py +++ b/cli/medperf/commands/auth/auth.py @@ -3,7 +3,7 @@ from medperf.commands.auth.logout import Logout from medperf.commands.auth.status import Status from medperf.decorators import clean_except -import medperf.config as config +from medperf.config_management import config import typer app = typer.Typer() diff --git a/cli/medperf/commands/auth/login.py b/cli/medperf/commands/auth/login.py index 6aac5fe5f..d39e384c5 100644 --- a/cli/medperf/commands/auth/login.py +++ b/cli/medperf/commands/auth/login.py @@ -1,4 +1,4 @@ -import medperf.config as config +from medperf.config_management import config from medperf.account_management import read_user_account from medperf.exceptions import InvalidArgumentError, MedperfException from email_validator import validate_email, EmailNotValidError diff --git a/cli/medperf/commands/auth/logout.py b/cli/medperf/commands/auth/logout.py index 5ca875bea..0a289bdba 100644 --- a/cli/medperf/commands/auth/logout.py +++ b/cli/medperf/commands/auth/logout.py @@ -1,4 +1,4 @@ -import medperf.config as config +from medperf.config_management import config class Logout: diff --git a/cli/medperf/commands/auth/status.py b/cli/medperf/commands/auth/status.py index af0cda0c1..459692ec5 100644 --- a/cli/medperf/commands/auth/status.py +++ b/cli/medperf/commands/auth/status.py @@ -1,4 +1,4 @@ -import medperf.config as config +from medperf.config_management import config from medperf.account_management import read_user_account diff --git a/cli/medperf/commands/auth/synapse_login.py b/cli/medperf/commands/auth/synapse_login.py index 0a0552772..645cf6e61 100644 --- a/cli/medperf/commands/auth/synapse_login.py +++ b/cli/medperf/commands/auth/synapse_login.py @@ -1,6 +1,6 @@ import synapseclient from synapseclient.core.exceptions import SynapseAuthenticationError -from medperf import config +from medperf.config_management import config from medperf.exceptions import CommunicationAuthenticationError diff --git a/cli/medperf/commands/benchmark/benchmark.py b/cli/medperf/commands/benchmark/benchmark.py index 35d719b0d..17f9f1f90 100644 --- a/cli/medperf/commands/benchmark/benchmark.py +++ b/cli/medperf/commands/benchmark/benchmark.py @@ -1,7 +1,7 @@ import typer from typing import Optional -import medperf.config as config +from medperf.config_management import config from medperf.decorators import clean_except from medperf.entities.benchmark import Benchmark from medperf.commands.list import EntityList diff --git a/cli/medperf/commands/benchmark/submit.py b/cli/medperf/commands/benchmark/submit.py index 05d1a0d10..b1d24bfa0 100644 --- a/cli/medperf/commands/benchmark/submit.py +++ b/cli/medperf/commands/benchmark/submit.py @@ -1,6 +1,7 @@ import os -import medperf.config as config +from medperf import settings +from medperf.config_management import config from medperf.entities.benchmark import Benchmark from medperf.exceptions import InvalidEntityError from medperf.utils import remove_path @@ -53,7 +54,7 @@ def __init__( self.no_cache = no_cache self.skip_data_preparation_step = skip_data_preparation_step self.bmk.metadata["demo_dataset_already_prepared"] = skip_data_preparation_step - config.tmp_paths.append(self.bmk.path) + settings.tmp_paths.append(self.bmk.path) def get_extra_information(self): """Retrieves information that must be populated automatically, diff --git a/cli/medperf/commands/compatibility_test/compatibility_test.py b/cli/medperf/commands/compatibility_test/compatibility_test.py index 0bd4a4695..a9732583b 100644 --- a/cli/medperf/commands/compatibility_test/compatibility_test.py +++ b/cli/medperf/commands/compatibility_test/compatibility_test.py @@ -1,7 +1,7 @@ import typer from typing import Optional -import medperf.config as config +from medperf.config_management import config from medperf.decorators import clean_except from medperf.commands.view import EntityView from medperf.entities.report import TestReport diff --git a/cli/medperf/commands/compatibility_test/utils.py b/cli/medperf/commands/compatibility_test/utils.py index c56a57d41..4759666f2 100644 --- a/cli/medperf/commands/compatibility_test/utils.py +++ b/cli/medperf/commands/compatibility_test/utils.py @@ -5,7 +5,8 @@ from medperf.comms.entity_resources import resources from medperf.entities.cube import Cube -import medperf.config as config +from medperf import settings +from medperf.config_management import config import os import yaml from pathlib import Path @@ -26,7 +27,7 @@ def download_demo_data(dset_url, dset_hash): # It is assumed that all demo datasets contain a file # which specifies the input of the data preparation step - paths_file = os.path.join(demo_dset_path, config.demo_dset_paths_file) + paths_file = os.path.join(demo_dset_path, settings.demo_dset_paths_file) with open(paths_file, "r") as f: paths = yaml.safe_load(f) @@ -41,14 +42,14 @@ def download_demo_data(dset_url, dset_hash): def prepare_local_cube(path): temp_uid = get_folders_hash([path]) - cubes_folder = config.cubes_folder + cubes_folder = settings.cubes_folder dst = os.path.join(cubes_folder, temp_uid) os.symlink(path, dst) logging.info(f"local cube will be linked to path: {dst}") - config.tmp_paths.append(dst) - cube_metadata_file = os.path.join(path, config.cube_metadata_filename) + settings.tmp_paths.append(dst) + cube_metadata_file = os.path.join(path, settings.cube_metadata_filename) if not os.path.exists(cube_metadata_file): - mlcube_yaml_path = os.path.join(path, config.cube_filename) + mlcube_yaml_path = os.path.join(path, settings.cube_filename) mlcube_yaml_hash = get_file_hash(mlcube_yaml_path) temp_metadata = { "id": None, @@ -63,7 +64,7 @@ def prepare_local_cube(path): metadata = Cube(**temp_metadata).todict() with open(cube_metadata_file, "w") as f: yaml.dump(metadata, f) - config.tmp_paths.append(cube_metadata_file) + settings.tmp_paths.append(cube_metadata_file) return temp_uid @@ -90,7 +91,7 @@ def prepare_cube(cube_uid: str): path = path.resolve() if os.path.exists(path): - mlcube_yaml_path = os.path.join(path, config.cube_filename) + mlcube_yaml_path = os.path.join(path, settings.cube_filename) if os.path.exists(mlcube_yaml_path): logging.info("local path provided. Creating symbolic link") temp_uid = prepare_local_cube(path) @@ -137,7 +138,7 @@ def create_test_dataset( data_creation.create_dataset_object() # TODO: existing dataset could make problems # make some changes since this is a test dataset - config.tmp_paths.remove(data_creation.dataset.path) + settings.tmp_paths.remove(data_creation.dataset.path) if skip_data_preparation_step: data_creation.make_dataset_prepared() dataset = data_creation.dataset diff --git a/cli/medperf/commands/dataset/associate.py b/cli/medperf/commands/dataset/associate.py index 84359fd1d..e3eb604de 100644 --- a/cli/medperf/commands/dataset/associate.py +++ b/cli/medperf/commands/dataset/associate.py @@ -1,4 +1,4 @@ -from medperf import config +from medperf.config_management import config from medperf.entities.dataset import Dataset from medperf.entities.benchmark import Benchmark from medperf.utils import dict_pretty_print, approval_prompt diff --git a/cli/medperf/commands/dataset/dataset.py b/cli/medperf/commands/dataset/dataset.py index fc18022ac..6d0229d15 100644 --- a/cli/medperf/commands/dataset/dataset.py +++ b/cli/medperf/commands/dataset/dataset.py @@ -1,7 +1,7 @@ import typer from typing import Optional -import medperf.config as config +from medperf.config_management import config from medperf.decorators import clean_except from medperf.entities.dataset import Dataset from medperf.commands.list import EntityList diff --git a/cli/medperf/commands/dataset/prepare.py b/cli/medperf/commands/dataset/prepare.py index 32ee6def8..e2b68bd84 100644 --- a/cli/medperf/commands/dataset/prepare.py +++ b/cli/medperf/commands/dataset/prepare.py @@ -2,7 +2,8 @@ import os import pandas as pd from medperf.entities.dataset import Dataset -import medperf.config as config +from medperf import settings +from medperf.config_management import config from medperf.entities.cube import Cube from medperf.utils import approval_prompt, dict_pretty_print from medperf.exceptions import ( @@ -40,7 +41,7 @@ def on_modified(self, event): # the latest report contents will be sent anyway, unless # one of those three finalizing actions were interrupted. # (Note that this slight chance is not blocking/buggy). - wait = config.wait_before_sending_reports + wait = settings.wait_before_sending_reports self.timer = Timer( wait, preparation.send_report, args=(report_metadata,) ) @@ -184,7 +185,7 @@ def run_prepare(self): with self.ui.interactive(): self.cube.run( task="prepare", - timeout=config.prepare_timeout, + timeout=settings.prepare_timeout, **prepare_params, ) except Exception as e: @@ -200,7 +201,7 @@ def run_prepare(self): report_sender.stop("finished") def run_sanity_check(self): - sanity_check_timeout = config.sanity_check_timeout + sanity_check_timeout = settings.sanity_check_timeout out_datapath = self.out_datapath out_labelspath = self.out_labelspath @@ -235,7 +236,7 @@ def run_sanity_check(self): self.ui.print("> Sanity checks complete") def run_statistics(self): - statistics_timeout = config.statistics_timeout + statistics_timeout = settings.statistics_timeout out_datapath = self.out_datapath out_labelspath = self.out_labelspath diff --git a/cli/medperf/commands/dataset/set_operational.py b/cli/medperf/commands/dataset/set_operational.py index 37758ddfe..0e86aec1f 100644 --- a/cli/medperf/commands/dataset/set_operational.py +++ b/cli/medperf/commands/dataset/set_operational.py @@ -1,5 +1,5 @@ from medperf.entities.dataset import Dataset -import medperf.config as config +from medperf.config_management import config from medperf.utils import approval_prompt, dict_pretty_print, get_folders_hash from medperf.exceptions import CleanExit, InvalidArgumentError import yaml @@ -74,7 +74,7 @@ def todict(self) -> dict: "state": self.dataset.state, } - def write(self) -> str: + def write(self) -> None: """Writes the registration into disk Args: filename (str, optional): name of the file. Defaults to config.reg_file. diff --git a/cli/medperf/commands/dataset/submit.py b/cli/medperf/commands/dataset/submit.py index a72a059d8..6e2617d18 100644 --- a/cli/medperf/commands/dataset/submit.py +++ b/cli/medperf/commands/dataset/submit.py @@ -2,7 +2,8 @@ from pathlib import Path import shutil from medperf.entities.dataset import Dataset -import medperf.config as config +from medperf import settings +from medperf.config_management import config from medperf.entities.cube import Cube from medperf.entities.benchmark import Benchmark from medperf.utils import ( @@ -129,7 +130,7 @@ def create_dataset_object(self): for_test=self.for_test, ) dataset.write() - config.tmp_paths.append(dataset.path) + settings.tmp_paths.append(dataset.path) dataset.set_raw_paths( raw_data_path=self.data_path, raw_labels_path=self.labels_path, diff --git a/cli/medperf/commands/execution.py b/cli/medperf/commands/execution.py index 85416fe96..10d64824a 100644 --- a/cli/medperf/commands/execution.py +++ b/cli/medperf/commands/execution.py @@ -4,7 +4,8 @@ from medperf.entities.cube import Cube from medperf.entities.dataset import Dataset from medperf.utils import generate_tmp_path -import medperf.config as config +from medperf import settings +from medperf.config_management import config from medperf.exceptions import ExecutionError import yaml @@ -52,7 +53,7 @@ def __setup_logs_path(self): data_uid = self.dataset.local_id logs_path = os.path.join( - config.experiments_logs_folder, str(model_uid), str(data_uid) + settings.experiments_logs_folder, str(model_uid), str(data_uid) ) os.makedirs(logs_path, exist_ok=True) model_logs_path = os.path.join(logs_path, "model.log") @@ -63,7 +64,7 @@ def __setup_predictions_path(self): model_uid = self.model.local_id data_uid = self.dataset.local_id preds_path = os.path.join( - config.predictions_folder, str(model_uid), str(data_uid) + settings.predictions_folder, str(model_uid), str(data_uid) ) if os.path.exists(preds_path): msg = f"Found existing predictions for model {self.model.id} on dataset " @@ -74,7 +75,7 @@ def __setup_predictions_path(self): def run_inference(self): self.ui.text = "Running model inference on dataset" - infer_timeout = config.infer_timeout + infer_timeout = settings.infer_timeout preds_path = self.preds_path data_path = self.dataset.data_path try: @@ -97,7 +98,7 @@ def run_inference(self): def run_evaluation(self): self.ui.text = "Running model evaluation on dataset" - evaluate_timeout = config.evaluate_timeout + evaluate_timeout = settings.evaluate_timeout preds_path = self.preds_path labels_path = self.dataset.labels_path results_path = self.results_path diff --git a/cli/medperf/commands/list.py b/cli/medperf/commands/list.py index fafec870c..921946c94 100644 --- a/cli/medperf/commands/list.py +++ b/cli/medperf/commands/list.py @@ -3,7 +3,7 @@ from medperf.exceptions import InvalidArgumentError from tabulate import tabulate -from medperf import config +from medperf.config_management import config from medperf.account_management import get_medperf_user_data diff --git a/cli/medperf/commands/mlcube/associate.py b/cli/medperf/commands/mlcube/associate.py index 8307caade..7fd13b75b 100644 --- a/cli/medperf/commands/mlcube/associate.py +++ b/cli/medperf/commands/mlcube/associate.py @@ -1,4 +1,4 @@ -from medperf import config +from medperf.config_management import config from medperf.entities.cube import Cube from medperf.entities.benchmark import Benchmark from medperf.utils import dict_pretty_print, approval_prompt diff --git a/cli/medperf/commands/mlcube/create.py b/cli/medperf/commands/mlcube/create.py index 3b9c12759..4bd66d9d5 100644 --- a/cli/medperf/commands/mlcube/create.py +++ b/cli/medperf/commands/mlcube/create.py @@ -2,7 +2,7 @@ from pathlib import Path from cookiecutter.main import cookiecutter -from medperf import config +from medperf import settings from medperf.exceptions import InvalidArgumentError @@ -16,7 +16,7 @@ def run(cls, template_name: str, output_path: str = ".", config_file: str = None output_path (str, Optional): The desired path for the MLCube. Defaults to current path. config_file (str, Optional): Path to a JSON configuration file. If not passed, user is prompted. """ - template_dirs = config.templates + template_dirs = settings.templates if template_name not in template_dirs: templates = list(template_dirs.keys()) raise InvalidArgumentError( diff --git a/cli/medperf/commands/mlcube/mlcube.py b/cli/medperf/commands/mlcube/mlcube.py index 9256f35f2..3e6c767a8 100644 --- a/cli/medperf/commands/mlcube/mlcube.py +++ b/cli/medperf/commands/mlcube/mlcube.py @@ -1,7 +1,8 @@ import typer from typing import Optional -import medperf.config as config +from medperf import settings +from medperf.config_management import config from medperf.decorators import clean_except from medperf.entities.cube import Cube from medperf.commands.list import EntityList @@ -35,7 +36,7 @@ def list( def create( template: str = typer.Argument( ..., - help=f"MLCube template name. Available templates: [{' | '.join(config.templates.keys())}]", + help=f"MLCube template name. Available templates: [{' | '.join(settings.templates.keys())}]", ), output_path: str = typer.Option( ".", "--output", "-o", help="Save the generated MLCube to the specified path" diff --git a/cli/medperf/commands/mlcube/submit.py b/cli/medperf/commands/mlcube/submit.py index 346aaf97a..ea64b078b 100644 --- a/cli/medperf/commands/mlcube/submit.py +++ b/cli/medperf/commands/mlcube/submit.py @@ -1,6 +1,7 @@ import os -import medperf.config as config +from medperf import settings +from medperf.config_management import config from medperf.entities.cube import Cube from medperf.utils import remove_path @@ -28,7 +29,7 @@ def __init__(self, submit_info: dict): self.comms = config.comms self.ui = config.ui self.cube = Cube(**submit_info) - config.tmp_paths.append(self.cube.path) + settings.tmp_paths.append(self.cube.path) def download(self): self.cube.download_config_files() diff --git a/cli/medperf/commands/profile.py b/cli/medperf/commands/profile.py index 0325ffa5c..1fd98bec8 100644 --- a/cli/medperf/commands/profile.py +++ b/cli/medperf/commands/profile.py @@ -1,9 +1,9 @@ import typer -from medperf import config +from medperf import settings from medperf.decorators import configurable, clean_except from medperf.utils import dict_pretty_print -from medperf.config_management import read_config, write_config +from medperf.config_management import config from medperf.exceptions import InvalidArgumentError app = typer.Typer() @@ -17,13 +17,13 @@ def activate(profile: str): Args: profile (str): Name of the profile to be used. """ - config_p = read_config() + config_p = config.read_config() if profile not in config_p: raise InvalidArgumentError("The provided profile does not exists") config_p.activate(profile) - write_config(config_p) + config_p.write_config() @app.command("create") @@ -36,13 +36,13 @@ def create( """Creates a new profile for managing and customizing configuration""" args = ctx.params args.pop("name") - config_p = read_config() + config_p = config.read_config() if name in config_p: raise InvalidArgumentError("A profile with the same name already exists") - config_p[name] = args - write_config(config_p) + config_p[name] = {**config_p.active_profile, **args} + config_p.write_config() @app.command("set") @@ -51,10 +51,10 @@ def create( def set_args(ctx: typer.Context): """Assign key-value configuration pairs to the current profile.""" args = ctx.params - config_p = read_config() + config_p = config.read_config() config_p.active_profile.update(args) - write_config(config_p) + config_p.write_config() @app.command("ls") @@ -62,7 +62,7 @@ def set_args(ctx: typer.Context): def list(): """Lists all available profiles""" ui = config.ui - config_p = read_config() + config_p = config.read_config() for profile in config_p: if config_p.is_profile_active(profile): ui.print_highlight("* " + profile) @@ -78,12 +78,12 @@ def view(profile: str = typer.Argument(None)): Args: profile (str, optional): Profile to display information from. Defaults to active profile. """ - config_p = read_config() + config_p = config.read_config() profile_config = config_p.active_profile if profile: profile_config = config_p[profile] - profile_config.pop(config.credentials_keyword, None) + profile_config.pop(settings.credentials_keyword, None) profile_name = profile if profile else config_p.active_profile_name config.ui.print(f"\nProfile '{profile_name}':") dict_pretty_print(profile_config, skip_none_values=False) @@ -97,14 +97,14 @@ def delete(profile: str): Args: profile (str): Profile to delete. """ - config_p = read_config() + config_p = config.read_config() if profile not in config_p.profiles: raise InvalidArgumentError("The provided profile does not exists") if profile in [ - config.default_profile_name, - config.testauth_profile_name, - config.test_profile_name, + settings.default_profile_name, + settings.testauth_profile_name, + settings.test_profile_name, ]: raise InvalidArgumentError("Cannot delete reserved profiles") @@ -112,4 +112,4 @@ def delete(profile: str): raise InvalidArgumentError("Cannot delete a currently activated profile") del config_p[profile] - write_config(config_p) + config_p.write_config() diff --git a/cli/medperf/commands/result/create.py b/cli/medperf/commands/result/create.py index 26d52fa2e..c3d86e025 100644 --- a/cli/medperf/commands/result/create.py +++ b/cli/medperf/commands/result/create.py @@ -8,7 +8,7 @@ from medperf.entities.cube import Cube from medperf.entities.dataset import Dataset from medperf.entities.benchmark import Benchmark -import medperf.config as config +from medperf.config_management import config from medperf.exceptions import ( InvalidArgumentError, ExecutionError, diff --git a/cli/medperf/commands/result/result.py b/cli/medperf/commands/result/result.py index 40b65c52e..02a1ba980 100644 --- a/cli/medperf/commands/result/result.py +++ b/cli/medperf/commands/result/result.py @@ -1,7 +1,7 @@ import typer from typing import Optional -import medperf.config as config +from medperf.config_management import config from medperf.decorators import clean_except from medperf.commands.view import EntityView from medperf.entities.result import Result diff --git a/cli/medperf/commands/result/submit.py b/cli/medperf/commands/result/submit.py index b69a596ce..8b23a7057 100644 --- a/cli/medperf/commands/result/submit.py +++ b/cli/medperf/commands/result/submit.py @@ -3,7 +3,7 @@ from medperf.exceptions import CleanExit from medperf.utils import remove_path, dict_pretty_print, approval_prompt from medperf.entities.result import Result -from medperf import config +from medperf.config_management import config class ResultSubmission: diff --git a/cli/medperf/commands/storage.py b/cli/medperf/commands/storage.py index a34afc936..bde082fc1 100644 --- a/cli/medperf/commands/storage.py +++ b/cli/medperf/commands/storage.py @@ -1,6 +1,7 @@ import typer -from medperf import config +from medperf import settings +from medperf.config_management import config from medperf.decorators import clean_except from medperf.utils import cleanup from medperf.storage.utils import move_storage @@ -15,8 +16,8 @@ def ls(): """Show the location of the current medperf assets""" headers = ["Asset", "Location"] info = [] - for folder in config.storage: - info.append((folder, config.storage[folder]["base"])) + for folder in settings.storage: + info.append((folder, settings.storage[folder]["base"])) tab = tabulate(info, headers=headers) config.ui.print(tab) @@ -39,5 +40,5 @@ def clean(): """Cleans up clutter paths""" # Force cleanup to be true - config.cleanup = True + settings.cleanup = True cleanup() diff --git a/cli/medperf/commands/view.py b/cli/medperf/commands/view.py index d19aedec0..675a9daa9 100644 --- a/cli/medperf/commands/view.py +++ b/cli/medperf/commands/view.py @@ -2,7 +2,7 @@ import json from typing import Union, Type -from medperf import config +from medperf.config_management import config from medperf.account_management import get_medperf_user_data from medperf.entities.interface import Entity from medperf.exceptions import InvalidArgumentError diff --git a/cli/medperf/comms/auth/__init__.py b/cli/medperf/comms/auth/__init__.py index 9a0a2ca8f..d9feaf1c7 100644 --- a/cli/medperf/comms/auth/__init__.py +++ b/cli/medperf/comms/auth/__init__.py @@ -1,2 +1,2 @@ -from .auth0 import Auth0 # noqa -from .local import Local # noqa +# from .auth0 import Auth0 # noqa +# from .local import Local # noqa diff --git a/cli/medperf/comms/auth/auth0.py b/cli/medperf/comms/auth/auth0.py index 60a052b72..edd0988ea 100644 --- a/cli/medperf/comms/auth/auth0.py +++ b/cli/medperf/comms/auth/auth0.py @@ -6,7 +6,7 @@ from medperf.comms.auth.token_verifier import verify_token from medperf.exceptions import CommunicationError, AuthenticationError import requests -import medperf.config as config +from medperf.config_management import config, Auth0Settings from medperf.utils import log_response_error from medperf.account_management import ( set_credentials, @@ -16,10 +16,8 @@ class Auth0(Auth): - def __init__(self): - self.domain = config.auth_domain - self.client_id = config.auth_client_id - self.audience = config.auth_audience + def __init__(self, auth_config: Auth0Settings): + self.settings = auth_config self._lock = threading.Lock() def login(self, email): @@ -71,11 +69,11 @@ def login(self, email): def __request_device_code(self): """Get a device code from the auth0 backend to be used for the authorization process""" - url = f"https://{self.domain}/oauth/device/code" + url = f"https://{self.settings.domain}/oauth/device/code" headers = {"content-type": "application/x-www-form-urlencoded"} body = { - "client_id": self.client_id, - "audience": self.audience, + "client_id": self.settings.client_id, + "audience": self.settings.audience, "scope": "offline_access openid email", } res = requests.post(url=url, headers=headers, data=body) @@ -99,12 +97,12 @@ def __get_device_access_token(self, device_code, polling_interval): json_res (dict): the response of the successful request, containg the access/refresh tokens pair token_issued_at (float): the timestamp when the access token was issued """ - url = f"https://{self.domain}/oauth/token" + url = f"https://{self.settings.domain}/oauth/token" headers = {"content-type": "application/x-www-form-urlencoded"} body = { "grant_type": "urn:ietf:params:oauth:grant-type:device_code", "device_code": device_code, - "client_id": self.client_id, + "client_id": self.settings.client_id, } while True: @@ -139,10 +137,10 @@ def logout(self): creds = read_credentials() refresh_token = creds["refresh_token"] - url = f"https://{self.domain}/oauth/revoke" + url = f"https://{self.settings.domain}/oauth/revoke" headers = {"content-type": "application/json"} body = { - "client_id": self.client_id, + "client_id": self.settings.client_id, "token": refresh_token, } res = requests.post(url=url, headers=headers, json=body) @@ -163,7 +161,7 @@ def access_token(self): # multiple threads want to access the database. with self._lock: # TODO: This is temporary. Use a cleaner solution. - db = sqlite3.connect(config.tokens_db, isolation_level=None, timeout=60) + db = sqlite3.connect(self.settings.tokens_db, isolation_level=None, timeout=60) try: db.execute("BEGIN EXCLUSIVE TRANSACTION") except sqlite3.OperationalError: @@ -196,12 +194,12 @@ def _access_token(self): # token_issued_at and expires_in are for the access token sliding_expiration_time = ( - token_issued_at + token_expires_in - config.token_expiration_leeway + token_issued_at + token_expires_in - self.settings.token_expiration_leeway ) absolute_expiration_time = ( logged_in_at - + config.token_absolute_expiry - - config.refresh_token_expiration_leeway + + self.settings.token_absolute_expiry + - self.settings.refresh_token_expiration_leeway ) current_time = time.time() @@ -233,11 +231,11 @@ def __refresh_access_token(self, refresh_token): access_token (str): the new access token """ - url = f"https://{self.domain}/oauth/token" + url = f"https://{self.settings.domain}/oauth/token" headers = {"content-type": "application/x-www-form-urlencoded"} body = { "grant_type": "refresh_token", - "client_id": self.client_id, + "client_id": self.settings.client_id, "refresh_token": refresh_token, } token_issued_at = time.time() diff --git a/cli/medperf/comms/auth/local.py b/cli/medperf/comms/auth/local.py index a597d4ce3..db022355b 100644 --- a/cli/medperf/comms/auth/local.py +++ b/cli/medperf/comms/auth/local.py @@ -1,5 +1,4 @@ from medperf.comms.auth.interface import Auth -import medperf.config as config from medperf.exceptions import InvalidArgumentError from medperf.account_management import ( set_credentials, @@ -10,8 +9,8 @@ class Local(Auth): - def __init__(self): - with open(config.local_tokens_path) as f: + def __init__(self, local_tokens_path): + with open(local_tokens_path) as f: self.tokens = json.load(f) def login(self, email): diff --git a/cli/medperf/comms/auth/token_verifier.py b/cli/medperf/comms/auth/token_verifier.py index 79ae7b34e..6eed92068 100644 --- a/cli/medperf/comms/auth/token_verifier.py +++ b/cli/medperf/comms/auth/token_verifier.py @@ -4,7 +4,7 @@ library's signature verifier to use this new `JwksFetcher`""" from typing import Any -from medperf import config +from medperf import settings import os import json from auth0.authentication.token_verifier import ( @@ -17,7 +17,7 @@ class JwksFetcherWithDiskCache(JwksFetcher): def _init_cache(self, cache_ttl: int) -> None: super()._init_cache(cache_ttl) - jwks_file = config.auth_jwks_file + jwks_file = settings.auth_jwks_file if not os.path.exists(jwks_file): return with open(jwks_file) as f: @@ -28,7 +28,7 @@ def _init_cache(self, cache_ttl: int) -> None: def _cache_jwks(self, jwks: dict[str, Any]) -> None: super()._cache_jwks(jwks) data = {"cache_date": self._cache_date, "jwks": jwks} - jwks_file = config.auth_jwks_file + jwks_file = settings.auth_jwks_file with open(jwks_file, "w") as f: json.dump(data, f) @@ -46,11 +46,11 @@ def __init__( def verify_token(token): signature_verifier = AsymmetricSignatureVerifierWithDiskCache( - config.auth_jwks_url, cache_ttl=config.auth_jwks_cache_ttl + settings.auth_jwks_url, cache_ttl=settings.auth_jwks_cache_ttl ) token_verifier = TokenVerifier( signature_verifier=signature_verifier, - issuer=config.auth_idtoken_issuer, - audience=config.auth_client_id, + issuer=settings.auth_idtoken_issuer, + audience=settings.auth_client_id, ) return token_verifier.verify(token) diff --git a/cli/medperf/comms/entity_resources/resources.py b/cli/medperf/comms/entity_resources/resources.py index 09dc7c0b8..3f210f579 100644 --- a/cli/medperf/comms/entity_resources/resources.py +++ b/cli/medperf/comms/entity_resources/resources.py @@ -16,7 +16,7 @@ import os import logging import yaml -import medperf.config as config +from medperf import settings from medperf.utils import ( generate_tmp_path, get_cube_image_name, @@ -85,13 +85,13 @@ def _get_regular_file(url: str, output_path: str, expected_hash: str = None) -> def get_cube(url: str, cube_path: str, expected_hash: str = None): """Downloads and writes a cube mlcube.yaml file""" - output_path = os.path.join(cube_path, config.cube_filename) + output_path = os.path.join(cube_path, settings.cube_filename) return _get_regular_file(url, output_path, expected_hash) def get_cube_params(url: str, cube_path: str, expected_hash: str = None): """Downloads and writes a cube parameters.yaml file""" - output_path = os.path.join(cube_path, config.workspace_path, config.params_filename) + output_path = os.path.join(cube_path, settings.workspace_path, settings.params_filename) return _get_regular_file(url, output_path, expected_hash) @@ -109,7 +109,7 @@ def get_cube_image(url: str, cube_path: str, hash_value: str = None) -> str: image_cube_file: Location where the image file is stored locally. hash_value (str): The hash of the downloaded file """ - image_path = config.image_path + image_path = settings.image_path image_name = get_cube_image_name(cube_path) image_cube_path = os.path.join(cube_path, image_path) os.makedirs(image_cube_path, exist_ok=True) @@ -118,7 +118,7 @@ def get_cube_image(url: str, cube_path: str, hash_value: str = None) -> str: # Remove existing links os.unlink(image_cube_file) - imgs_storage = config.images_folder + imgs_storage = settings.images_folder if not hash_value: # No hash provided, we need to download the file first tmp_output_path = generate_tmp_path() @@ -153,8 +153,8 @@ def get_cube_additional( Returns: tarball_hash (str): The hash of the downloaded tarball file """ - additional_files_folder = os.path.join(cube_path, config.additional_path) - mlcube_cache_file = os.path.join(cube_path, config.mlcube_cache_file) + additional_files_folder = os.path.join(cube_path, settings.additional_path) + mlcube_cache_file = os.path.join(cube_path, settings.mlcube_cache_file) if not _should_get_cube_additional( additional_files_folder, expected_tarball_hash, mlcube_cache_file ): @@ -163,7 +163,7 @@ def get_cube_additional( # Download the additional files. Make sure files are extracted in tmp storage # to avoid any clutter objects if uncompression fails for some reason. tmp_output_folder = generate_tmp_path() - output_tarball_path = os.path.join(tmp_output_folder, config.tarball_filename) + output_tarball_path = os.path.join(tmp_output_folder, settings.tarball_filename) tarball_hash = download_resource(url, output_tarball_path, expected_tarball_hash) untar(output_tarball_path) @@ -200,7 +200,7 @@ def get_benchmark_demo_dataset(url: str, expected_hash: str = None) -> str: # the compatibility test command and remove the option of directly passing # demo datasets. This would look cleaner. # Possible cons: if multiple benchmarks use the same demo dataset. - demo_storage = config.demo_datasets_folder + demo_storage = settings.demo_datasets_folder if expected_hash: # If the folder exists, return demo_dataset_folder = os.path.join(demo_storage, expected_hash) @@ -210,7 +210,7 @@ def get_benchmark_demo_dataset(url: str, expected_hash: str = None) -> str: # make sure files are uncompressed while in tmp storage, to avoid any clutter # objects if uncompression fails for some reason. tmp_output_folder = generate_tmp_path() - output_tarball_path = os.path.join(tmp_output_folder, config.tarball_filename) + output_tarball_path = os.path.join(tmp_output_folder, settings.tarball_filename) hash_value = download_resource(url, output_tarball_path, expected_hash) untar(output_tarball_path) diff --git a/cli/medperf/comms/entity_resources/sources/direct.py b/cli/medperf/comms/entity_resources/sources/direct.py index 9dff83403..eed2a5e7c 100644 --- a/cli/medperf/comms/entity_resources/sources/direct.py +++ b/cli/medperf/comms/entity_resources/sources/direct.py @@ -1,6 +1,6 @@ import requests from medperf.exceptions import CommunicationRetrievalError -from medperf import config +from medperf import settings from medperf.utils import remove_path, log_response_error from .source import BaseSource import validators @@ -48,7 +48,7 @@ def __download_once(self, resource_identifier: str, output_path: str): raise CommunicationRetrievalError(msg) with open(output_path, "wb") as f: - for chunk in res.iter_content(chunk_size=config.ddl_stream_chunk_size): + for chunk in res.iter_content(chunk_size=settings.ddl_stream_chunk_size): # NOTE: if the response is chunk-encoded, this may not work # check whether this is common. f.write(chunk) @@ -59,7 +59,7 @@ def download(self, resource_identifier: str, output_path: str): link servers.""" attempt = 0 - while attempt < config.ddl_max_redownload_attempts: + while attempt < settings.ddl_max_redownload_attempts: try: self.__download_once(resource_identifier, output_path) return diff --git a/cli/medperf/comms/factory.py b/cli/medperf/comms/factory.py index e60e16cc0..14040ccde 100644 --- a/cli/medperf/comms/factory.py +++ b/cli/medperf/comms/factory.py @@ -1,14 +1,15 @@ -from .rest import REST +from typing import Union + from .interface import Comms from medperf.exceptions import InvalidArgumentError -class CommsFactory: - @staticmethod - def create_comms(name: str, host: str) -> Comms: - name = name.lower() - if name == "rest": - return REST(host) - else: - msg = "the indicated communication interface doesn't exist" - raise InvalidArgumentError(msg) +def create_comms(name: str, host: str, cert: Union[str, bool, None]) -> Comms: + from .rest import REST + + name = name.lower() + if name == "rest": + return REST(host, cert) + else: + msg = "the indicated communication interface doesn't exist" + raise InvalidArgumentError(msg) diff --git a/cli/medperf/comms/rest.py b/cli/medperf/comms/rest.py index 5ac236f93..ae842729a 100644 --- a/cli/medperf/comms/rest.py +++ b/cli/medperf/comms/rest.py @@ -1,9 +1,10 @@ -from typing import List +from typing import List, Union import requests import logging from medperf.enums import Status -import medperf.config as config +from medperf import settings +from medperf.config_management import config from medperf.comms.interface import Comms from medperf.utils import ( sanitize_json, @@ -19,9 +20,9 @@ class REST(Comms): - def __init__(self, source: str): + def __init__(self, source: str, cert: Union[str, bool, None]): self.server_url = self.parse_url(source) - self.cert = config.certificate + self.cert = cert if self.cert is None: # No certificate provided, default to normal verification self.cert = True @@ -38,7 +39,7 @@ def parse_url(cls, url: str) -> str: str: parsed URL with protocol and version """ url_sections = url.split("://") - api_path = f"/api/v{config.major_version}" + api_path = f"/api/v{settings.major_version}" # Remove protocol if passed if len(url_sections) > 1: url = "".join(url_sections[1:]) @@ -78,7 +79,7 @@ def __get_list( self, url, num_elements=None, - page_size=config.default_page_size, + page_size=settings.default_page_size, offset=0, binary_reduction=False, ): @@ -91,7 +92,7 @@ def __get_list( Args: url (str): The url to retrieve elements from num_elements (int, optional): The desired number of elements to be retrieved. Defaults to None. - page_size (int, optional): Starting page size. Defaults to config.default_page_size. + page_size (int, optional): Starting page size. Defaults to settings.default_page_size. start_limit (int, optional): The starting position for element retrieval. Defaults to 0. binary_reduction (bool, optional): Wether to handle errors by halfing the page size. Defaults to False. diff --git a/cli/medperf/config_management/__init__.py b/cli/medperf/config_management/__init__.py index 380d7b47e..508ec1d26 100644 --- a/cli/medperf/config_management/__init__.py +++ b/cli/medperf/config_management/__init__.py @@ -1,36 +1,36 @@ -from .config_management import ConfigManager, read_config, write_config # noqa -from medperf import config +from .config_management import config, Auth0Settings # noqa +from medperf import settings import os def _init_config(): """builds the initial configuration file""" - default_profile = { - param: getattr(config, param) for param in config.configurable_parameters - } - config_p = ConfigManager() + default_profile = settings.default_profile.copy() + # default_profile["ui"] = settings.default_ui + + config_p = config # default profile - config_p[config.default_profile_name] = default_profile + config_p[settings.default_profile_name] = default_profile # testauth profile - config_p[config.testauth_profile_name] = { + config_p[settings.testauth_profile_name] = { **default_profile, - "server": config.local_server, - "certificate": config.local_certificate, - "auth_audience": config.auth_dev_audience, - "auth_domain": config.auth_dev_domain, - "auth_jwks_url": config.auth_dev_jwks_url, - "auth_idtoken_issuer": config.auth_dev_idtoken_issuer, - "auth_client_id": config.auth_dev_client_id, + "server": settings.local_server, + "certificate": settings.local_certificate, + "auth_audience": settings.auth_dev_audience, + "auth_domain": settings.auth_dev_domain, + "auth_jwks_url": settings.auth_dev_jwks_url, + "auth_idtoken_issuer": settings.auth_dev_idtoken_issuer, + "auth_client_id": settings.auth_dev_client_id, } # local profile - config_p[config.test_profile_name] = { + config_p[settings.test_profile_name] = { **default_profile, - "server": config.local_server, - "certificate": config.local_certificate, + "server": settings.local_server, + "certificate": settings.local_certificate, "auth_class": "Local", "auth_audience": "N/A", "auth_domain": "N/A", @@ -41,24 +41,24 @@ def _init_config(): # storage config_p.storage = { - folder: config.storage[folder]["base"] for folder in config.storage + folder: settings.storage[folder]["base"] for folder in settings.storage } - config_p.activate(config.default_profile_name) + config_p.activate(settings.default_profile_name) - os.makedirs(config.config_storage, exist_ok=True) - write_config(config_p) + os.makedirs(settings.config_storage, exist_ok=True) + config_p.write_config() def setup_config(): - if not os.path.exists(config.config_path): + if not os.path.exists(settings.config_path): _init_config() # Set current active profile parameters - config_p = read_config() - for param in config_p.active_profile: - setattr(config, param, config_p.active_profile[param]) + config.read_config() + # for param in config_p.active_profile: + # setattr(settings, param, config_p.active_profile[param]) # Set storage parameters - for folder in config_p.storage: - config.storage[folder]["base"] = config_p.storage[folder] + for folder in config.storage: + settings.storage[folder]["base"] = config.storage[folder] diff --git a/cli/medperf/config_management/config_management.py b/cli/medperf/config_management/config_management.py index 749a7ef67..da5fd0e05 100644 --- a/cli/medperf/config_management/config_management.py +++ b/cli/medperf/config_management/config_management.py @@ -1,5 +1,26 @@ +from typing import Optional, Dict, Any + import yaml -from medperf import config +from pydantic import BaseSettings +from medperf import settings +from medperf.comms.auth.interface import Auth +from medperf.comms.factory import create_comms +from medperf.comms.interface import Comms +from medperf.ui.factory import create_ui +from medperf.ui.interface import UI + + +class Auth0Settings(BaseSettings): + domain: str + jwks_url: str + idtoken_issuer: str + client_id: str + audience: str + jwks_cache_ttl: int + tokens_db: str + token_expiration_leeway: int + token_absolute_expiry: int + refresh_token_expiration_leeway: int class ConfigManager: @@ -8,23 +29,92 @@ def __init__(self): self.profiles = {} self.storage = {} + self._fields_to_override: Optional[Dict[str, Any]] = None + self._profile_to_override: Optional[str] = None + + self.ui: UI = None + self.auth: Auth = None + self.comms: Comms = None + + def keep_overridden_fields(self, profile_name: Optional[str] = None, **kwargs): + """User might override some fields temporarily through the CLI params. We'd like to + use these overridden fields every time config is read, but we don't want to save them + to the yaml file. This method allows us to keep these fields in memory, and apply them. + If profile name is given, updates should be applied to that profile only. + """ + self._fields_to_override = kwargs + self._profile_to_override = profile_name + + def _override_fields(self) -> None: + if (self._profile_to_override is not None + and self._profile_to_override != self.active_profile_name): + return + + if self._fields_to_override: + self.profiles[self.active_profile_name] = {**self.active_profile, **self._fields_to_override} + @property def active_profile(self): return self.profiles[self.active_profile_name] + def _recreate_ui(self): + ui_type = self.active_profile.get("ui") or settings.default_ui + self.ui = create_ui(ui_type) + + def _recreate_comms(self): + comms_type = self.active_profile.get("comms") or settings.default_comms + server = self.active_profile.get("server") or settings.server + if "certificate" in self.active_profile: + cert = self.active_profile.get("certificate") + else: + cert = settings.certificate + self.comms = create_comms(comms_type, server, cert) + + def _recreate_auth(self): + # Setup auth class + auth_class = self.active_profile.get("auth_class") or settings.default_auth_class + if auth_class == "Auth0": + from medperf.comms.auth.auth0 import Auth0 + auth_config = Auth0Settings( + domain=self.active_profile.get("auth_domain") or settings.auth_domain, + jwks_url=self.active_profile.get("auth_jwks_url") or settings.auth_jwks_url, + idtoken_issuer=self.active_profile.get("auth_idtoken_issuer") or settings.auth_idtoken_issuer, + client_id=self.active_profile.get("auth_client_id") or settings.auth_client_id, + audience=self.active_profile.get("auth_audience") or settings.auth_audience, + jwks_cache_ttl=settings.auth_jwks_cache_ttl, + tokens_db=settings.tokens_db, + token_expiration_leeway=settings.token_expiration_leeway, + token_absolute_expiry=settings.token_absolute_expiry, + refresh_token_expiration_leeway=settings.refresh_token_expiration_leeway, + ) + + self.auth = Auth0(auth_config=auth_config) + elif auth_class == "Local": + from medperf.comms.auth.local import Local + + self.auth = Local(local_tokens_path=settings.local_tokens_path) + else: + raise ValueError(f"Unknown Auth class {auth_class}") + def activate(self, profile_name): self.active_profile_name = profile_name + self._override_fields() + # Setup UI, COMMS + self._recreate_ui() + self._recreate_auth() + self._recreate_comms() def is_profile_active(self, profile_name): return self.active_profile_name == profile_name - def read(self, path): + def _read(self, path): with open(path) as f: data = yaml.safe_load(f) - self.active_profile_name = data["active_profile_name"] self.profiles = data["profiles"] self.storage = data["storage"] + self.activate(data["active_profile_name"]) + def write(self, path): data = { "active_profile_name": self.active_profile_name, @@ -46,14 +136,14 @@ def __delitem__(self, key): def __iter__(self): return iter(self.profiles) + def read_config(self): + config_path = settings.config_path + self._read(config_path) + return self -def read_config(): - config_p = ConfigManager() - config_path = config.config_path - config_p.read(config_path) - return config_p + def write_config(self): + config_path = settings.config_path + self.write(config_path) -def write_config(config_p: ConfigManager): - config_path = config.config_path - config_p.write(config_path) +config = ConfigManager() diff --git a/cli/medperf/decorators.py b/cli/medperf/decorators.py index b69d66e23..d43e93b5f 100644 --- a/cli/medperf/decorators.py +++ b/cli/medperf/decorators.py @@ -4,10 +4,12 @@ import functools from merge_args import merge_args from collections.abc import Callable + +from medperf.config_management import config from medperf.utils import pretty_error, cleanup from medperf.logging.utils import package_logs from medperf.exceptions import MedperfException, CleanExit -import medperf.config as config +from medperf import settings def clean_except(func: Callable) -> Callable: @@ -61,77 +63,79 @@ def configurable(func: Callable) -> Callable: def wrapper( *args, server: str = typer.Option( - config.server, "--server", help="URL of a hosted MedPerf API instance" + settings.server, "--server", help="URL of a hosted MedPerf API instance" ), + # TODO: auth_class is broken (param is written back to `settings.auth_class` + # should be stored to config profile instead auth_class: str = typer.Option( - config.auth_class, + settings.default_auth_class, "--auth_class", help="Authentication interface to use [Auth0]", ), auth_domain: str = typer.Option( - config.auth_domain, "--auth_domain", help="Auth0 domain name" + settings.auth_domain, "--auth_domain", help="Auth0 domain name" ), auth_jwks_url: str = typer.Option( - config.auth_jwks_url, "--auth_jwks_url", help="Auth0 Json Web Key set URL" + settings.auth_jwks_url, "--auth_jwks_url", help="Auth0 Json Web Key set URL" ), auth_idtoken_issuer: str = typer.Option( - config.auth_idtoken_issuer, + settings.auth_idtoken_issuer, "--auth_idtoken_issuer", help="Auth0 ID token issuer", ), auth_client_id: str = typer.Option( - config.auth_client_id, "--auth_client_id", help="Auth0 client ID" + settings.auth_client_id, "--auth_client_id", help="Auth0 client ID" ), auth_audience: str = typer.Option( - config.auth_audience, + settings.auth_audience, "--auth_audience", help="Server's Auth0 API identifier", ), certificate: str = typer.Option( - config.certificate, "--certificate", help="path to a valid SSL certificate" + settings.certificate, "--certificate", help="path to a valid SSL certificate" ), loglevel: str = typer.Option( - config.loglevel, + settings.loglevel, "--loglevel", help="Logging level [debug | info | warning | error]", ), prepare_timeout: int = typer.Option( - config.prepare_timeout, + settings.prepare_timeout, "--prepare_timeout", help="Maximum time in seconds before interrupting prepare task", ), sanity_check_timeout: int = typer.Option( - config.sanity_check_timeout, + settings.sanity_check_timeout, "--sanity_check_timeout", help="Maximum time in seconds before interrupting sanity_check task", ), statistics_timeout: int = typer.Option( - config.statistics_timeout, + settings.statistics_timeout, "--statistics_timeout", help="Maximum time in seconds before interrupting statistics task", ), infer_timeout: int = typer.Option( - config.infer_timeout, + settings.infer_timeout, "--infer_timeout", help="Maximum time in seconds before interrupting infer task", ), evaluate_timeout: int = typer.Option( - config.evaluate_timeout, + settings.evaluate_timeout, "--evaluate_timeout", help="Maximum time in seconds before interrupting evaluate task", ), container_loglevel: str = typer.Option( - config.container_loglevel, + settings.container_loglevel, "--container-loglevel", help="Logging level for containers to be run [debug | info | warning | error]", ), platform: str = typer.Option( - config.platform, + settings.platform, "--platform", help="Platform to use for MLCube. [docker | singularity]", ), gpus: str = typer.Option( - config.gpus, + settings.gpus, "--gpus", help=""" What GPUs to expose to MLCube. @@ -140,7 +144,7 @@ def wrapper( Defaults to all available GPUs""", ), cleanup: bool = typer.Option( - config.cleanup, + settings.cleanup, "--cleanup/--no-cleanup", help="Wether to clean up temporary medperf storage after execution", ), @@ -162,52 +166,52 @@ def add_inline_parameters(func: Callable) -> Callable: """ # NOTE: changing parameters here should be accompanied - # by changing config.inline_parameters + # by changing settings.inline_parameters @merge_args(func) def wrapper( *args, loglevel: str = typer.Option( - config.loglevel, + settings.loglevel, "--loglevel", help="Logging level [debug | info | warning | error]", ), prepare_timeout: int = typer.Option( - config.prepare_timeout, + settings.prepare_timeout, "--prepare_timeout", help="Maximum time in seconds before interrupting prepare task", ), sanity_check_timeout: int = typer.Option( - config.sanity_check_timeout, + settings.sanity_check_timeout, "--sanity_check_timeout", help="Maximum time in seconds before interrupting sanity_check task", ), statistics_timeout: int = typer.Option( - config.statistics_timeout, + settings.statistics_timeout, "--statistics_timeout", help="Maximum time in seconds before interrupting statistics task", ), infer_timeout: int = typer.Option( - config.infer_timeout, + settings.infer_timeout, "--infer_timeout", help="Maximum time in seconds before interrupting infer task", ), evaluate_timeout: int = typer.Option( - config.evaluate_timeout, + settings.evaluate_timeout, "--evaluate_timeout", help="Maximum time in seconds before interrupting evaluate task", ), container_loglevel: str = typer.Option( - config.container_loglevel, + settings.container_loglevel, "--container-loglevel", help="Logging level for containers to be run [debug | info | warning | error]", ), platform: str = typer.Option( - config.platform, + settings.platform, "--platform", help="Platform to use for MLCube. [docker | singularity]", ), gpus: str = typer.Option( - config.gpus, + settings.gpus, "--gpus", help=""" What GPUs to expose to MLCube. @@ -220,7 +224,7 @@ def wrapper( (e.g., --gpus="device=0,2")\n""", ), cleanup: bool = typer.Option( - config.cleanup, + settings.cleanup, "--cleanup/--no-cleanup", help="Whether to clean up temporary medperf storage after execution", ), diff --git a/cli/medperf/entities/benchmark.py b/cli/medperf/entities/benchmark.py index 841fc684d..69ff7f66c 100644 --- a/cli/medperf/entities/benchmark.py +++ b/cli/medperf/entities/benchmark.py @@ -1,7 +1,8 @@ from typing import List, Optional from pydantic import HttpUrl, Field -import medperf.config as config +from medperf import settings +from medperf.config_management import config from medperf.entities.association import Association from medperf.entities.interface import Entity from medperf.entities.schemas import ApprovableSchema, DeployableSchema @@ -37,7 +38,7 @@ def get_type(): @staticmethod def get_storage_path(): - return config.benchmarks_folder + return settings.benchmarks_folder @staticmethod def get_comms_retriever(): @@ -45,7 +46,7 @@ def get_comms_retriever(): @staticmethod def get_metadata_filename(): - return config.benchmarks_filename + return settings.benchmarks_filename @staticmethod def get_comms_uploader(): diff --git a/cli/medperf/entities/cube.py b/cli/medperf/entities/cube.py index 62a3e84f6..569671992 100644 --- a/cli/medperf/entities/cube.py +++ b/cli/medperf/entities/cube.py @@ -16,7 +16,8 @@ from medperf.entities.interface import Entity from medperf.entities.schemas import DeployableSchema from medperf.exceptions import InvalidArgumentError, ExecutionError, InvalidEntityError -import medperf.config as config +from medperf import settings +from medperf.config_management import config from medperf.comms.entity_resources import resources from medperf.account_management import get_medperf_user_data @@ -49,7 +50,7 @@ def get_type(): @staticmethod def get_storage_path(): - return config.cubes_folder + return settings.cubes_folder @staticmethod def get_comms_retriever(): @@ -57,7 +58,7 @@ def get_comms_retriever(): @staticmethod def get_metadata_filename(): - return config.cube_metadata_filename + return settings.cube_metadata_filename @staticmethod def get_comms_uploader(): @@ -71,10 +72,10 @@ def __init__(self, *args, **kwargs): """ super().__init__(*args, **kwargs) - self.cube_path = os.path.join(self.path, config.cube_filename) + self.cube_path = os.path.join(self.path, settings.cube_filename) self.params_path = None if self.git_parameters_url: - self.params_path = os.path.join(self.path, config.params_filename) + self.params_path = os.path.join(self.path, settings.params_filename) @property def local_id(self): @@ -146,15 +147,15 @@ def download_image(self): _, local_hash = resources.get_cube_image(url, self.path, tarball_hash) self.image_tarball_hash = local_hash else: - if config.platform == "docker": + if settings.platform == "docker": # For docker, image should be pulled before calculating its hash self._get_image_from_registry() self._set_image_hash_from_registry() - elif config.platform == "singularity": + elif settings.platform == "singularity": # For singularity, we need the hash first before trying to convert self._set_image_hash_from_registry() - image_folder: str = os.path.join(config.cubes_folder, config.image_path) + image_folder = os.path.join(settings.cubes_folder, settings.image_path) if os.path.exists(image_folder): for file in os.listdir(image_folder): if file == self._converted_singularity_image_name: @@ -174,10 +175,10 @@ def _set_image_hash_from_registry(self): # Retrieve image hash from MLCube logging.debug(f"Retrieving {self.id} image hash") tmp_out_yaml = generate_tmp_path() - cmd = f"mlcube --log-level {config.loglevel} inspect --mlcube={self.cube_path} --format=yaml" - cmd += f" --platform={config.platform} --output-file {tmp_out_yaml}" + cmd = f"mlcube --log-level {settings.loglevel} inspect --mlcube={self.cube_path} --format=yaml" + cmd += f" --platform={settings.platform} --output-file {tmp_out_yaml}" logging.info(f"Running MLCube command: {cmd}") - with spawn_and_kill(cmd, timeout=config.mlcube_inspect_timeout) as proc_wrapper: + with spawn_and_kill(cmd, timeout=settings.mlcube_inspect_timeout) as proc_wrapper: proc = proc_wrapper.proc combine_proc_sp_text(proc) if proc.exitstatus != 0: @@ -195,12 +196,12 @@ def _set_image_hash_from_registry(self): def _get_image_from_registry(self): # Retrieve image from image registry logging.debug(f"Retrieving {self.id} image") - cmd = f"mlcube --log-level {config.loglevel} configure --mlcube={self.cube_path} --platform={config.platform}" - if config.platform == "singularity": + cmd = f"mlcube --log-level {settings.loglevel} configure --mlcube={self.cube_path} --platform={settings.platform}" + if settings.platform == "singularity": cmd += f" -Psingularity.image={self._converted_singularity_image_name}" logging.info(f"Running MLCube command: {cmd}") with spawn_and_kill( - cmd, timeout=config.mlcube_configure_timeout + cmd, timeout=settings.mlcube_configure_timeout ) as proc_wrapper: proc = proc_wrapper.proc combine_proc_sp_text(proc) @@ -249,21 +250,21 @@ def run( kwargs (dict): additional arguments that are passed directly to the mlcube command """ kwargs.update(string_params) - cmd = f"mlcube --log-level {config.loglevel} run" - cmd += f' --mlcube="{self.cube_path}" --task={task} --platform={config.platform} --network=none' - if config.gpus is not None: - cmd += f" --gpus={config.gpus}" + cmd = f"mlcube --log-level {settings.loglevel} run" + cmd += f' --mlcube="{self.cube_path}" --task={task} --platform={settings.platform} --network=none' + if settings.gpus is not None: + cmd += f" --gpus={settings.gpus}" if read_protected_input: cmd += " --mount=ro" for k, v in kwargs.items(): cmd_arg = f'{k}="{v}"' cmd = " ".join([cmd, cmd_arg]) - container_loglevel = config.container_loglevel + container_loglevel = settings.container_loglevel # TODO: we should override run args instead of what we are doing below # we shouldn't allow arbitrary run args unless our client allows it - if config.platform == "docker": + if settings.platform == "docker": # use current user cpu_args = self.get_config("docker.cpu_args") or "" gpu_args = self.get_config("docker.gpu_args") or "" @@ -274,7 +275,7 @@ def run( if container_loglevel: cmd += f' -Pdocker.env_args="-e MEDPERF_LOGLEVEL={container_loglevel.upper()}"' - elif config.platform == "singularity": + elif settings.platform == "singularity": # use -e to discard host env vars, -C to isolate the container (see singularity run --help) run_args = self.get_config("singularity.run_args") or "" run_args = " ".join([run_args, "-eC"]).strip() diff --git a/cli/medperf/entities/dataset.py b/cli/medperf/entities/dataset.py index cb4a7e6ef..4c719affc 100644 --- a/cli/medperf/entities/dataset.py +++ b/cli/medperf/entities/dataset.py @@ -8,7 +8,8 @@ from medperf.entities.interface import Entity from medperf.entities.schemas import DeployableSchema -import medperf.config as config +from medperf import settings +from medperf.config_management import config from medperf.account_management import get_medperf_user_data @@ -39,7 +40,7 @@ def get_type(): @staticmethod def get_storage_path(): - return config.datasets_folder + return settings.datasets_folder @staticmethod def get_comms_retriever(): @@ -47,7 +48,7 @@ def get_comms_retriever(): @staticmethod def get_metadata_filename(): - return config.reg_file + return settings.reg_file @staticmethod def get_comms_uploader(): @@ -65,37 +66,37 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.data_path = os.path.join(self.path, "data") self.labels_path = os.path.join(self.path, "labels") - self.report_path = os.path.join(self.path, config.report_file) - self.metadata_path = os.path.join(self.path, config.metadata_folder) - self.statistics_path = os.path.join(self.path, config.statistics_filename) + self.report_path = os.path.join(self.path, settings.report_file) + self.metadata_path = os.path.join(self.path, settings.metadata_folder) + self.statistics_path = os.path.join(self.path, settings.statistics_filename) @property def local_id(self): return self.generated_uid def set_raw_paths(self, raw_data_path: str, raw_labels_path: str): - raw_paths_file = os.path.join(self.path, config.dataset_raw_paths_file) + raw_paths_file = os.path.join(self.path, settings.dataset_raw_paths_file) data = {"data_path": raw_data_path, "labels_path": raw_labels_path} with open(raw_paths_file, "w") as f: yaml.dump(data, f) def get_raw_paths(self): - raw_paths_file = os.path.join(self.path, config.dataset_raw_paths_file) + raw_paths_file = os.path.join(self.path, settings.dataset_raw_paths_file) with open(raw_paths_file) as f: data = yaml.safe_load(f) return data["data_path"], data["labels_path"] def mark_as_ready(self): - flag_file = os.path.join(self.path, config.ready_flag_file) + flag_file = os.path.join(self.path, settings.ready_flag_file) with open(flag_file, "w"): pass def unmark_as_ready(self): - flag_file = os.path.join(self.path, config.ready_flag_file) + flag_file = os.path.join(self.path, settings.ready_flag_file) remove_path(flag_file) def is_ready(self): - flag_file = os.path.join(self.path, config.ready_flag_file) + flag_file = os.path.join(self.path, settings.ready_flag_file) return os.path.exists(flag_file) @staticmethod diff --git a/cli/medperf/entities/report.py b/cli/medperf/entities/report.py index 6f962d5d7..5e837815a 100644 --- a/cli/medperf/entities/report.py +++ b/cli/medperf/entities/report.py @@ -1,7 +1,7 @@ import hashlib from typing import List, Union, Optional -import medperf.config as config +from medperf import settings from medperf.entities.interface import Entity @@ -44,11 +44,11 @@ def get_type(): @staticmethod def get_storage_path(): - return config.tests_folder + return settings.tests_folder @staticmethod def get_metadata_filename(): - return config.test_report_file + return settings.test_report_file def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/cli/medperf/entities/result.py b/cli/medperf/entities/result.py index 0e96d1feb..4f7ff61c6 100644 --- a/cli/medperf/entities/result.py +++ b/cli/medperf/entities/result.py @@ -1,6 +1,7 @@ from medperf.entities.interface import Entity from medperf.entities.schemas import ApprovableSchema -import medperf.config as config +from medperf import settings +from medperf.config_management import config from medperf.account_management import get_medperf_user_data @@ -28,7 +29,7 @@ def get_type(): @staticmethod def get_storage_path(): - return config.results_folder + return settings.results_folder @staticmethod def get_comms_retriever(): @@ -36,7 +37,7 @@ def get_comms_retriever(): @staticmethod def get_metadata_filename(): - return config.results_info_file + return settings.results_info_file @staticmethod def get_comms_uploader(): diff --git a/cli/medperf/init.py b/cli/medperf/init.py index 3dfe57638..9986c75cd 100644 --- a/cli/medperf/init.py +++ b/cli/medperf/init.py @@ -1,7 +1,6 @@ import os -from medperf import config -from medperf.comms.factory import CommsFactory +from medperf import settings from medperf.config_management import setup_config from medperf.logging import setup_logging from medperf.storage import ( @@ -9,7 +8,6 @@ init_storage, override_storage_config_paths, ) -from medperf.ui.factory import UIFactory def initialize(): @@ -24,21 +22,5 @@ def initialize(): init_storage() # Setup logging - log_file = os.path.join(config.logs_storage, config.log_file) - setup_logging(log_file, config.loglevel) - - # Setup UI, COMMS - config.ui = UIFactory.create_ui(config.ui) - config.comms = CommsFactory.create_comms(config.comms, config.server) - - # Setup auth class - if config.auth_class == "Auth0": - from .comms.auth import Auth0 - - config.auth = Auth0() - elif config.auth_class == "Local": - from .comms.auth import Local - - config.auth = Local() - else: - raise ValueError(f"Unknown Auth class {config.auth_class}") + log_file = os.path.join(settings.logs_storage, settings.log_file) + setup_logging(log_file, settings.loglevel) diff --git a/cli/medperf/logging/__init__.py b/cli/medperf/logging/__init__.py index 6f13bc315..77c833e68 100644 --- a/cli/medperf/logging/__init__.py +++ b/cli/medperf/logging/__init__.py @@ -4,7 +4,7 @@ from logging import handlers from .filters.redacting_filter import RedactingFilter from .formatters.newline_formatter import NewLineFormatter -from medperf import config +from medperf import settings def setup_logging(log_file: str, loglevel: str): @@ -13,7 +13,7 @@ def setup_logging(log_file: str, loglevel: str): os.makedirs(log_folder, exist_ok=True) log_fmt = "%(asctime)s | %(module)s.%(funcName)s | %(levelname)s: %(message)s" - handler = handlers.RotatingFileHandler(log_file, backupCount=config.logs_backup_count) + handler = handlers.RotatingFileHandler(log_file, backupCount=settings.logs_backup_count) handler.setFormatter(NewLineFormatter(log_fmt)) logging.basicConfig( level=loglevel.upper(), diff --git a/cli/medperf/logging/utils.py b/cli/medperf/logging/utils.py index 62f2f3c1c..298c6c530 100644 --- a/cli/medperf/logging/utils.py +++ b/cli/medperf/logging/utils.py @@ -12,7 +12,7 @@ import psutil import yaml -from medperf import config +from medperf import settings def get_system_information(): @@ -50,7 +50,7 @@ def get_disk_usage(): # We are intereseted in home storage, and storage where medperf assets # are saved. We currently assume all of them are together always, including # the datasets folder. - paths_of_interest = [os.environ["HOME"], str(config.datasets_folder)] + paths_of_interest = [os.environ["HOME"], str(settings.datasets_folder)] disk_usage_dict = {} for poi in paths_of_interest: try: @@ -71,9 +71,9 @@ def get_disk_usage(): def get_configuration_variables(): try: - config_vars = vars(config) + config_vars = vars(settings) config_dict = {} - for item in dir(config): + for item in dir(settings): if item.startswith("__"): continue config_dict[item] = config_vars[item] @@ -104,10 +104,10 @@ def filter_var_dict_for_yaml(unfiltered_dict): def get_storage_contents(): try: - storage_paths = config.storage.copy() + storage_paths = settings.storage.copy() storage_paths["credentials_folder"] = { - "base": os.path.dirname(config.creds_folder), - "name": os.path.basename(config.creds_folder), + "base": os.path.dirname(settings.creds_folder), + "name": os.path.basename(settings.creds_folder), } ignore_paths = {"datasets_folder", "predictions_folder", "results_folder"} contents = {} @@ -197,11 +197,11 @@ def log_machine_details(): def package_logs(): # Handle cases where the folder doesn't exist - if not os.path.exists(config.logs_storage): + if not os.path.exists(settings.logs_storage): return # Don't create a tarball if there's no logs to be packaged - files = os.listdir(config.logs_storage) + files = os.listdir(settings.logs_storage) if len(files) == 0: return @@ -211,9 +211,9 @@ def package_logs(): if is_logfile: logfiles.append(file) - package_file = os.path.join(config.logs_storage, config.log_package_file) + package_file = os.path.join(settings.logs_storage, settings.log_package_file) with tarfile.open(package_file, "w:gz") as tar: for file in logfiles: - filepath = os.path.join(config.logs_storage, file) + filepath = os.path.join(settings.logs_storage, file) tar.add(filepath, arcname=os.path.basename(filepath)) diff --git a/cli/medperf/config.py b/cli/medperf/settings.py similarity index 88% rename from cli/medperf/config.py rename to cli/medperf/settings.py index 2c5b520be..350fe56d0 100644 --- a/cli/medperf/config.py +++ b/cli/medperf/settings.py @@ -13,11 +13,10 @@ local_server = "https://localhost:8000" local_certificate = str(BASE_DIR / "server" / "cert.crt") -comms = "REST" +default_comms = "REST" # Auth config -auth = None # This will be overwritten by the globally initialized auth class object -auth_class = "Auth0" +default_auth_class = "Auth0" auth_domain = "auth.medperf.org" auth_dev_domain = "dev-5xl8y6uuc2hig2ly.us.auth0.com" @@ -176,35 +175,35 @@ loglevel = "debug" logs_backup_count = 100 cleanup = True -ui = "CLI" default_profile_name = "default" testauth_profile_name = "testauth" test_profile_name = "local" credentials_keyword = "credentials" -inline_parameters = [ - "loglevel", - "prepare_timeout", - "sanity_check_timeout", - "statistics_timeout", - "infer_timeout", - "evaluate_timeout", - "platform", - "gpus", - "cleanup", - "container_loglevel", -] -configurable_parameters = inline_parameters + [ - "server", - "certificate", - "auth_class", - "auth_domain", - "auth_jwks_url", - "auth_idtoken_issuer", - "auth_client_id", - "auth_audience", -] +default_profile = { + "loglevel": loglevel, + "prepare_timeout": prepare_timeout, + "sanity_check_timeout": sanity_check_timeout, + "statistics_timeout": statistics_timeout, + "infer_timeout": infer_timeout, + "evaluate_timeout": evaluate_timeout, + "platform": platform, + "gpus": gpus, + "cleanup": cleanup, + "container_loglevel": container_loglevel, + "server": server, + "certificate": certificate, + "auth_class": default_auth_class, + "auth_domain": auth_domain, + "auth_jwks_url": auth_jwks_url, + "auth_idtoken_issuer": auth_idtoken_issuer, + "auth_client_id": auth_client_id, + "auth_audience": auth_audience, +} + + +default_ui = "CLI" templates = { "data_preparator": "templates/data_preparator_mlcube", diff --git a/cli/medperf/storage/__init__.py b/cli/medperf/storage/__init__.py index 74e4e7962..f52bfba3b 100644 --- a/cli/medperf/storage/__init__.py +++ b/cli/medperf/storage/__init__.py @@ -2,51 +2,51 @@ import shutil import time -from medperf import config -from medperf.config_management import read_config, write_config +from medperf import settings +from medperf.config_management import config from .utils import full_folder_path def override_storage_config_paths(): - for folder in config.storage: - setattr(config, folder, full_folder_path(folder)) + for folder in settings.storage: + setattr(settings, folder, full_folder_path(folder)) def init_storage(): """Builds the general medperf folder structure.""" - for folder in config.storage: - folder = getattr(config, folder) + for folder in settings.storage: + folder = getattr(settings, folder) os.makedirs(folder, exist_ok=True) def apply_configuration_migrations(): - if not os.path.exists(config.config_path): + if not os.path.exists(settings.config_path): return - config_p = read_config() + config_p = config.read_config() # Migration for moving the logs folder to a new location if "logs_folder" not in config_p.storage: return src_dir = os.path.join(config_p.storage["logs_folder"], "logs") - tgt_dir = config.logs_storage + tgt_dir = settings.logs_storage shutil.move(src_dir, tgt_dir) del config_p.storage["logs_folder"] # Migration for tracking the login timestamp (i.e., refresh token issuance timestamp) - if config.credentials_keyword in config_p.active_profile: + if settings.credentials_keyword in config_p.active_profile: # So the user is logged in - if "logged_in_at" not in config_p.active_profile[config.credentials_keyword]: + if "logged_in_at" not in config_p.active_profile[settings.credentials_keyword]: # Apply migration. We will set it to the current time, since this # will make sure they will not be logged out before the actual refresh # token expiration (for a better user experience). However, currently logged # in users will still face a confusing error when the refresh token expires. - config_p.active_profile[config.credentials_keyword][ + config_p.active_profile[settings.credentials_keyword][ "logged_in_at" ] = time.time() - write_config(config_p) + config_p.write_config() diff --git a/cli/medperf/storage/utils.py b/cli/medperf/storage/utils.py index 4d1a39db7..d12321f9b 100644 --- a/cli/medperf/storage/utils.py +++ b/cli/medperf/storage/utils.py @@ -1,22 +1,22 @@ -from medperf.config_management import read_config, write_config +from medperf.config_management import config from medperf.exceptions import InvalidArgumentError -from medperf import config +from medperf import settings import os import re import shutil def full_folder_path(folder, new_base=None): - server_path = config.server.split("//")[1] + server_path = settings.server.split("//")[1] server_path = re.sub(r"[.:]", "_", server_path) - if folder in config.root_folders: - info = config.storage[folder] + if folder in settings.root_folders: + info = settings.storage[folder] base = new_base or info["base"] full_path = os.path.join(base, info["name"]) - elif folder in config.server_folders: - info = config.storage[folder] + elif folder in settings.server_folders: + info = settings.storage[folder] base = new_base or info["base"] full_path = os.path.join(base, info["name"], server_path) @@ -24,7 +24,7 @@ def full_folder_path(folder, new_base=None): def move_storage(target_base_path: str): - config_p = read_config() + config_p = config.read_config() target_base_path = os.path.abspath(target_base_path) target_base_path = os.path.normpath(target_base_path) @@ -38,11 +38,11 @@ def move_storage(target_base_path: str): else: os.makedirs(target_base_path, 0o700) - for folder in config.storage: + for folder in settings.storage: folder_path = os.path.join( - config.storage[folder]["base"], config.storage[folder]["name"] + settings.storage[folder]["base"], settings.storage[folder]["name"] ) - target_path = os.path.join(target_base_path, config.storage[folder]["name"]) + target_path = os.path.join(target_base_path, settings.storage[folder]["name"]) folder_path = os.path.normpath(folder_path) target_path = os.path.normpath(target_path) @@ -58,4 +58,4 @@ def move_storage(target_base_path: str): shutil.move(folder_path, target_path) config_p.storage[folder] = target_base_path - write_config(config_p) + config_p.write_config() diff --git a/cli/medperf/tests/commands/benchmark/test_submit.py b/cli/medperf/tests/commands/benchmark/test_submit.py index 7e2d5b23b..90de75e70 100644 --- a/cli/medperf/tests/commands/benchmark/test_submit.py +++ b/cli/medperf/tests/commands/benchmark/test_submit.py @@ -3,7 +3,7 @@ from medperf.entities.result import Result from medperf.commands.benchmark.submit import SubmitBenchmark -from medperf import config +from medperf import settings PATCH_BENCHMARK = "medperf.commands.benchmark.submit.{}" NAME_MAX_LEN = 20 @@ -34,7 +34,7 @@ def test_submit_prepares_tmp_path_for_cleanup(): submission = SubmitBenchmark(BENCHMARK_INFO) # Assert - assert submission.bmk.path in config.tmp_paths + assert submission.bmk.path in settings.tmp_paths def test_submit_uploads_benchmark_data(mocker, result, comms, ui): diff --git a/cli/medperf/tests/commands/compatibility_test/test_utils.py b/cli/medperf/tests/commands/compatibility_test/test_utils.py index e1fc2b317..97a5768e7 100644 --- a/cli/medperf/tests/commands/compatibility_test/test_utils.py +++ b/cli/medperf/tests/commands/compatibility_test/test_utils.py @@ -3,7 +3,7 @@ import medperf.commands.compatibility_test.utils as utils import os -import medperf.config as config +from medperf import settings PATCH_UTILS = "medperf.commands.compatibility_test.utils.{}" @@ -13,7 +13,7 @@ class TestPrepareCube: @pytest.fixture(autouse=True) def setup(self, fs): cube_path = "/path/to/cube" - cube_path_config = os.path.join(cube_path, config.cube_filename) + cube_path_config = os.path.join(cube_path, settings.cube_filename) fs.create_file(cube_path_config, contents="cube mlcube.yaml contents") self.cube_path = cube_path @@ -33,7 +33,7 @@ def test_local_cube_symlink_is_created_properly(self, path_attr): new_uid = utils.prepare_cube(getattr(self, path_attr)) # Assert - cube_path = os.path.join(config.cubes_folder, new_uid) + cube_path = os.path.join(settings.cubes_folder, new_uid) assert os.path.islink(cube_path) assert os.path.realpath(cube_path) == os.path.realpath(self.cube_path) @@ -43,16 +43,16 @@ def test_local_cube_metadata_is_created(self): # Assert metadata_file = os.path.join( - config.cubes_folder, + settings.cubes_folder, new_uid, - config.cube_metadata_filename, + settings.cube_metadata_filename, ) assert os.path.exists(metadata_file) def test_local_cube_metadata_is_not_created_if_found(self, fs): # Arrange - metadata_file = os.path.join(self.cube_path, config.cube_metadata_filename) + metadata_file = os.path.join(self.cube_path, settings.cube_metadata_filename) metadata_contents = "meta contents before execution" @@ -63,9 +63,9 @@ def test_local_cube_metadata_is_not_created_if_found(self, fs): # Assert metadata_file = os.path.join( - config.cubes_folder, + settings.cubes_folder, new_uid, - config.cube_metadata_filename, + settings.cube_metadata_filename, ) assert open(metadata_file).read() == metadata_contents @@ -77,10 +77,10 @@ def test_exception_is_raised_for_nonexisting_path(self): def test_cleanup_is_set_up_correctly(self): # Act uid = utils.prepare_cube(self.cube_path) - symlinked_path = os.path.join(config.cubes_folder, uid) + symlinked_path = os.path.join(settings.cubes_folder, uid) metadata_file = os.path.join( self.cube_path, - config.cube_metadata_filename, + settings.cube_metadata_filename, ) # Assert - assert set([symlinked_path, metadata_file]).issubset(config.tmp_paths) + assert set([symlinked_path, metadata_file]).issubset(settings.tmp_paths) diff --git a/cli/medperf/tests/commands/mlcube/test_create.py b/cli/medperf/tests/commands/mlcube/test_create.py index da52d72ab..23240bae8 100644 --- a/cli/medperf/tests/commands/mlcube/test_create.py +++ b/cli/medperf/tests/commands/mlcube/test_create.py @@ -1,6 +1,6 @@ import pytest -from medperf import config +from medperf import settings from medperf.commands.mlcube.create import CreateCube from medperf.exceptions import InvalidArgumentError @@ -14,7 +14,7 @@ def setup(mocker): class TestTemplate: - @pytest.mark.parametrize("template,dir", list(config.templates.items())) + @pytest.mark.parametrize("template,dir", list(settings.templates.items())) def test_valid_template_is_used(mocker, setup, template, dir): # Arrange spy = setup @@ -39,7 +39,7 @@ def test_current_path_is_used_by_default(mocker, setup): # Arrange path = "." spy = setup - template = list(config.templates.keys())[0] + template = list(settings.templates.keys())[0] # Act CreateCube.run(template) @@ -53,7 +53,7 @@ def test_current_path_is_used_by_default(mocker, setup): def test_output_path_is_used_for_template_creation(mocker, setup, output_path): # Arrange spy = setup - template = list(config.templates.keys())[0] + template = list(settings.templates.keys())[0] # Act CreateCube.run(template, output_path=output_path) @@ -68,7 +68,7 @@ class TestConfigFile: def test_config_file_is_disabled_by_default(mocker, setup): # Arrange spy = setup - template = list(config.templates.keys())[0] + template = list(settings.templates.keys())[0] # Act CreateCube.run(template) @@ -82,7 +82,7 @@ def test_config_file_is_disabled_by_default(mocker, setup): def test_config_file_is_used_when_passed(mocker, setup, config_file): # Arrange spy = setup - template = list(config.templates.keys())[0] + template = list(settings.templates.keys())[0] # Act CreateCube.run(template, config_file=config_file) @@ -97,7 +97,7 @@ def test_passing_config_file_disables_input(mocker, setup, config_file): # Arrange spy = setup should_not_input = config_file is not None - template = list(config.templates.keys())[0] + template = list(settings.templates.keys())[0] # Act CreateCube.run(template, config_file=config_file) diff --git a/cli/medperf/tests/commands/mlcube/test_submit.py b/cli/medperf/tests/commands/mlcube/test_submit.py index a946c1fef..2491ef38f 100644 --- a/cli/medperf/tests/commands/mlcube/test_submit.py +++ b/cli/medperf/tests/commands/mlcube/test_submit.py @@ -1,7 +1,7 @@ import os import pytest -import medperf.config as config +from medperf import settings from medperf.tests.mocks.cube import TestCube from medperf.commands.mlcube.submit import SubmitCube @@ -25,7 +25,7 @@ def test_submit_prepares_tmp_path_for_cleanup(): submission = SubmitCube(cube.todict()) # Assert - assert submission.cube.path in config.tmp_paths + assert submission.cube.path in settings.tmp_paths def test_run_runs_expected_flow(mocker, comms, ui, cube): @@ -57,8 +57,8 @@ def test_to_permanent_path_renames_correctly(mocker, comms, ui, cube, uid): submission.cube = cube spy = mocker.patch("os.rename") mocker.patch("os.path.exists", return_value=False) - old_path = os.path.join(config.cubes_folder, cube.local_id) - new_path = os.path.join(config.cubes_folder, str(uid)) + old_path = os.path.join(settings.cubes_folder, cube.local_id) + new_path = os.path.join(settings.cubes_folder, str(uid)) # Act submission.to_permanent_path({**cube.todict(), "id": uid}) diff --git a/cli/medperf/tests/commands/result/test_create.py b/cli/medperf/tests/commands/result/test_create.py index c69544781..5f0e2dc76 100644 --- a/cli/medperf/tests/commands/result/test_create.py +++ b/cli/medperf/tests/commands/result/test_create.py @@ -1,6 +1,6 @@ import os from unittest.mock import ANY, call -from medperf import config +from medperf import settings from medperf.exceptions import ExecutionError, InvalidArgumentError, InvalidEntityError from medperf.tests.mocks.benchmark import TestBenchmark from medperf.tests.mocks.cube import TestCube @@ -335,9 +335,9 @@ def test_execution_of_one_model_writes_result(self, mocker, setup): dset_uid = 2 bmk_uid = 1 expected_file = os.path.join( - config.results_folder, + settings.results_folder, f"b{bmk_uid}m{model_uid}d{dset_uid}", - config.results_info_file, + settings.results_info_file, ) # Act BenchmarkExecution.run(bmk_uid, dset_uid, models_uids=[model_uid]) diff --git a/cli/medperf/tests/commands/test_execution.py b/cli/medperf/tests/commands/test_execution.py index d50ca5d31..34c1b804b 100644 --- a/cli/medperf/tests/commands/test_execution.py +++ b/cli/medperf/tests/commands/test_execution.py @@ -5,7 +5,7 @@ from medperf.tests.mocks.cube import TestCube from medperf.tests.mocks.dataset import TestDataset import pytest -from medperf import config +from medperf import settings import yaml @@ -101,7 +101,7 @@ def test_no_failure_with_ignore_error(mocker, setup): def test_failure_with_existing_predictions(mocker, setup, ignore_model_errors, fs): # Arrange preds_path = os.path.join( - config.predictions_folder, + settings.predictions_folder, INPUT_MODEL.local_id, INPUT_DATASET.local_id, ) @@ -148,20 +148,20 @@ def test_results_are_returned(mocker, setup): def test_cube_run_are_called_properly(mocker, setup): # Arrange exp_preds_path = os.path.join( - config.predictions_folder, + settings.predictions_folder, INPUT_MODEL.local_id, INPUT_DATASET.local_id, ) exp_model_logs_path = os.path.join( - config.experiments_logs_folder, + settings.experiments_logs_folder, INPUT_MODEL.local_id, INPUT_DATASET.local_id, "model.log", ) exp_metrics_logs_path = os.path.join( - config.experiments_logs_folder, + settings.experiments_logs_folder, INPUT_MODEL.local_id, INPUT_DATASET.local_id, f"metrics_{INPUT_EVALUATOR.local_id}.log", @@ -170,14 +170,14 @@ def test_cube_run_are_called_properly(mocker, setup): exp_model_call = call( task="infer", output_logs=exp_model_logs_path, - timeout=config.infer_timeout, + timeout=settings.infer_timeout, data_path=INPUT_DATASET.data_path, output_path=exp_preds_path, ) exp_eval_call = call( task="evaluate", output_logs=exp_metrics_logs_path, - timeout=config.evaluate_timeout, + timeout=settings.evaluate_timeout, predictions=exp_preds_path, labels=INPUT_DATASET.labels_path, output_path=ANY, diff --git a/cli/medperf/tests/commands/test_profile.py b/cli/medperf/tests/commands/test_profile.py index b07480550..5f752264d 100644 --- a/cli/medperf/tests/commands/test_profile.py +++ b/cli/medperf/tests/commands/test_profile.py @@ -2,7 +2,7 @@ from unittest.mock import call from typer.testing import CliRunner -from medperf.config_management import read_config +from medperf.config_management import config from medperf.commands.profile import app runner = CliRunner() @@ -15,7 +15,7 @@ def test_activate_updates_active_profile(mocker, profile): runner.invoke(app, ["activate", profile]) # Assert - config_p = read_config() + config_p = config.read_config() assert config_p.is_profile_active(profile) @@ -39,7 +39,7 @@ def test_create_adds_new_profile(mocker, name, args): runner.invoke(app, ["create", "-n", name] + in_args) # Assert - config_p = read_config() + config_p = config.read_config() assert config_p[name] == {**config_p.profiles["default"], **out_cfg} @@ -65,7 +65,7 @@ def test_set_updates_profile_parameters(mocker, args): runner.invoke(app, ["set"] + in_args) # Assert - config_p = read_config() + config_p = config.read_config() assert config_p.active_profile == {**config_p.profiles["default"], **out_cfg} @@ -73,12 +73,11 @@ def test_ls_prints_profile_names(mocker, ui): # Arrange spy = mocker.patch.object(ui, "print") green_spy = mocker.patch.object(ui, "print_highlight") - config_p = read_config() calls = [ call(" " + profile) - for profile in config_p - if not config_p.is_profile_active(profile) + for profile in config + if not config.is_profile_active(profile) ] # Act @@ -86,14 +85,14 @@ def test_ls_prints_profile_names(mocker, ui): # Assert spy.assert_has_calls(calls) - green_spy.assert_called_once_with("* " + config_p.active_profile_name) + green_spy.assert_called_once_with("* " + config.active_profile_name) @pytest.mark.parametrize("profile", ["default", "local"]) def test_view_prints_profile_contents(mocker, profile): # Arrange spy = mocker.patch(PATCH_PROFILE.format("dict_pretty_print")) - config_p = read_config() + config_p = config.read_config() cfg = config_p[profile] # Act diff --git a/cli/medperf/tests/comms/entity_resources/sources/test_direct.py b/cli/medperf/tests/comms/entity_resources/sources/test_direct.py index d71e4163a..32a5e418e 100644 --- a/cli/medperf/tests/comms/entity_resources/sources/test_direct.py +++ b/cli/medperf/tests/comms/entity_resources/sources/test_direct.py @@ -1,6 +1,6 @@ from medperf.tests.mocks import MockResponse from medperf.comms.entity_resources.sources.direct import DirectLinkSource -import medperf.config as config +from medperf import settings import pytest from medperf.exceptions import CommunicationRetrievalError @@ -21,7 +21,7 @@ def test_download_works_as_expected(mocker, fs): # Assert assert open(filename).read() == "sometext" get_spy.assert_called_once_with(url, stream=True) - iter_spy.assert_called_once_with(chunk_size=config.ddl_stream_chunk_size) + iter_spy.assert_called_once_with(chunk_size=settings.ddl_stream_chunk_size) def test_download_raises_for_failed_request_after_multiple_attempts(mocker): @@ -34,4 +34,4 @@ def test_download_raises_for_failed_request_after_multiple_attempts(mocker): with pytest.raises(CommunicationRetrievalError): DirectLinkSource().download(url, filename) - assert spy.call_count == config.ddl_max_redownload_attempts + assert spy.call_count == settings.ddl_max_redownload_attempts diff --git a/cli/medperf/tests/comms/entity_resources/test_resources.py b/cli/medperf/tests/comms/entity_resources/test_resources.py index 437e518be..6e1a62f7c 100644 --- a/cli/medperf/tests/comms/entity_resources/test_resources.py +++ b/cli/medperf/tests/comms/entity_resources/test_resources.py @@ -1,7 +1,7 @@ import os from medperf.utils import get_file_hash import pytest -import medperf.config as config +from medperf import settings from medperf.comms.entity_resources import resources import yaml @@ -28,12 +28,12 @@ def test_get_cube_image_retrieves_image_if_not_local(self, mocker, url, fs): # Arrange cube_path = "cube/1" image_name = "some_name" - cube_yaml_path = os.path.join(cube_path, config.cube_filename) + cube_yaml_path = os.path.join(cube_path, settings.cube_filename) fs.create_file( cube_yaml_path, contents=yaml.dump({"singularity": {"image": image_name}}) ) - exp_file = os.path.join(cube_path, config.image_path, image_name) - os.makedirs(config.images_folder, exist_ok=True) + exp_file = os.path.join(cube_path, settings.image_path, image_name) + os.makedirs(settings.images_folder, exist_ok=True) # Act resources.get_cube_image(url, cube_path) @@ -50,11 +50,11 @@ def test_get_cube_image_uses_cache_if_available(self, mocker, url, fs): spy = mocker.spy(resources, "download_resource") cube_path = "cube/1" image_name = "some_name" - cube_yaml_path = os.path.join(cube_path, config.cube_filename) + cube_yaml_path = os.path.join(cube_path, settings.cube_filename) fs.create_file( cube_yaml_path, contents=yaml.dump({"singularity": {"image": image_name}}) ) - img_path = os.path.join(config.images_folder, "hash") + img_path = os.path.join(settings.images_folder, "hash") fs.create_file(img_path, contents="img") # Act @@ -74,7 +74,7 @@ def test_get_additional_files_does_not_download_if_folder_exists_and_hash_valid( ): # Arrange cube_path = "cube/1" - additional_files_folder = os.path.join(cube_path, config.additional_path) + additional_files_folder = os.path.join(cube_path, settings.additional_path) fs.create_dir(additional_files_folder) spy = mocker.spy(resources, "download_resource") exp_hash = resources.get_cube_additional(url, cube_path) @@ -90,7 +90,7 @@ def test_get_additional_files_will_download_if_folder_exists_and_hash_outdated( ): # Arrange cube_path = "cube/1" - additional_files_folder = os.path.join(cube_path, config.additional_path) + additional_files_folder = os.path.join(cube_path, settings.additional_path) fs.create_dir(additional_files_folder) spy = mocker.spy(resources, "download_resource") resources.get_cube_additional(url, cube_path) @@ -106,11 +106,11 @@ def test_get_additional_files_will_download_if_folder_exists_and_hash_valid_but_ ): # a test for existing installation before this feature # Arrange cube_path = "cube/1" - additional_files_folder = os.path.join(cube_path, config.additional_path) + additional_files_folder = os.path.join(cube_path, settings.additional_path) fs.create_dir(additional_files_folder) spy = mocker.spy(resources, "download_resource") exp_hash = resources.get_cube_additional(url, cube_path) - hash_cache_file = os.path.join(cube_path, config.mlcube_cache_file) + hash_cache_file = os.path.join(cube_path, settings.mlcube_cache_file) os.remove(hash_cache_file) # Act diff --git a/cli/medperf/tests/comms/test_auth0.py b/cli/medperf/tests/comms/test_auth0.py index 9eadd9662..76af2e7f6 100644 --- a/cli/medperf/tests/comms/test_auth0.py +++ b/cli/medperf/tests/comms/test_auth0.py @@ -1,15 +1,29 @@ import time from unittest.mock import ANY + +from medperf.config_management import Auth0Settings from medperf.tests.mocks import MockResponse from medperf.comms.auth.auth0 import Auth0 -from medperf import config +from medperf import settings from medperf.exceptions import AuthenticationError import sqlite3 import pytest - PATCH_AUTH = "medperf.comms.auth.auth0.{}" +test_auth_config = Auth0Settings( + domain=settings.auth_domain, + jwks_url=settings.auth_jwks_url, + idtoken_issuer=settings.auth_idtoken_issuer, + client_id=settings.auth_client_id, + audience=settings.auth_audience, + jwks_cache_ttl=settings.auth_jwks_cache_ttl, + tokens_db=settings.tokens_db, + token_expiration_leeway=settings.token_expiration_leeway, + token_absolute_expiry=settings.token_absolute_expiry, + refresh_token_expiration_leeway=settings.refresh_token_expiration_leeway, +) + @pytest.fixture def setup(mocker): @@ -25,7 +39,7 @@ def test_logout_removes_credentials(mocker, setup): spy = mocker.patch(PATCH_AUTH.format("delete_credentials")) # Act - Auth0().logout() + Auth0(test_auth_config).logout() # Assert spy.assert_called_once() @@ -44,7 +58,7 @@ def test_token_is_not_refreshed_if_not_expired(mocker, setup): spy = mocker.patch(PATCH_AUTH.format("Auth0._Auth0__refresh_access_token")) # Act - Auth0().access_token + Auth0(test_auth_config).access_token # Assert spy.assert_not_called() @@ -65,7 +79,7 @@ def test_token_is_refreshed_if_expired(mocker, setup): spy = mocker.patch(PATCH_AUTH.format("Auth0._Auth0__refresh_access_token")) # Act - Auth0().access_token + _ = Auth0(test_auth_config).access_token # Assert spy.assert_called_once() @@ -74,7 +88,7 @@ def test_token_is_refreshed_if_expired(mocker, setup): def test_logs_out_if_session_reaches_token_absolute_expiration_time(mocker, setup): # Arrange expiration_time = 900 - absolute_expiration_time = config.token_absolute_expiry + absolute_expiration_time = settings.token_absolute_expiry mocked_logged_in_at = time.time() - absolute_expiration_time mocked_issued_at = time.time() - expiration_time creds = { @@ -89,7 +103,7 @@ def test_logs_out_if_session_reaches_token_absolute_expiration_time(mocker, setu # Act with pytest.raises(AuthenticationError): - Auth0().access_token + _ = Auth0(test_auth_config).access_token # Assert spy.assert_called_once() @@ -116,7 +130,7 @@ def test_refresh_token_sets_new_tokens(mocker, setup): spy = mocker.patch(PATCH_AUTH.format("set_credentials")) # Act - Auth0()._Auth0__refresh_access_token("") + Auth0(test_auth_config)._Auth0__refresh_access_token("") # Assert spy.assert_called_once_with( diff --git a/cli/medperf/tests/comms/test_rest.py b/cli/medperf/tests/comms/test_rest.py index fb3596c98..afca275a3 100644 --- a/cli/medperf/tests/comms/test_rest.py +++ b/cli/medperf/tests/comms/test_rest.py @@ -3,7 +3,8 @@ import requests from unittest.mock import ANY, call -from medperf import config +from medperf import settings +from medperf.config_management import config from medperf.enums import Status from medperf.comms.rest import REST from medperf.tests.mocks import MockResponse @@ -15,7 +16,7 @@ @pytest.fixture def server(mocker, ui): - server = REST(url) + server = REST(url, cert=None) return server @@ -160,8 +161,10 @@ def test_auth_post_calls_authorized_request(mocker, server): @pytest.mark.parametrize("req_type", ["get", "post"]) @pytest.mark.parametrize("token", ["test", "token", "auth_token"]) def test_auth_get_adds_token_to_request(mocker, server, token, req_type, auth): + config.read_config() # Arrange auth.access_token = token + config.auth = auth if req_type == "get": spy = mocker.patch("requests.get") @@ -171,7 +174,7 @@ def test_auth_get_adds_token_to_request(mocker, server, token, req_type, auth): func = requests.post exp_headers = {"Authorization": f"Bearer {token}"} - cert_verify = config.certificate or True + cert_verify = settings.certificate or True # Act server._REST__auth_req(url, func) @@ -196,7 +199,7 @@ def test__req_sanitizes_json(mocker, server): def test__get_list_uses_default_page_size(mocker, server): # Arrange - exp_page_size = config.default_page_size + exp_page_size = settings.default_page_size exp_url = f"{full_url}?limit={exp_page_size}&offset=0" ret_body = MockResponse({"count": 1, "next": None, "results": []}, 200) spy = mocker.patch.object(server, "_REST__auth_get", return_value=ret_body) diff --git a/cli/medperf/tests/conftest.py b/cli/medperf/tests/conftest.py index d8b07ef6f..98da77d06 100644 --- a/cli/medperf/tests/conftest.py +++ b/cli/medperf/tests/conftest.py @@ -3,13 +3,16 @@ import builtins import os from copy import deepcopy -from medperf import config +from medperf import settings +from medperf.config_management import config_management +from medperf.config_management import config from medperf.ui.interface import UI from medperf.comms.interface import Comms from medperf.comms.auth.interface import Auth from medperf.init import initialize import importlib + # from copy import deepcopy @@ -69,22 +72,30 @@ def stunted_listdir(): @pytest.fixture(autouse=True) -def package_init(fs): +def package_init(fs, monkeypatch): # TODO: this might not be enough. Fixtures that don't depend on # ui, auth, or comms may still run before this fixture # all of this should hacky test setup be changed anyway - orig_config_as_dict = {} + orig_settings_as_dict = {} + try: + orig_settings = importlib.reload(settings) + except ImportError: + orig_settings = importlib.import_module("medperf.settings", "medperf") + try: - orig_config = importlib.reload(config) + config_mgmt = importlib.reload(config_management) except ImportError: - orig_config = importlib.import_module("medperf.config", "medperf") - for attr in dir(orig_config): + config_mgmt = importlib.import_module("medperf.config_management.config_management", + "medperf.config_management") + monkeypatch.setattr('medperf.config_management.config', config_mgmt.config) + + for attr in dir(orig_settings): if not attr.startswith("__"): - orig_config_as_dict[attr] = deepcopy(getattr(orig_config, attr)) + orig_settings_as_dict[attr] = deepcopy(getattr(orig_settings, attr)) initialize() yield - for attr in orig_config_as_dict: - setattr(config, attr, orig_config_as_dict[attr]) + for attr in orig_settings_as_dict: + setattr(settings, attr, orig_settings_as_dict[attr]) @pytest.fixture @@ -104,5 +115,14 @@ def comms(mocker, package_init): @pytest.fixture def auth(mocker, package_init): auth = mocker.create_autospec(spec=Auth) - config.auth = auth + settings.auth = auth return auth + + +@pytest.fixture +def fs(fs): + fs.add_real_file( + settings.local_tokens_path, + target_path=settings.local_tokens_path + ) + yield fs diff --git a/cli/medperf/tests/entities/test_cube.py b/cli/medperf/tests/entities/test_cube.py index 89e7cc5a9..ae23ecb16 100644 --- a/cli/medperf/tests/entities/test_cube.py +++ b/cli/medperf/tests/entities/test_cube.py @@ -4,7 +4,7 @@ from unittest.mock import call import medperf -import medperf.config as config +from medperf import settings from medperf.entities.cube import Cube from medperf.tests.entities.utils import ( setup_cube_fs, @@ -59,15 +59,15 @@ def set_common_attributes(self, setup): self.id = setup["remote"][0]["id"] # Specify expected path for all downloaded files - self.cube_path = os.path.join(config.cubes_folder, str(self.id)) - self.manifest_path = os.path.join(self.cube_path, config.cube_filename) + self.cube_path = os.path.join(settings.cubes_folder, str(self.id)) + self.manifest_path = os.path.join(self.cube_path, settings.cube_filename) self.params_path = os.path.join( - self.cube_path, config.workspace_path, config.params_filename + self.cube_path, settings.workspace_path, settings.params_filename ) self.add_path = os.path.join( - self.cube_path, config.additional_path, config.tarball_filename + self.cube_path, settings.additional_path, settings.tarball_filename ) - self.img_path = os.path.join(self.cube_path, config.image_path, "img.tar.gz") + self.img_path = os.path.join(self.cube_path, settings.image_path, "img.tar.gz") self.config_files_paths = [self.manifest_path, self.params_path] self.run_files_paths = [self.add_path, self.img_path] @@ -105,9 +105,9 @@ def test_download_run_files_without_image_configures_mlcube( ) spy = mocker.spy(medperf.entities.cube.spawn_and_kill, "spawn") expected_cmds = [ - f"mlcube --log-level debug configure --mlcube={self.manifest_path} --platform={config.platform}", + f"mlcube --log-level debug configure --mlcube={self.manifest_path} --platform={settings.platform}", f"mlcube --log-level debug inspect --mlcube={self.manifest_path}" - f" --format=yaml --platform={config.platform} --output-file {tmp_path}", + f" --format=yaml --platform={settings.platform} --output-file {tmp_path}", ] expected_cmds = [call(cmd, timeout=None) for cmd in expected_cmds] @@ -173,12 +173,12 @@ class TestRun: @pytest.fixture(autouse=True) def set_common_attributes(self, setup): self.id = setup["remote"][0]["id"] - self.platform = config.platform - self.gpus = config.gpus + self.platform = settings.platform + self.gpus = settings.gpus # Specify expected path for the manifest files - self.cube_path = os.path.join(config.cubes_folder, str(self.id)) - self.manifest_path = os.path.join(self.cube_path, config.cube_filename) + self.cube_path = os.path.join(settings.cubes_folder, str(self.id)) + self.manifest_path = os.path.join(self.cube_path, settings.cube_filename) @pytest.mark.parametrize("timeout", [847, None]) def test_cube_runs_command(self, mocker, timeout, setup, task): @@ -306,15 +306,15 @@ def set_common_attributes(self, fs, setup, task, out_key, out_value): self.cube_contents = { "tasks": {task: {"parameters": {"outputs": {out_key: out_value}}}} } - self.cube_path = os.path.join(config.cubes_folder, str(self.id)) - self.manifest_path = os.path.join(self.cube_path, config.cube_filename) + self.cube_path = os.path.join(settings.cubes_folder, str(self.id)) + self.manifest_path = os.path.join(self.cube_path, settings.cube_filename) fs.create_file(self.manifest_path, contents=yaml.dump(self.cube_contents)) # Construct the expected output path out_val_path = out_value if isinstance(out_value, dict): out_val_path = out_value["default"] - self.output = os.path.join(self.cube_path, config.workspace_path, out_val_path) + self.output = os.path.join(self.cube_path, settings.workspace_path, out_val_path) def test_default_output_returns_expected_path(self, task, out_key): # Arrange @@ -334,7 +334,7 @@ def test_default_output_returns_path_with_params( # Create a params file with minimal content params_contents = {param_key: param_val} params_path = os.path.join( - self.cube_path, config.workspace_path, config.params_filename + self.cube_path, settings.workspace_path, settings.params_filename ) fs.create_file(params_path, contents=yaml.dump(params_contents)) diff --git a/cli/medperf/tests/entities/utils.py b/cli/medperf/tests/entities/utils.py index c3bde6feb..e48afaba0 100644 --- a/cli/medperf/tests/entities/utils.py +++ b/cli/medperf/tests/entities/utils.py @@ -1,5 +1,5 @@ import os -from medperf import config +from medperf import settings import yaml from medperf.utils import get_file_hash @@ -24,7 +24,7 @@ def setup_benchmark_fs(ents, fs): else: bmk_contents = TestBenchmark(id=None, name=ent) - bmk_filepath = os.path.join(bmk_contents.path, config.benchmarks_filename) + bmk_filepath = os.path.join(bmk_contents.path, settings.benchmarks_filename) cubes_ids = [] cubes_ids.append(bmk_contents.data_preparation_mlcube) cubes_ids.append(bmk_contents.reference_model_mlcube) @@ -62,7 +62,7 @@ def setup_cube_fs(ents, fs): else: cube = TestCube(id=None, name=ent) - meta_cube_file = os.path.join(cube.path, config.cube_metadata_filename) + meta_cube_file = os.path.join(cube.path, settings.cube_metadata_filename) meta = cube.todict() try: fs.create_file(meta_cube_file, contents=yaml.dump(meta)) @@ -96,7 +96,7 @@ def cubefile_fn(url, cube_path, *args): pass hash = get_file_hash(filepath) # special case: tarball file - if filename == config.tarball_filename: + if filename == settings.tarball_filename: return hash return filepath, hash @@ -105,12 +105,12 @@ def cubefile_fn(url, cube_path, *args): def setup_cube_comms_downloads(mocker, fs): cube_path = "" - cube_file = config.cube_filename - params_path = config.workspace_path - params_file = config.params_filename - add_path = config.additional_path - add_file = config.tarball_filename - img_path = config.image_path + cube_file = settings.cube_filename + params_path = settings.workspace_path + params_file = settings.params_filename + add_path = settings.additional_path + add_file = settings.tarball_filename + img_path = settings.image_path img_file = "img.tar.gz" get_cube_fn = generate_cubefile_fn(fs, cube_path, cube_file) @@ -135,7 +135,7 @@ def setup_dset_fs(ents, fs): else: dset_contents = TestDataset(id=None, generated_uid=ent) - reg_dset_file = os.path.join(dset_contents.path, config.reg_file) + reg_dset_file = os.path.join(dset_contents.path, settings.reg_file) cube_id = dset_contents.data_preparation_mlcube setup_cube_fs([cube_id], fs) try: @@ -168,7 +168,7 @@ def setup_result_fs(ents, fs): else: result_contents = TestResult(id=None, name=ent) - result_file = os.path.join(result_contents.path, config.results_info_file) + result_file = os.path.join(result_contents.path, settings.results_info_file) bmk_id = result_contents.benchmark cube_id = result_contents.model dataset_id = result_contents.dataset diff --git a/cli/medperf/tests/test_account_management.py b/cli/medperf/tests/test_account_management.py index 41beebf22..c0fb0e6a5 100644 --- a/cli/medperf/tests/test_account_management.py +++ b/cli/medperf/tests/test_account_management.py @@ -1,16 +1,21 @@ import pytest -import medperf.config as config +from medperf import settings from medperf.exceptions import MedperfException from medperf import account_management - -PATCH_ACC = "medperf.account_management.account_management.{}" +PATCH_ACC = "medperf.account_management.account_management.config" class MockConfig: def __init__(self, **kwargs): self.active_profile = kwargs + def read_config(self): + return self + + def write_config(self): + pass + @pytest.fixture def mock_config(mocker, request): @@ -19,8 +24,7 @@ def mock_config(mocker, request): except AttributeError: param = {} config_p = MockConfig(**param) - mocker.patch(PATCH_ACC.format("read_config"), return_value=config_p) - mocker.patch(PATCH_ACC.format("write_config")) + mocker.patch(PATCH_ACC, new=config_p) return config_p @@ -30,26 +34,26 @@ def test_get_medperf_user_data_fail_for_not_logged_in_user(mock_config): @pytest.mark.parametrize( - "mock_config", [{config.credentials_keyword: {}}], indirect=True + "mock_config", [{settings.credentials_keyword: {}}], indirect=True ) def test_get_medperf_user_data_gets_data_from_comms(mocker, mock_config, comms): # Arrange medperf_user = "medperf_user" mocker.patch.object(comms, "get_current_user", return_value=medperf_user) - + mock_config.comms = comms # Act account_management.get_medperf_user_data() # Assert assert ( - mock_config.active_profile[config.credentials_keyword]["medperf_user"] + mock_config.active_profile[settings.credentials_keyword]["medperf_user"] == medperf_user ) @pytest.mark.parametrize( "mock_config", - [{config.credentials_keyword: {"medperf_user": "some data"}}], + [{settings.credentials_keyword: {"medperf_user": "some data"}}], indirect=True, ) def test_get_medperf_user_data_gets_data_from_cache(mocker, mock_config, comms): diff --git a/cli/medperf/tests/test_utils.py b/cli/medperf/tests/test_utils.py index abc0650ef..d2ab23e9a 100644 --- a/cli/medperf/tests/test_utils.py +++ b/cli/medperf/tests/test_utils.py @@ -6,7 +6,7 @@ from unittest.mock import mock_open, call, ANY from medperf import utils -import medperf.config as config +from medperf import settings from medperf.tests.mocks import MockTar from medperf.exceptions import MedperfException import yaml @@ -59,7 +59,7 @@ def filesystem(): def test_setup_logging_filters_sensitive_data(text, exp_output): # Arrange logging.getLogger().setLevel("DEBUG") - log_file = os.path.join(config.logs_storage, config.log_file) + log_file = os.path.join(settings.logs_storage, settings.log_file) # Act logging.debug(text) @@ -112,7 +112,7 @@ def test_cleanup_removes_files(mocker, ui, fs): # Arrange path = "/path/to/garbage.html" fs.create_file(path, contents="garbage") - config.tmp_paths = [path] + settings.tmp_paths = [path] # Act utils.cleanup() @@ -125,13 +125,13 @@ def test_cleanup_moves_files_to_trash_on_failure(mocker, ui, fs): # Arrange path = "/path/to/garbage.html" fs.create_file(path, contents="garbage") - config.tmp_paths = [path] + settings.tmp_paths = [path] def side_effect(*args, **kwargs): raise PermissionError mocker.patch("os.remove", side_effect=side_effect) - trash_folder = config.trash_folder + trash_folder = settings.trash_folder # Act utils.cleanup() @@ -146,7 +146,7 @@ def side_effect(*args, **kwargs): @pytest.mark.parametrize("datasets", [4, 287], indirect=True) def test_get_uids_returns_uids_of_datasets(mocker, datasets, path): # Arrange - mock_walk_return = iter([(config.datasets_folder, datasets, ())]) + mock_walk_return = iter([(settings.datasets_folder, datasets, ())]) spy = mocker.patch("os.walk", return_value=mock_walk_return) # Act @@ -368,7 +368,7 @@ def test_get_cube_image_name_retrieves_name(mocker, fs): cube_path = "path" mock_content = {"singularity": {"image": exp_image_name}} - target_file = os.path.join(cube_path, config.cube_filename) + target_file = os.path.join(cube_path, settings.cube_filename) fs.create_file(target_file, contents=yaml.dump(mock_content)) # Act @@ -384,7 +384,7 @@ def test_get_cube_image_name_fails_if_cube_not_configured(mocker, fs): cube_path = "path" mock_content = {"not singularity": {"image": exp_image_name}} - target_file = os.path.join(cube_path, config.cube_filename) + target_file = os.path.join(cube_path, settings.cube_filename) fs.create_file(target_file, contents=yaml.dump(mock_content)) # Act & Assert diff --git a/cli/medperf/ui/factory.py b/cli/medperf/ui/factory.py index eb70acb0b..bec715fd4 100644 --- a/cli/medperf/ui/factory.py +++ b/cli/medperf/ui/factory.py @@ -4,14 +4,12 @@ from medperf.exceptions import InvalidArgumentError -class UIFactory: - @staticmethod - def create_ui(name: str) -> UI: - name = name.lower() - if name == "cli": - return CLI() - elif name == "stdin": - return StdIn() - else: - msg = f"{name}: the indicated UI interface doesn't exist" - raise InvalidArgumentError(msg) +def create_ui(name: str) -> UI: + name = name.lower() + if name == "cli": + return CLI() + elif name == "stdin": + return StdIn() + else: + msg = f"{name}: the indicated UI interface doesn't exist" + raise InvalidArgumentError(msg) diff --git a/cli/medperf/utils.py b/cli/medperf/utils.py index 35aa697d6..283b60e03 100644 --- a/cli/medperf/utils.py +++ b/cli/medperf/utils.py @@ -20,7 +20,8 @@ from colorama import Fore, Style from pexpect.exceptions import TIMEOUT from git import Repo, GitCommandError -import medperf.config as config +from medperf import settings +from medperf.config_management import config from medperf.exceptions import ExecutionError, MedperfException @@ -75,7 +76,7 @@ def remove_path(path): def move_to_trash(path): - trash_folder = config.trash_folder + trash_folder = settings.trash_folder unique_path = os.path.join(trash_folder, generate_tmp_uid()) os.makedirs(unique_path) shutil.move(path, unique_path) @@ -83,14 +84,14 @@ def move_to_trash(path): def cleanup(): """Removes clutter and unused files from the medperf folder structure.""" - if not config.cleanup: + if not settings.cleanup: logging.info("Cleanup disabled") return - for path in config.tmp_paths: + for path in settings.tmp_paths: remove_path(path) - trash_folder = config.trash_folder + trash_folder = settings.trash_folder if os.path.exists(trash_folder) and os.listdir(trash_folder): msg = "WARNING: Failed to premanently cleanup some files. Consider deleting" msg += f" '{trash_folder}' manually to avoid unnecessary storage." @@ -146,8 +147,8 @@ def generate_tmp_path() -> str: Returns: str: generated temporary path """ - tmp_path = os.path.join(config.tmp_folder, generate_tmp_uid()) - config.tmp_paths.append(tmp_path) + tmp_path = os.path.join(settings.tmp_folder, generate_tmp_uid()) + settings.tmp_paths.append(tmp_path) return tmp_path @@ -333,8 +334,8 @@ def list_files(startpath): def log_storage(): - for folder in config.storage: - folder = getattr(config, folder) + for folder in settings.storage: + folder = getattr(settings, folder) logging.debug(list_files(folder)) @@ -381,7 +382,7 @@ def format_errors_dict(errors_dict: dict): error_msg += errors elif len(errors) == 1: # If a single error for a field is given, don't create a sublist - error_msg += errors[0] + error_msg += str(errors[0]) else: # Create a sublist otherwise for e_msg in errors: @@ -392,7 +393,7 @@ def format_errors_dict(errors_dict: dict): def get_cube_image_name(cube_path: str) -> str: """Retrieves the singularity image name of the mlcube by reading its mlcube.yaml file""" - cube_config_path = os.path.join(cube_path, config.cube_filename) + cube_config_path = os.path.join(cube_path, settings.cube_filename) with open(cube_config_path, "r") as f: cube_config = yaml.safe_load(f) @@ -430,7 +431,7 @@ def filter_latest_associations(associations, entity_key): def check_for_updates() -> None: """Check if the current branch is up-to-date with its remote counterpart using GitPython.""" - repo = Repo(config.BASE_DIR) + repo = Repo(settings.BASE_DIR) if repo.bare: logging.debug("Repo is bare") return diff --git a/cli/medperf/web_ui/app.py b/cli/medperf/web_ui/app.py index 3c8cca87d..e9fe66872 100644 --- a/cli/medperf/web_ui/app.py +++ b/cli/medperf/web_ui/app.py @@ -2,17 +2,18 @@ from pathlib import Path import typer -from fastapi import FastAPI -from fastapi.responses import RedirectResponse +from fastapi import FastAPI, APIRouter, Request, Form +from fastapi.responses import RedirectResponse, HTMLResponse from fastapi.staticfiles import StaticFiles -from medperf import config +from medperf import settings from medperf.decorators import clean_except from medperf.web_ui.common import custom_exception_handler from medperf.web_ui.datasets.routes import router as datasets_router from medperf.web_ui.benchmarks.routes import router as benchmarks_router from medperf.web_ui.mlcubes.routes import router as mlcubes_router from medperf.web_ui.yaml_fetch.routes import router as yaml_fetch_router +from medperf.web_ui.profile.routes import router as profile_router web_app = FastAPI() @@ -20,6 +21,7 @@ web_app.include_router(benchmarks_router, prefix="/benchmarks") web_app.include_router(mlcubes_router, prefix="/mlcubes") web_app.include_router(yaml_fetch_router) +web_app.include_router(profile_router) static_folder_path = Path(resources.files("medperf.web_ui")) / "static" # noqa web_app.mount( @@ -47,4 +49,4 @@ def run( ): """Runs a local web UI""" import uvicorn - uvicorn.run(web_app, host="127.0.0.1", port=port, log_level=config.loglevel) + uvicorn.run(web_app, host="127.0.0.1", port=port, log_level=settings.loglevel) diff --git a/cli/medperf/web_ui/benchmarks/routes.py b/cli/medperf/web_ui/benchmarks/routes.py index 6221a9d20..040aff480 100644 --- a/cli/medperf/web_ui/benchmarks/routes.py +++ b/cli/medperf/web_ui/benchmarks/routes.py @@ -8,7 +8,7 @@ from medperf.entities.dataset import Dataset from medperf.entities.cube import Cube from medperf.account_management import get_medperf_user_data -from medperf.web_ui.common import templates, sort_associations_display +from medperf.web_ui.common import templates, sort_associations_display, get_profiles_context router = APIRouter() logger = logging.getLogger(__name__) @@ -30,7 +30,8 @@ def benchmarks_ui(request: Request, mine_only: bool = False): mine_benchmarks = [d for d in benchmarks if d.owner == my_user_id] other_benchmarks = [d for d in benchmarks if d.owner != my_user_id] benchmarks = mine_benchmarks + other_benchmarks - return templates.TemplateResponse("benchmarks.html", {"request": request, "benchmarks": benchmarks}) + profile_context = get_profiles_context() + return templates.TemplateResponse("benchmarks.html", {"request": request, "benchmarks": benchmarks, **profile_context}) @router.get("/ui/{benchmark_id}", response_class=HTMLResponse) @@ -48,6 +49,8 @@ def benchmark_detail_ui(request: Request, benchmark_id: int): datasets = {assoc.dataset: Dataset.get(assoc.dataset) for assoc in datasets_associations if assoc.dataset} models = {assoc.model_mlcube: Cube.get(assoc.model_mlcube) for assoc in models_associations if assoc.model_mlcube} + profile_context = get_profiles_context() + return templates.TemplateResponse( "benchmark_detail.html", { @@ -60,6 +63,7 @@ def benchmark_detail_ui(request: Request, benchmark_id: int): "datasets_associations": datasets_associations, "models_associations": models_associations, "datasets": datasets, - "models": models + "models": models, + **profile_context } ) diff --git a/cli/medperf/web_ui/common.py b/cli/medperf/web_ui/common.py index 0274a1c75..c3d7d7fa6 100644 --- a/cli/medperf/web_ui/common.py +++ b/cli/medperf/web_ui/common.py @@ -6,6 +6,7 @@ from fastapi.requests import Request +from medperf.config_management import config from medperf.entities.association import Association from medperf.enums import Status @@ -50,3 +51,21 @@ def assoc_sorting_key(assoc): return status_order, date_order return sorted(associations, key=assoc_sorting_key) + + +def list_profiles() -> list[str]: + return list(config.profiles) + + +def get_active_profile() -> str: + return config.active_profile_name + + +def get_profiles_context(): + profiles = list_profiles() + active_profile = get_active_profile() + context = { + "profiles": profiles, + "active_profile": active_profile, + } + return context diff --git a/cli/medperf/web_ui/datasets/routes.py b/cli/medperf/web_ui/datasets/routes.py index 2d8eef00b..f9f1a80bd 100644 --- a/cli/medperf/web_ui/datasets/routes.py +++ b/cli/medperf/web_ui/datasets/routes.py @@ -8,7 +8,7 @@ from medperf.entities.cube import Cube from medperf.entities.dataset import Dataset from medperf.entities.benchmark import Benchmark -from medperf.web_ui.common import templates, sort_associations_display +from medperf.web_ui.common import templates, sort_associations_display, get_profiles_context router = APIRouter() logger = logging.getLogger(__name__) @@ -29,7 +29,8 @@ def datasets_ui(request: Request, mine_only: bool = False): mine_datasets = [d for d in datasets if d.owner == my_user_id] other_datasets = [d for d in datasets if d.owner != my_user_id] datasets = mine_datasets + other_datasets - return templates.TemplateResponse("datasets.html", {"request": request, "datasets": datasets}) + profile_context = get_profiles_context() + return templates.TemplateResponse("datasets.html", {"request": request, "datasets": datasets, **profile_context}) @router.get("/ui/{dataset_id}", response_class=HTMLResponse) @@ -44,6 +45,8 @@ def dataset_detail_ui(request: Request, dataset_id: int): benchmarks = {assoc.benchmark: Benchmark.get(assoc.benchmark) for assoc in benchmark_associations if assoc.benchmark} + profile_context = get_profiles_context() + return templates.TemplateResponse("dataset_detail.html", { "request": request, @@ -51,5 +54,6 @@ def dataset_detail_ui(request: Request, dataset_id: int): "entity_name": dataset.name, "prep_cube": prep_cube, "benchmark_associations": benchmark_associations, - "benchmarks": benchmarks + "benchmarks": benchmarks, + **profile_context }) diff --git a/cli/medperf/web_ui/mlcubes/routes.py b/cli/medperf/web_ui/mlcubes/routes.py index b057319f5..33975068a 100644 --- a/cli/medperf/web_ui/mlcubes/routes.py +++ b/cli/medperf/web_ui/mlcubes/routes.py @@ -8,7 +8,7 @@ from medperf.account_management import get_medperf_user_data from medperf.entities.cube import Cube from medperf.entities.benchmark import Benchmark -from medperf.web_ui.common import templates, sort_associations_display +from medperf.web_ui.common import templates, sort_associations_display, get_profiles_context router = APIRouter() logger = logging.getLogger(__name__) @@ -29,7 +29,8 @@ def mlcubes_ui(request: Request, mine_only: bool = False): mine_cubes = [c for c in mlcubes if c.owner == my_user_id] other_cubes = [c for c in mlcubes if c.owner != my_user_id] mlcubes = mine_cubes + other_cubes - return templates.TemplateResponse("mlcubes.html", {"request": request, "mlcubes": mlcubes}) + profile_context = get_profiles_context() + return templates.TemplateResponse("mlcubes.html", {"request": request, "mlcubes": mlcubes, **profile_context}) @router.get("/ui/{mlcube_id}", response_class=HTMLResponse) @@ -42,6 +43,8 @@ def mlcube_detail_ui(request: Request, mlcube_id: int): benchmarks = {assoc.benchmark: Benchmark.get(assoc.benchmark) for assoc in benchmarks_associations if assoc.benchmark} + profile_context = get_profiles_context() + return templates.TemplateResponse( "mlcube_detail.html", { @@ -49,6 +52,7 @@ def mlcube_detail_ui(request: Request, mlcube_id: int): "entity": mlcube, "entity_name": mlcube.name, "benchmarks_associations": benchmarks_associations, - "benchmarks": benchmarks + "benchmarks": benchmarks, + **profile_context, } ) diff --git a/cli/medperf/web_ui/profile/__init__.py b/cli/medperf/web_ui/profile/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/cli/medperf/web_ui/profile/routes.py b/cli/medperf/web_ui/profile/routes.py new file mode 100644 index 000000000..8b2228670 --- /dev/null +++ b/cli/medperf/web_ui/profile/routes.py @@ -0,0 +1,26 @@ +import logging + +from fastapi import APIRouter, Form + +from medperf import settings +from medperf.exceptions import InvalidArgumentError +from medperf.config_management import config + +router = APIRouter() + + +def activate_profile(profile: str) -> None: + config.read_config() + if profile not in config: + raise InvalidArgumentError("The provided profile does not exists") + config.activate(profile) + config.write_config() + + logging.debug("new profile activated") + logging.debug(f"new config creds: {config.active_profile[settings.credentials_keyword]}") + + +@router.post("/change-profile") +async def change_profile(profile: str = Form(...)): + activate_profile(profile) + return {"message": "Profile changed"} diff --git a/cli/medperf/web_ui/templates/base.html b/cli/medperf/web_ui/templates/base.html index fb731167b..6c18af08e 100644 --- a/cli/medperf/web_ui/templates/base.html +++ b/cli/medperf/web_ui/templates/base.html @@ -105,6 +105,22 @@ Datasets +