diff --git a/.github/workflows/train-ci.yml b/.github/workflows/train-ci.yml new file mode 100644 index 000000000..473e7460b --- /dev/null +++ b/.github/workflows/train-ci.yml @@ -0,0 +1,41 @@ +name: FL Integration workflow + +on: pull_request + +jobs: + setup: + name: fl-integration-test + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v2 + + - name: Set up Python 3.9 + uses: actions/setup-python@v2 + with: + python-version: '3.9' + + - name: Install dependencies + working-directory: . + run: | + python -m pip install --upgrade pip + pip install -e cli/ + pip install -r cli/test-requirements.txt + pip install -r server/requirements.txt + pip install -r server/test-requirements.txt + + - name: Set server environment vars + working-directory: ./server + run: cp .env.local.local-auth .env + + - name: Run django server in background with generated certs + working-directory: ./server + run: sh setup-dev-server.sh & sleep 6 + + - name: Run server integration tests + working-directory: ./server + run: python seed.py --cert cert.crt + + - name: Run client integration tests + working-directory: . + run: sh cli/cli_tests_training.sh -f \ No newline at end of file diff --git a/.gitignore b/.gitignore index 212d8e13e..048016fc6 100644 --- a/.gitignore +++ b/.gitignore @@ -147,3 +147,10 @@ cython_debug/ # Dev Environment Specific .vscode .venv +server/keys + +# exclude fl example +!examples/fl/mock_cert/project/ca/root.key +!examples/fl/mock_cert/project/ca/cert/root.crt +!flca/dev_assets/intermediate_ca.crt +!flca/dev_assets/root_ca.crt diff --git a/cli/cli_tests.sh b/cli/cli_tests.sh index 697c40105..4640fcca0 100755 --- a/cli/cli_tests.sh +++ b/cli/cli_tests.sh @@ -5,7 +5,6 @@ ################### Start Testing ######################## ########################################################## - ########################################################## echo "==========================================" echo "Printing MedPerf version" @@ -195,7 +194,7 @@ echo "Running data submission step" echo "=====================================" print_eval "medperf dataset submit -p $PREP_UID -d $DIRECTORY/dataset_a -l $DIRECTORY/dataset_a --name='dataset_a' --description='mock dataset a' --location='mock location a' -y" checkFailed "Data submission step failed" -DSET_A_UID=$(medperf dataset ls | grep dataset_a | tr -s ' ' | cut -d ' ' -f 1) +DSET_A_UID=$(medperf dataset ls | grep dataset_a | tr -s ' ' | awk '{$1=$1;print}' | cut -d ' ' -f 1) echo "DSET_A_UID=$DSET_A_UID" ########################################################## diff --git a/cli/cli_tests_training.sh b/cli/cli_tests_training.sh new file mode 100644 index 000000000..9e84ec05d --- /dev/null +++ b/cli/cli_tests_training.sh @@ -0,0 +1,559 @@ +# import setup +. "$(dirname $(realpath "$0"))/tests_setup.sh" + +########################################################## +################### Start Testing ######################## +########################################################## + +########################################################## +echo "==========================================" +echo "Creating test profiles for each user" +echo "==========================================" +print_eval medperf profile activate local +checkFailed "local profile creation failed" + +print_eval medperf profile create -n testmodel +checkFailed "testmodel profile creation failed" +print_eval medperf profile create -n testagg +checkFailed "testagg profile creation failed" +print_eval medperf profile create -n testdata1 +checkFailed "testdata1 profile creation failed" +print_eval medperf profile create -n testdata2 +checkFailed "testdata2 profile creation failed" +print_eval medperf profile create -n fladmin +checkFailed "fladmin profile creation failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Retrieving mock datasets" +echo "=====================================" +echo "downloading files to $DIRECTORY" + +wget -P $DIRECTORY https://storage.googleapis.com/medperf-storage/testfl/data/col1.tar.gz +tar -xf $DIRECTORY/col1.tar.gz -C $DIRECTORY +wget -P $DIRECTORY https://storage.googleapis.com/medperf-storage/testfl/data/col2.tar.gz +tar -xf $DIRECTORY/col2.tar.gz -C $DIRECTORY +wget -P $DIRECTORY https://storage.googleapis.com/medperf-storage/testfl/data/test.tar.gz +tar -xf $DIRECTORY/test.tar.gz -C $DIRECTORY +rm $DIRECTORY/col1.tar.gz +rm $DIRECTORY/col2.tar.gz +rm $DIRECTORY/test.tar.gz + +########################################################## + +echo "\n" + +########################################################## +echo "==========================================" +echo "Login each user" +echo "==========================================" +print_eval medperf profile activate testmodel +checkFailed "testmodel profile activation failed" + +print_eval medperf auth login -e $MODELOWNER +checkFailed "testmodel login failed" + +print_eval medperf profile activate testagg +checkFailed "testagg profile activation failed" + +print_eval medperf auth login -e $AGGOWNER +checkFailed "testagg login failed" + +print_eval medperf profile activate testdata1 +checkFailed "testdata1 profile activation failed" + +print_eval medperf auth login -e $DATAOWNER +checkFailed "testdata1 login failed" + +print_eval medperf profile activate testdata2 +checkFailed "testdata2 profile activation failed" + +print_eval medperf auth login -e $DATAOWNER2 +checkFailed "testdata2 login failed" + +print_eval medperf profile activate fladmin +checkFailed "fladmin profile activation failed" + +print_eval medperf auth login -e $FLADMIN +checkFailed "fladmin login failed" + +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Activate modelowner profile" +echo "=====================================" +print_eval medperf profile activate testmodel +checkFailed "testmodel profile activation failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Submit cubes" +echo "=====================================" + +print_eval medperf mlcube submit --name trainprep -m $PREP_TRAINING_MLCUBE --operational +checkFailed "Train prep submission failed" +PREP_UID=$(medperf mlcube ls | grep trainprep | head -n 1 | tr -s ' ' | cut -d ' ' -f 2) + +print_eval medperf mlcube submit --name traincube -m $TRAIN_MLCUBE -a $TRAIN_WEIGHTS --operational +checkFailed "traincube submission failed" +TRAINCUBE_UID=$(medperf mlcube ls | grep traincube | head -n 1 | tr -s ' ' | cut -d ' ' -f 2) + +print_eval medperf mlcube submit --name fladmincube -m $FLADMIN_MLCUBE --operational +checkFailed "fladmincube submission failed" +FLADMINCUBE_UID=$(medperf mlcube ls | grep fladmincube | head -n 1 | tr -s ' ' | cut -d ' ' -f 2) + +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Submit Training Experiment" +echo "=====================================" +print_eval medperf training submit -n trainexp -d trainexp -p $PREP_UID -m $TRAINCUBE_UID -a $FLADMINCUBE_UID +checkFailed "Training exp submission failed" +TRAINING_UID=$(medperf training ls | grep trainexp | tail -n 1 | tr -s ' ' | cut -d ' ' -f 2) + +# Approve benchmark +ADMIN_TOKEN=$(jq -r --arg ADMIN $ADMIN '.[$ADMIN]' $MOCK_TOKENS_FILE) +checkFailed "Retrieving admin token failed" +curl -sk -X PUT $SERVER_URL$VERSION_PREFIX/training/$TRAINING_UID/ -d '{"approval_status": "APPROVED"}' -H 'Content-Type: application/json' -H "Authorization: Bearer $ADMIN_TOKEN" +checkFailed "training exp approval failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Associate with ca" +echo "=====================================" +CA_UID=$(medperf ca ls | grep "MedPerf CA" | tr -s ' ' | awk '{$1=$1;print}' | cut -d ' ' -f 1) +print_eval medperf ca associate -t $TRAINING_UID -c $CA_UID -y +checkFailed "ca association failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Activate aggowner profile" +echo "=====================================" +print_eval medperf profile activate testagg +checkFailed "testagg profile activation failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Running aggregator submission step" +echo "=====================================" +HOSTNAME_=$(hostname -I | cut -d " " -f 1) +# HOSTNAME_=$(hostname -A | cut -d " " -f 1) # fqdn on github CI runner doesn't resolve from inside containers +print_eval medperf aggregator submit -n aggreg -a $HOSTNAME_ -p 50273 -m $TRAINCUBE_UID +checkFailed "aggregator submission step failed" +AGG_UID=$(medperf aggregator ls | grep aggreg | tr -s ' ' | awk '{$1=$1;print}' | cut -d ' ' -f 1) +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Running aggregator association step" +echo "=====================================" +print_eval medperf aggregator associate -a $AGG_UID -t $TRAINING_UID -y +checkFailed "aggregator association step failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Activate modelowner profile" +echo "=====================================" +print_eval medperf profile activate testmodel +checkFailed "testmodel profile activation failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Approve aggregator association" +echo "=====================================" +print_eval medperf association approve -t $TRAINING_UID -a $AGG_UID +checkFailed "agg association approval failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "submit plan" +echo "=====================================" +print_eval medperf training set_plan -t $TRAINING_UID -c $TRAINING_CONFIG -y +checkFailed "submit plan failed" + +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "start event" +echo "=====================================" +echo "testdo@example.com: testdo@example.com" >>./testcols.yaml +echo "testdo2@example.com: testdo2@example.com" >>./testcols.yaml +print_eval medperf training start_event -n event1 -t $TRAINING_UID -p ./testcols.yaml -y +checkFailed "start event failed" +rm ./testcols.yaml + +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Activate aggowner profile" +echo "=====================================" +print_eval medperf profile activate testagg +checkFailed "testagg profile activation failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Get aggregator cert" +echo "=====================================" +print_eval medperf certificate get_server_certificate -t $TRAINING_UID +checkFailed "Get aggregator cert failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Starting aggregator" +echo "=====================================" +print_eval medperf aggregator start -t $TRAINING_UID -p $HOSTNAME_ agg.log 2>&1 & +AGG_PID=$! + +# sleep so that the mlcube is run before we change profiles +sleep 7 + +# Check if the command is still running. +if [ ! -d "/proc/$AGG_PID" ]; then + checkFailed "agg doesn't seem to be running" 1 +fi +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Activate dataowner profile" +echo "=====================================" +print_eval medperf profile activate testdata1 +checkFailed "testdata1 profile activation failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Running data1 submission step" +echo "=====================================" +print_eval medperf dataset submit -p $PREP_UID -d $DIRECTORY/col1 -l $DIRECTORY/col1 --name="col1" --description="col1data" --location="col1location" -y +checkFailed "Data1 submission step failed" +DSET_1_UID=$(medperf dataset ls | grep col1 | tr -s ' ' | awk '{$1=$1;print}' | cut -d ' ' -f 1) +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Running data1 preparation step" +echo "=====================================" +print_eval medperf dataset prepare -d $DSET_1_UID +checkFailed "Data1 preparation step failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Running data1 set_operational step" +echo "=====================================" +print_eval medperf dataset set_operational -d $DSET_1_UID -y +checkFailed "Data1 set_operational step failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Running data1 association step" +echo "=====================================" +print_eval medperf dataset associate -d $DSET_1_UID -t $TRAINING_UID -y +checkFailed "Data1 association step failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Get data1owner cert" +echo "=====================================" +print_eval medperf certificate get_client_certificate -t $TRAINING_UID +checkFailed "Get data1owner cert failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Starting training with data1" +echo "=====================================" +print_eval medperf dataset train -d $DSET_1_UID -t $TRAINING_UID -y col1.log 2>&1 & +COL1_PID=$! + +# sleep so that the mlcube is run before we change profiles +sleep 7 + +# Check if the command is still running. +if [ ! -d "/proc/$COL1_PID" ]; then + checkFailed "data1 training doesn't seem to be running" 1 +fi +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Activate dataowner2 profile" +echo "=====================================" +print_eval medperf profile activate testdata2 +checkFailed "testdata2 profile activation failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Running data2 submission step" +echo "=====================================" +print_eval medperf dataset submit -p $PREP_UID -d $DIRECTORY/col2 -l $DIRECTORY/col2 --name="col2" --description="col2data" --location="col2location" -y +checkFailed "Data2 submission step failed" +DSET_2_UID=$(medperf dataset ls | grep col2 | tr -s ' ' | awk '{$1=$1;print}' | cut -d ' ' -f 1) +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Running data2 preparation step" +echo "=====================================" +print_eval medperf dataset prepare -d $DSET_2_UID +checkFailed "Data2 preparation step failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Running data2 set_operational step" +echo "=====================================" +print_eval medperf dataset set_operational -d $DSET_2_UID -y +checkFailed "Data2 set_operational step failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Running data2 association step" +echo "=====================================" +print_eval medperf dataset associate -d $DSET_2_UID -t $TRAINING_UID -y +checkFailed "Data2 association step failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Get data2owner cert" +echo "=====================================" +print_eval medperf certificate get_client_certificate -t $TRAINING_UID +checkFailed "Get data2owner cert failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Starting training with data2" +echo "=====================================" +print_eval medperf dataset train -d $DSET_2_UID -t $TRAINING_UID -y col2.log 2>&1 & +COL2_PID=$! + +# sleep so that the mlcube is run before we change profiles +sleep 7 + +# Check if the command is still running. +if [ ! -d "/proc/$COL2_PID" ]; then + checkFailed "data2 training doesn't seem to be running" 1 +fi +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Activate fladmin profile" +echo "=====================================" +print_eval medperf profile activate fladmin +checkFailed "fladmin profile activation failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Get fladmin certificate" +echo "=====================================" +print_eval medperf certificate get_client_certificate -t $TRAINING_UID +checkFailed "Get fladmin cert failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Check experiment status" +echo "=====================================" +print_eval medperf training get_experiment_status -t $TRAINING_UID +checkFailed "Get experiment status failed" + +sleep 3 # sleep some time then get status again + +print_eval medperf training get_experiment_status -t $TRAINING_UID +checkFailed "Get experiment status failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Update plan parameter" +echo "=====================================" +print_eval medperf training update_plan -t $TRAINING_UID -f "straggler_handling_policy.settings.straggler_cutoff_time" -v 1200 +checkFailed "Update plan failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Waiting for other prcocesses to exit successfully" +echo "=====================================" +# NOTE: on systems with small process ID table or very short-lived processes, +# there is a probability that PIDs are reused and hence the +# code below may be inaccurate. Perhaps grep processes according to command +# string is the most efficient way to reduce that probability further. +# Followup NOTE: not sure, but the "wait" command may fail if it is waiting for +# a process that is not a child of the current shell +wait $COL1_PID +checkFailed "data1 training didn't exit successfully" +wait $COL2_PID +checkFailed "data2 training didn't exit successfully" +wait $AGG_PID +checkFailed "agg didn't exit successfully" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Activate aggowner profile" +echo "=====================================" +print_eval medperf profile activate testagg +checkFailed "testagg profile activation failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "close event" +echo "=====================================" +print_eval medperf training close_event -t $TRAINING_UID -y +checkFailed "close event failed" + +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Logout users" +echo "=====================================" +print_eval medperf profile activate testmodel +checkFailed "testmodel profile activation failed" + +print_eval medperf auth logout +checkFailed "logout failed" + +print_eval medperf profile activate testagg +checkFailed "testagg profile activation failed" + +print_eval medperf auth logout +checkFailed "logout failed" + +print_eval medperf profile activate testdata1 +checkFailed "testdata1 profile activation failed" + +print_eval medperf auth logout +checkFailed "logout failed" + +print_eval medperf profile activate testdata2 +checkFailed "testdata2 profile activation failed" + +print_eval medperf auth logout +checkFailed "logout failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Delete test profiles" +echo "=====================================" +print_eval medperf profile activate default +checkFailed "default profile activation failed" + +print_eval medperf profile delete testmodel +checkFailed "Profile deletion failed" + +print_eval medperf profile delete testagg +checkFailed "Profile deletion failed" + +print_eval medperf profile delete testdata1 +checkFailed "Profile deletion failed" + +print_eval medperf profile delete testdata2 +checkFailed "Profile deletion failed" +########################################################## + +if ${CLEANUP}; then + clean +fi diff --git a/cli/medperf/_version.py b/cli/medperf/_version.py index ae7362549..bbab0242f 100644 --- a/cli/medperf/_version.py +++ b/cli/medperf/_version.py @@ -1 +1 @@ -__version__ = "0.1.3" +__version__ = "0.1.4" diff --git a/cli/medperf/certificates.py b/cli/medperf/certificates.py new file mode 100644 index 000000000..5c595a378 --- /dev/null +++ b/cli/medperf/certificates.py @@ -0,0 +1,50 @@ +from medperf.entities.ca import CA +from medperf.entities.cube import Cube + + +def get_client_cert(ca: CA, email: str, output_path: str): + """Responsible for getting a user cert""" + common_name = email + ca.prepare_config() + params = { + "ca_config": ca.config_path, + "pki_assets": output_path, + } + env = {"MEDPERF_INPUT_CN": common_name} + + mlcube = Cube.get(ca.client_mlcube) + mlcube.download_run_files() + mlube_task = "get_client_cert" + mlcube.run(task=mlube_task, env_dict=env, **params) + + +def get_server_cert(ca: CA, address: str, output_path: str): + """Responsible for getting a server cert""" + common_name = address + ca.prepare_config() + params = { + "ca_config": ca.config_path, + "pki_assets": output_path, + } + env = {"MEDPERF_INPUT_CN": common_name} + + mlcube = Cube.get(ca.server_mlcube) + mlcube.download_run_files() + mlube_task = "get_server_cert" + mlcube.run(task=mlube_task, env_dict=env, port=80, **params) + + +def trust(ca: CA): + """Verifies the CA cert fingerprint and writes it to the MedPerf storage. + This is needed when running a workload, either by the users or + by the aggregator + """ + ca.prepare_config() + params = { + "ca_config": ca.config_path, + "pki_assets": ca.pki_assets, + } + mlcube = Cube.get(ca.ca_mlcube) + mlcube.download_run_files() + mlube_task = "trust" + mlcube.run(task=mlube_task, **params) diff --git a/cli/medperf/cli.py b/cli/medperf/cli.py index 4fc7102c4..06e2608a4 100644 --- a/cli/medperf/cli.py +++ b/cli/medperf/cli.py @@ -16,7 +16,12 @@ import medperf.commands.profile as profile import medperf.commands.association.association as association import medperf.commands.compatibility_test.compatibility_test as compatibility_test +import medperf.commands.training.training as training +import medperf.commands.aggregator.aggregator as aggregator +import medperf.commands.ca.ca as ca +import medperf.commands.certificate.certificate as certificate import medperf.commands.storage as storage + from medperf.utils import check_for_updates from medperf.logging.utils import log_machine_details @@ -30,6 +35,10 @@ app.add_typer(compatibility_test.app, name="test", help="Manage compatibility tests") app.add_typer(auth.app, name="auth", help="Authentication") app.add_typer(storage.app, name="storage", help="Storage management") +app.add_typer(training.app, name="training", help="Manage training experiments") +app.add_typer(aggregator.app, name="aggregator", help="Manage aggregators") +app.add_typer(ca.app, name="ca", help="Manage CAs") +app.add_typer(certificate.app, name="certificate", help="Manage certificates") @app.command("run") diff --git a/cli/medperf/commands/aggregator/aggregator.py b/cli/medperf/commands/aggregator/aggregator.py new file mode 100644 index 000000000..54775f627 --- /dev/null +++ b/cli/medperf/commands/aggregator/aggregator.py @@ -0,0 +1,121 @@ +from typing import Optional +from medperf.entities.aggregator import Aggregator +import typer + +import medperf.config as config +from medperf.decorators import clean_except +from medperf.commands.aggregator.submit import SubmitAggregator +from medperf.commands.aggregator.associate import AssociateAggregator +from medperf.commands.aggregator.run import StartAggregator + +from medperf.commands.list import EntityList +from medperf.commands.view import EntityView + +app = typer.Typer() + + +@app.command("submit") +@clean_except +def submit( + name: str = typer.Option(..., "--name", "-n", help="Name of the aggregator"), + address: str = typer.Option( + ..., "--address", "-a", help="Address/domain of the aggregator" + ), + port: int = typer.Option( + ..., "--port", "-p", help="The port which the aggregator will use" + ), + aggregation_mlcube: int = typer.Option( + ..., "--aggregation-mlcube", "-m", help="Aggregation MLCube UID" + ), +): + """Submits an aggregator""" + SubmitAggregator.run(name, address, port, aggregation_mlcube) + config.ui.print("✅ Done!") + + +@app.command("associate") +@clean_except +def associate( + aggregator_id: int = typer.Option( + ..., "--aggregator_id", "-a", help="UID of benchmark to associate with" + ), + training_exp_id: int = typer.Option( + ..., "--training_exp_id", "-t", help="UID of benchmark to associate with" + ), + approval: bool = typer.Option(False, "-y", help="Skip approval step"), +): + """Associates a benchmark with a given mlcube or dataset. Only one option at a time.""" + AssociateAggregator.run(aggregator_id, training_exp_id, approved=approval) + config.ui.print("✅ Done!") + + +@app.command("start") +@clean_except +def run( + training_exp_id: int = typer.Option( + ..., + "--training_exp_id", + "-t", + help="UID of training experiment whose aggregator to be run", + ), + publish_on: str = typer.Option( + "127.0.0.1", + "--publish_on", + "-p", + help="Host network interface on which the aggregator will listen", + ), + overwrite: bool = typer.Option( + False, "--overwrite", help="Overwrite outputs if present" + ), +): + """Starts the aggregation server of a training experiment""" + StartAggregator.run(training_exp_id, publish_on, overwrite) + config.ui.print("✅ Done!") + + +@app.command("ls") +@clean_except +def list( + unregistered: bool = typer.Option( + False, "--unregistered", help="Get unregistered aggregators" + ), + mine: bool = typer.Option(False, "--mine", help="Get current-user aggregators"), +): + """List aggregators""" + EntityList.run( + Aggregator, + fields=["UID", "Name", "Address", "Port"], + unregistered=unregistered, + mine_only=mine, + ) + + +@app.command("view") +@clean_except +def view( + entity_id: Optional[int] = typer.Argument(None, help="Benchmark ID"), + format: str = typer.Option( + "yaml", + "-f", + "--format", + help="Format to display contents. Available formats: [yaml, json]", + ), + unregistered: bool = typer.Option( + False, + "--unregistered", + help="Display unregistered benchmarks if benchmark ID is not provided", + ), + mine: bool = typer.Option( + False, + "--mine", + help="Display current-user benchmarks if benchmark ID is not provided", + ), + output: str = typer.Option( + None, + "--output", + "-o", + help="Output file to store contents. If not provided, the output will be displayed", + ), +): + """Displays the information of one or more aggregators""" + EntityView.run(entity_id, Aggregator, format, unregistered, mine, output) diff --git a/cli/medperf/commands/aggregator/associate.py b/cli/medperf/commands/aggregator/associate.py new file mode 100644 index 000000000..0222c3aef --- /dev/null +++ b/cli/medperf/commands/aggregator/associate.py @@ -0,0 +1,34 @@ +from medperf import config +from medperf.entities.aggregator import Aggregator +from medperf.entities.training_exp import TrainingExp +from medperf.utils import approval_prompt +from medperf.exceptions import InvalidArgumentError + + +class AssociateAggregator: + @staticmethod + def run(training_exp_id: int, agg_uid: int, approved=False): + """Associates an aggregator with a training experiment + + Args: + agg_uid (int): UID of the registered aggregator to associate + benchmark_uid (int): UID of the benchmark to associate with + """ + comms = config.comms + ui = config.ui + agg = Aggregator.get(agg_uid) + if agg.id is None: + msg = "The provided aggregator is not registered." + raise InvalidArgumentError(msg) + + training_exp = TrainingExp.get(training_exp_id) + msg = "Please confirm that you would like to associate" + msg += f" the aggregator {agg.name} with the training exp {training_exp.name}." + msg += " [Y/n]" + + approved = approved or approval_prompt(msg) + if approved: + ui.print("Generating aggregator training association") + comms.associate_training_aggregator(agg.id, training_exp_id) + else: + ui.print("Aggregator association operation cancelled.") diff --git a/cli/medperf/commands/aggregator/run.py b/cli/medperf/commands/aggregator/run.py new file mode 100644 index 000000000..0ba96f3a0 --- /dev/null +++ b/cli/medperf/commands/aggregator/run.py @@ -0,0 +1,107 @@ +import os +from medperf import config +from medperf.entities.ca import CA +from medperf.entities.event import TrainingEvent +from medperf.exceptions import InvalidArgumentError, MedperfException +from medperf.entities.training_exp import TrainingExp +from medperf.entities.aggregator import Aggregator +from medperf.entities.cube import Cube +from medperf.utils import get_pki_assets_path, remove_path +from medperf.certificates import trust + + +class StartAggregator: + @classmethod + def run(cls, training_exp_id: int, publish_on: str, overwrite: bool = False): + """Starts the aggregation server of a training experiment + + Args: + training_exp_id (int): Training experiment UID. + """ + execution = cls(training_exp_id, publish_on, overwrite) + execution.prepare() + execution.validate() + execution.check_existing_outputs() + execution.prepare_aggregator() + execution.prepare_participants_list() + execution.prepare_plan() + execution.prepare_pki_assets() + with config.ui.interactive(): + execution.run_experiment() + + def __init__(self, training_exp_id, publish_on, overwrite) -> None: + self.training_exp_id = training_exp_id + self.overwrite = overwrite + self.publish_on = publish_on + self.ui = config.ui + + def prepare(self): + self.training_exp = TrainingExp.get(self.training_exp_id) + self.ui.print(f"Training Execution: {self.training_exp.name}") + self.event = TrainingEvent.from_experiment(self.training_exp_id) + + def validate(self): + if self.event.finished: + msg = "The provided training experiment has to start a training event." + raise InvalidArgumentError(msg) + if self.publish_on == "127.0.0.1": + pass + + def check_existing_outputs(self): + msg = ( + "Outputs still exist from previous runs. Overwrite" + " them by rerunning the command with --overwrite" + ) + paths = [ + self.event.agg_out_logs, + self.event.out_weights, + self.event.report_path, + ] + for path in paths: + if os.path.exists(path): + if not self.overwrite: + raise MedperfException(msg) + remove_path(path) + + def prepare_aggregator(self): + self.aggregator = Aggregator.from_experiment(self.training_exp_id) + self.cube = self.__get_cube(self.aggregator.aggregation_mlcube, "aggregation") + + def __get_cube(self, uid: int, name: str) -> Cube: + self.ui.text = f"Retrieving {name} cube" + cube = Cube.get(uid) + cube.download_run_files() + self.ui.print(f"> {name} cube download complete") + return cube + + def prepare_participants_list(self): + self.event.prepare_participants_list() + + def prepare_plan(self): + self.training_exp.prepare_plan() + + def prepare_pki_assets(self): + ca = CA.from_experiment(self.training_exp_id) + trust(ca) + agg_address = self.aggregator.address + self.aggregator_pki_assets = get_pki_assets_path(agg_address, ca.name) + self.ca = ca + + def run_experiment(self): + params = { + "node_cert_folder": self.aggregator_pki_assets, + "ca_cert_folder": self.ca.pki_assets, + "plan_path": self.training_exp.plan_path, + "collaborators": self.event.participants_list_path, + "output_logs": self.event.agg_out_logs, + "output_weights": self.event.out_weights, + "report_path": self.event.report_path, + } + + self.ui.text = "Running Aggregator" + self.cube.run( + task="start_aggregator", + port=self.aggregator.port, + publish_on=self.publish_on, + **params, + ) diff --git a/cli/medperf/commands/aggregator/submit.py b/cli/medperf/commands/aggregator/submit.py new file mode 100644 index 000000000..3fd653fcb --- /dev/null +++ b/cli/medperf/commands/aggregator/submit.py @@ -0,0 +1,45 @@ +import medperf.config as config +from medperf.entities.aggregator import Aggregator +from medperf.utils import remove_path +from medperf.entities.cube import Cube + + +class SubmitAggregator: + @classmethod + def run(cls, name: str, address: str, port: int, aggregation_mlcube: int): + """Submits a new aggregator to the medperf platform + Args: + name (str): aggregator name + address (str): aggregator address/domain + port (int): port which the aggregator will use + aggregation_mlcube (int): aggregation mlcube uid + """ + ui = config.ui + submission = cls(name, address, port, aggregation_mlcube) + + with ui.interactive(): + ui.text = "Submitting Aggregator to MedPerf" + submission.validate_agg_cube() + updated_benchmark_body = submission.submit() + ui.print("Uploaded") + submission.write(updated_benchmark_body) + + def __init__(self, name: str, address: str, port: int, aggregation_mlcube: int): + self.ui = config.ui + agg_config = {"address": address, "port": port} + self.aggregator = Aggregator( + name=name, config=agg_config, aggregation_mlcube=aggregation_mlcube + ) + config.tmp_paths.append(self.aggregator.path) + + def validate_agg_cube(self): + Cube.get(self.aggregator.aggregation_mlcube) + + def submit(self): + updated_body = self.aggregator.upload() + return updated_body + + def write(self, updated_body): + remove_path(self.aggregator.path) + aggregator = Aggregator(**updated_body) + aggregator.write() diff --git a/cli/medperf/commands/association/approval.py b/cli/medperf/commands/association/approval.py index 4ed343911..ec7fe7999 100644 --- a/cli/medperf/commands/association/approval.py +++ b/cli/medperf/commands/association/approval.py @@ -1,14 +1,17 @@ from medperf import config -from medperf.exceptions import InvalidArgumentError +from medperf.commands.association.utils import validate_args class Approval: @staticmethod def run( - benchmark_uid: int, approval_status: str, + benchmark_uid: int = None, + training_exp_uid: int = None, dataset_uid: int = None, mlcube_uid: int = None, + aggregator_uid: int = None, + ca_uid: int = None, ): """Sets approval status for an association between a benchmark and a dataset or mlcube @@ -21,17 +24,34 @@ def run( mlcube_uid (int, optional): MLCube UID. Defaults to None. """ comms = config.comms - too_many_resources = dataset_uid and mlcube_uid - no_resource = dataset_uid is None and mlcube_uid is None - if no_resource or too_many_resources: - raise InvalidArgumentError("Must provide either a dataset or mlcube") + validate_args( + benchmark_uid, + training_exp_uid, + dataset_uid, + mlcube_uid, + aggregator_uid, + ca_uid, + approval_status.value, + ) + update = {"approval_status": approval_status.value} + if benchmark_uid: + if dataset_uid: + comms.update_benchmark_dataset_association( + benchmark_uid, dataset_uid, update + ) - if dataset_uid: - comms.set_dataset_association_approval( - benchmark_uid, dataset_uid, approval_status.value - ) - - if mlcube_uid: - comms.set_mlcube_association_approval( - benchmark_uid, mlcube_uid, approval_status.value - ) + if mlcube_uid: + comms.update_benchmark_model_association( + benchmark_uid, mlcube_uid, update + ) + if training_exp_uid: + if dataset_uid: + comms.update_training_dataset_association( + training_exp_uid, dataset_uid, update + ) + if aggregator_uid: + comms.update_training_aggregator_association( + training_exp_uid, aggregator_uid, update + ) + if ca_uid: + comms.update_training_ca_association(training_exp_uid, ca_uid, update) diff --git a/cli/medperf/commands/association/association.py b/cli/medperf/commands/association/association.py index fa69682ed..b97255c72 100644 --- a/cli/medperf/commands/association/association.py +++ b/cli/medperf/commands/association/association.py @@ -1,5 +1,4 @@ import typer -from typing import Optional import medperf.config as config from medperf.decorators import clean_except @@ -13,22 +12,47 @@ @app.command("ls") @clean_except -def list(filter: Optional[str] = typer.Argument(None)): +def list( + benchmark: bool = typer.Option(False, "-b", help="list benchmark associations"), + training_exp: bool = typer.Option(False, "-t", help="list training associations"), + dataset: bool = typer.Option(False, "-d", help="list dataset associations"), + mlcube: bool = typer.Option(False, "-m", help="list mlcube associations"), + aggregator: bool = typer.Option(False, "-a", help="list aggregator associations"), + ca: bool = typer.Option(False, "-c", help="list ca associations"), + approval_status: str = typer.Option( + None, "--approval-status", help="Approval status" + ), +): """Display all associations related to the current user. Args: filter (str, optional): Filter associations by approval status. Defaults to displaying all user associations. """ - ListAssociations.run(filter) + ListAssociations.run( + benchmark, + training_exp, + dataset, + mlcube, + aggregator, + ca, + approval_status, + ) @app.command("approve") @clean_except def approve( - benchmark_uid: int = typer.Option(..., "--benchmark", "-b", help="Benchmark UID"), + benchmark_uid: int = typer.Option(None, "--benchmark", "-b", help="Benchmark UID"), + training_exp_uid: int = typer.Option( + None, "--training_exp", "-t", help="Training exp UID" + ), dataset_uid: int = typer.Option(None, "--dataset", "-d", help="Dataset UID"), mlcube_uid: int = typer.Option(None, "--mlcube", "-m", help="MLCube UID"), + aggregator_uid: int = typer.Option( + None, "--aggregator", "-a", help="Aggregator UID" + ), + ca_uid: int = typer.Option(None, "--ca", "-c", help="CA UID"), ): """Approves an association between a benchmark and a dataset or model mlcube @@ -37,16 +61,31 @@ def approve( dataset_uid (int, optional): Dataset UID. mlcube_uid (int, optional): Model MLCube UID. """ - Approval.run(benchmark_uid, Status.APPROVED, dataset_uid, mlcube_uid) + Approval.run( + Status.APPROVED, + benchmark_uid, + training_exp_uid, + dataset_uid, + mlcube_uid, + aggregator_uid, + ca_uid, + ) config.ui.print("✅ Done!") @app.command("reject") @clean_except def reject( - benchmark_uid: int = typer.Option(..., "--benchmark", "-b", help="Benchmark UID"), + benchmark_uid: int = typer.Option(None, "--benchmark", "-b", help="Benchmark UID"), + training_exp_uid: int = typer.Option( + None, "--training_exp", "-t", help="Training exp UID" + ), dataset_uid: int = typer.Option(None, "--dataset", "-d", help="Dataset UID"), mlcube_uid: int = typer.Option(None, "--mlcube", "-m", help="MLCube UID"), + aggregator_uid: int = typer.Option( + None, "--aggregator", "-a", help="Aggregator UID" + ), + ca_uid: int = typer.Option(None, "--ca", "-c", help="CA UID"), ): """Rejects an association between a benchmark and a dataset or model mlcube @@ -55,7 +94,15 @@ def reject( dataset_uid (int, optional): Dataset UID. mlcube_uid (int, optional): Model MLCube UID. """ - Approval.run(benchmark_uid, Status.REJECTED, dataset_uid, mlcube_uid) + Approval.run( + Status.REJECTED, + benchmark_uid, + training_exp_uid, + dataset_uid, + mlcube_uid, + aggregator_uid, + ca_uid, + ) config.ui.print("✅ Done!") diff --git a/cli/medperf/commands/association/list.py b/cli/medperf/commands/association/list.py index e210fbc26..fbe77539a 100644 --- a/cli/medperf/commands/association/list.py +++ b/cli/medperf/commands/association/list.py @@ -1,47 +1,55 @@ from tabulate import tabulate from medperf import config +from medperf.commands.association.utils import validate_args, get_associations_list class ListAssociations: @staticmethod - def run(filter: str = None): - """Get Pending association requests""" - comms = config.comms - ui = config.ui - dset_assocs = comms.get_datasets_associations() - cube_assocs = comms.get_cubes_associations() + def run( + benchmark, + training_exp, + dataset, + mlcube, + aggregator, + ca, + approval_status, + ): + """Get user association requests""" + validate_args( + benchmark, training_exp, dataset, mlcube, aggregator, ca, approval_status + ) + if training_exp: + experiment_key = "training_exp" + elif benchmark: + experiment_key = "benchmark" - # Might be worth seeing if creating an association class that encapsulates - # most of the logic here is useful - assocs = dset_assocs + cube_assocs - if filter: - filter = filter.upper() - assocs = [assoc for assoc in assocs if assoc["approval_status"] == filter] + if mlcube: + component_key = "model_mlcube" + elif dataset: + component_key = "dataset" + elif aggregator: + component_key = "aggregator" + elif ca: + component_key = "ca" + + assocs = get_associations_list(experiment_key, component_key, approval_status) assocs_info = [] for assoc in assocs: assoc_info = ( - assoc.get("dataset", None), - assoc.get("model_mlcube", None), - assoc["benchmark"], + assoc[component_key], + assoc[experiment_key], assoc["initiated_by"], assoc["approval_status"], - assoc.get("priority", None), - # NOTE: We should find a better way to show priorities, since a priority - # is better shown when listing cube associations only, of a specific - # benchmark. Maybe this is resolved after we add a general filtering - # feature to list commands. ) assocs_info.append(assoc_info) headers = [ - "Dataset UID", - "MLCube UID", - "Benchmark UID", + f"{component_key.replace('_', ' ').title()} UID", + f"{experiment_key.replace('_', ' ').title()} UID", "Initiated by", "Status", - "Priority", ] tab = tabulate(assocs_info, headers=headers) - ui.print(tab) + config.ui.print(tab) diff --git a/cli/medperf/commands/association/priority.py b/cli/medperf/commands/association/priority.py index c58db2450..760b0f4c2 100644 --- a/cli/medperf/commands/association/priority.py +++ b/cli/medperf/commands/association/priority.py @@ -19,6 +19,6 @@ def run(benchmark_uid: int, mlcube_uid: int, priority: int): raise InvalidArgumentError( "The given mlcube doesn't exist or is not associated with the benchmark" ) - config.comms.set_mlcube_association_priority( - benchmark_uid, mlcube_uid, priority + config.comms.update_benchmark_model_association( + benchmark_uid, mlcube_uid, {"priority": priority} ) diff --git a/cli/medperf/commands/association/utils.py b/cli/medperf/commands/association/utils.py new file mode 100644 index 000000000..78ff742d7 --- /dev/null +++ b/cli/medperf/commands/association/utils.py @@ -0,0 +1,125 @@ +from medperf.exceptions import InvalidArgumentError +from medperf import config +from pydantic.datetime_parse import parse_datetime + + +def validate_args( + benchmark, training_exp, dataset, model_mlcube, aggregator, ca, approval_status +): + training_exp = bool(training_exp) + benchmark = bool(benchmark) + dataset = bool(dataset) + model_mlcube = bool(model_mlcube) + aggregator = bool(aggregator) + ca = bool(ca) + + if approval_status is not None: + if approval_status.lower() not in ["pending", "approved", "rejected"]: + raise InvalidArgumentError( + "If provided, approval status must be one of pending, approved, or rejected" + ) + if sum([benchmark, training_exp]) != 1: + raise InvalidArgumentError( + "One training experiment or a benchmark flag must be provided" + ) + if sum([dataset, model_mlcube, aggregator, ca]) != 1: + raise InvalidArgumentError( + "One dataset, mlcube, aggregator, or ca flag must be provided" + ) + if training_exp and model_mlcube: + raise InvalidArgumentError( + "Invalid combination of arguments. There are no associations" + " between training experiments and models mlcubes" + ) + if benchmark and (ca or aggregator): + raise InvalidArgumentError( + "Invalid combination of arguments. There are no associations" + " between benchmarks and CAs or aggregators" + ) + + +def filter_latest_associations(associations, experiment_key, component_key): + """Given a list of component-experiment associations, this function + retrieves a list containing the latest association of each + experiment-component instance. + + Args: + associations (list[dict]): the list of associations + experiment_key (str): experiment identifier field in the association + component_key (str): component identifier field in the association + + Returns: + list[dict]: the list containing the latest association of each + entity instance. + """ + + associations.sort(key=lambda assoc: parse_datetime(assoc["created_at"])) + latest_associations = {} + for assoc in associations: + component_id = assoc[component_key] + experiment_id = assoc[experiment_key] + latest_associations[(component_id, experiment_id)] = assoc + + latest_associations = list(latest_associations.values()) + return latest_associations + + +def get_last_component(associations, experiment_key): + associations.sort(key=lambda assoc: parse_datetime(assoc["created_at"])) + experiments_component = {} + for assoc in associations: + experiment_id = assoc[experiment_key] + experiments_component[experiment_id] = assoc + + experiments_component = list(experiments_component.values()) + return experiments_component + + +def get_associations_list( + experiment_key: str, + component_key: str, + approval_status: str = None, + experiment_id: int = None, +): + comms_functions = { + "training_exp": { + "dataset": { + "user": config.comms.get_user_training_datasets_associations, + "experiment": config.comms.get_training_datasets_associations, + }, + "aggregator": { + "user": config.comms.get_user_training_aggregators_associations, + }, + "ca": { + "user": config.comms.get_user_training_cas_associations, + }, + }, + "benchmark": { + "dataset": { + "user": config.comms.get_user_benchmarks_datasets_associations, + }, + "mode_mlcube": { + "user": config.comms.get_user_benchmarks_models_associations, + "experiment": config.comms.get_benchmark_models_associations, + }, + }, + } + if experiment_id: + comms_func = comms_functions[experiment_key][component_key]["experiment"] + assocs = comms_func(experiment_id) + else: + comms_func = comms_functions[experiment_key][component_key]["user"] + assocs = comms_func() + + assocs = filter_latest_associations(assocs, experiment_key, component_key) + if component_key in ["aggregator", "ca"]: + # an experiment should only have one aggregator and/or one CA + assocs = get_last_component(assocs, experiment_key) + + if approval_status: + approval_status = approval_status.upper() + assocs = [ + assoc for assoc in assocs if assoc["approval_status"] == approval_status + ] + + return assocs diff --git a/cli/medperf/commands/benchmark/benchmark.py b/cli/medperf/commands/benchmark/benchmark.py index f02d67cb4..35d719b0d 100644 --- a/cli/medperf/commands/benchmark/benchmark.py +++ b/cli/medperf/commands/benchmark/benchmark.py @@ -16,14 +16,16 @@ @app.command("ls") @clean_except def list( - local: bool = typer.Option(False, "--local", help="Get local benchmarks"), + unregistered: bool = typer.Option( + False, "--unregistered", help="Get unregistered benchmarks" + ), mine: bool = typer.Option(False, "--mine", help="Get current-user benchmarks"), ): - """List benchmarks stored locally and remotely from the user""" + """List benchmarks""" EntityList.run( Benchmark, fields=["UID", "Name", "Description", "State", "Approval Status", "Registered"], - local_only=local, + unregistered=unregistered, mine_only=mine, ) @@ -162,10 +164,10 @@ def view( "--format", help="Format to display contents. Available formats: [yaml, json]", ), - local: bool = typer.Option( + unregistered: bool = typer.Option( False, - "--local", - help="Display local benchmarks if benchmark ID is not provided", + "--unregistered", + help="Display unregistered benchmarks if benchmark ID is not provided", ), mine: bool = typer.Option( False, @@ -180,4 +182,4 @@ def view( ), ): """Displays the information of one or more benchmarks""" - EntityView.run(entity_id, Benchmark, format, local, mine, output) + EntityView.run(entity_id, Benchmark, format, unregistered, mine, output) diff --git a/cli/medperf/commands/ca/associate.py b/cli/medperf/commands/ca/associate.py new file mode 100644 index 000000000..9bf439b6b --- /dev/null +++ b/cli/medperf/commands/ca/associate.py @@ -0,0 +1,34 @@ +from medperf import config +from medperf.entities.ca import CA +from medperf.entities.training_exp import TrainingExp +from medperf.utils import approval_prompt +from medperf.exceptions import InvalidArgumentError + + +class AssociateCA: + @staticmethod + def run(training_exp_id: int, ca_uid: int, approved=False): + """Associates an ca with a training experiment + + Args: + ca_uid (int): UID of the registered ca to associate + benchmark_uid (int): UID of the benchmark to associate with + """ + comms = config.comms + ui = config.ui + ca = CA.get(ca_uid) + if ca.id is None: + msg = "The provided ca is not registered." + raise InvalidArgumentError(msg) + + training_exp = TrainingExp.get(training_exp_id) + msg = "Please confirm that you would like to associate" + msg += f" the ca {ca.name} with the training exp {training_exp.name}." + msg += " [Y/n]" + + approved = approved or approval_prompt(msg) + if approved: + ui.print("Generating ca training association") + comms.associate_training_ca(ca.id, training_exp_id) + else: + ui.print("CA association operation cancelled.") diff --git a/cli/medperf/commands/ca/ca.py b/cli/medperf/commands/ca/ca.py new file mode 100644 index 000000000..580f00822 --- /dev/null +++ b/cli/medperf/commands/ca/ca.py @@ -0,0 +1,102 @@ +from typing import Optional +from medperf.entities.ca import CA +import typer + +import medperf.config as config +from medperf.decorators import clean_except +from medperf.commands.ca.submit import SubmitCA +from medperf.commands.ca.associate import AssociateCA + +from medperf.commands.list import EntityList +from medperf.commands.view import EntityView + +app = typer.Typer() + + +@app.command("submit") +@clean_except +def submit( + name: str = typer.Option(..., "--name", "-n", help="Name of the ca"), + config_path: str = typer.Option( + ..., + "--config-path", + "-c", + help="Path to the configuration file (JSON) of the CA", + ), + ca_mlcube: int = typer.Option(..., "--ca-mlcube", help="CA MLCube UID"), + client_mlcube: int = typer.Option( + ..., + "--client-mlcube", + help="MLCube UID to be used by clients to get a cert", + ), + server_mlcube: int = typer.Option( + ..., + "--server-mlcube", + help="MLCube UID to be used by servers to get a cert", + ), +): + """Submits a ca""" + SubmitCA.run(name, config_path, ca_mlcube, client_mlcube, server_mlcube) + config.ui.print("✅ Done!") + + +@app.command("associate") +@clean_except +def associate( + ca_id: int = typer.Option(..., "--ca_id", "-c", help="UID of CA to associate with"), + training_exp_id: int = typer.Option( + ..., "--training_exp_id", "-t", help="UID of training exp to associate with" + ), + approval: bool = typer.Option(False, "-y", help="Skip approval step"), +): + """Associates a CA with a training experiment.""" + AssociateCA.run(ca_id, training_exp_id, approved=approval) + config.ui.print("✅ Done!") + + +@app.command("ls") +@clean_except +def list( + unregistered: bool = typer.Option( + False, "--unregistered", help="Get unregistered CAs" + ), + mine: bool = typer.Option(False, "--mine", help="Get current-user CAs"), +): + """List CAs""" + EntityList.run( + CA, + fields=["UID", "Name", "Address", "Port"], + unregistered=unregistered, + mine_only=mine, + ) + + +@app.command("view") +@clean_except +def view( + entity_id: Optional[int] = typer.Argument(None, help="Benchmark ID"), + format: str = typer.Option( + "yaml", + "-f", + "--format", + help="Format to display contents. Available formats: [yaml, json]", + ), + unregistered: bool = typer.Option( + False, + "--unregistered", + help="Display unregistered benchmarks if benchmark ID is not provided", + ), + mine: bool = typer.Option( + False, + "--mine", + help="Display current-user benchmarks if benchmark ID is not provided", + ), + output: str = typer.Option( + None, + "--output", + "-o", + help="Output file to store contents. If not provided, the output will be displayed", + ), +): + """Displays the information of one or more CAs""" + EntityView.run(entity_id, CA, format, unregistered, mine, output) diff --git a/cli/medperf/commands/ca/submit.py b/cli/medperf/commands/ca/submit.py new file mode 100644 index 000000000..cbb3f0348 --- /dev/null +++ b/cli/medperf/commands/ca/submit.py @@ -0,0 +1,65 @@ +import medperf.config as config +from medperf.entities.ca import CA +from medperf.utils import remove_path +from medperf.entities.cube import Cube + + +class SubmitCA: + @classmethod + def run( + cls, + name: str, + config_path: str, + ca_mlcube: int, + client_mlcube: int, + server_mlcube: int, + ): + """Submits a new ca to the medperf platform + Args: + name (str): ca name + config_path (dict): ca config + ca_mlcube (int): ca_mlcube mlcube uid + client_mlcube (int): client_mlcube mlcube uid + server_mlcube (int): server_mlcube mlcube uid + """ + ui = config.ui + submission = cls(name, config_path, ca_mlcube, client_mlcube, server_mlcube) + + with ui.interactive(): + ui.text = "Submitting CA to MedPerf" + submission.validate_ca_cubes() + updated_benchmark_body = submission.submit() + ui.print("Uploaded") + submission.write(updated_benchmark_body) + + def __init__( + self, + name: str, + config_path: str, + ca_mlcube: int, + client_mlcube: int, + server_mlcube: int, + ): + self.ui = config.ui + self.ca = CA( + name=name, + config=config_path, + ca_mlcube=ca_mlcube, + client_mlcube=client_mlcube, + server_mlcube=server_mlcube, + ) + config.tmp_paths.append(self.ca.path) + + def validate_ca_cubes(self): + Cube.get(self.ca.ca_mlcube) + Cube.get(self.ca.client_mlcube) + Cube.get(self.ca.server_mlcube) + + def submit(self): + updated_body = self.ca.upload() + return updated_body + + def write(self, updated_body): + remove_path(self.ca.path) + ca = CA(**updated_body) + ca.write() diff --git a/cli/medperf/commands/certificate/certificate.py b/cli/medperf/commands/certificate/certificate.py new file mode 100644 index 000000000..7125fbcc8 --- /dev/null +++ b/cli/medperf/commands/certificate/certificate.py @@ -0,0 +1,44 @@ +import typer + +import medperf.config as config +from medperf.decorators import clean_except +from medperf.commands.certificate.client_certificate import GetUserCertificate +from medperf.commands.certificate.server_certificate import GetServerCertificate + +app = typer.Typer() + + +@app.command("get_client_certificate") +@clean_except +def get_client_certificate( + training_exp_id: int = typer.Option( + ..., + "--training_exp_id", + "-t", + help="UID of training exp which you intend to be a part of", + ), + overwrite: bool = typer.Option( + False, "--overwrite", help="Overwrite cert and key if present" + ), +): + """get a client certificate""" + GetUserCertificate.run(training_exp_id, overwrite) + config.ui.print("✅ Done!") + + +@app.command("get_server_certificate") +@clean_except +def get_server_certificate( + training_exp_id: int = typer.Option( + ..., + "--training_exp_id", + "-t", + help="UID of training exp which you intend to be a part of", + ), + overwrite: bool = typer.Option( + False, "--overwrite", help="Overwrite cert and key if present" + ), +): + """get a server certificate""" + GetServerCertificate.run(training_exp_id, overwrite) + config.ui.print("✅ Done!") diff --git a/cli/medperf/commands/certificate/client_certificate.py b/cli/medperf/commands/certificate/client_certificate.py new file mode 100644 index 000000000..8eace42c8 --- /dev/null +++ b/cli/medperf/commands/certificate/client_certificate.py @@ -0,0 +1,22 @@ +from medperf.entities.ca import CA +from medperf.account_management import get_medperf_user_data +from medperf.certificates import get_client_cert +from medperf.exceptions import MedperfException +from medperf.utils import get_pki_assets_path, remove_path +import os + + +class GetUserCertificate: + @staticmethod + def run(training_exp_id: int, overwrite: bool = False): + """get user cert""" + ca = CA.from_experiment(training_exp_id) + email = get_medperf_user_data()["email"] + output_path = get_pki_assets_path(email, ca.name) + if os.path.exists(output_path): + if not overwrite: + raise MedperfException( + "Cert and key already present. Rerun the command with --overwrite" + ) + remove_path(output_path) + get_client_cert(ca, email, output_path) diff --git a/cli/medperf/commands/certificate/server_certificate.py b/cli/medperf/commands/certificate/server_certificate.py new file mode 100644 index 000000000..1e7a25db8 --- /dev/null +++ b/cli/medperf/commands/certificate/server_certificate.py @@ -0,0 +1,23 @@ +from medperf.entities.ca import CA +from medperf.entities.aggregator import Aggregator +from medperf.certificates import get_server_cert +from medperf.exceptions import MedperfException +from medperf.utils import get_pki_assets_path, remove_path +import os + + +class GetServerCertificate: + @staticmethod + def run(training_exp_id: int, overwrite: bool = False): + """get server cert""" + ca = CA.from_experiment(training_exp_id) + aggregator = Aggregator.from_experiment(training_exp_id) + address = aggregator.address + output_path = get_pki_assets_path(address, ca.name) + if os.path.exists(output_path): + if not overwrite: + raise MedperfException( + "Cert and key already present. Rerun the command with --overwrite" + ) + remove_path(output_path) + get_server_cert(ca, address, output_path) diff --git a/cli/medperf/commands/compatibility_test/compatibility_test.py b/cli/medperf/commands/compatibility_test/compatibility_test.py index a3b25ac78..0bd4a4695 100644 --- a/cli/medperf/commands/compatibility_test/compatibility_test.py +++ b/cli/medperf/commands/compatibility_test/compatibility_test.py @@ -95,7 +95,11 @@ def run( @clean_except def list(): """List previously executed tests reports.""" - EntityList.run(TestReport, fields=["UID", "Data Source", "Model", "Evaluator"]) + EntityList.run( + TestReport, + fields=["UID", "Data Source", "Model", "Evaluator"], + unregistered=True, + ) @app.command("view") @@ -116,4 +120,4 @@ def view( ), ): """Displays the information of one or more test reports""" - EntityView.run(entity_id, TestReport, format, output=output) + EntityView.run(entity_id, TestReport, format, unregistered=True, output=output) diff --git a/cli/medperf/commands/compatibility_test/utils.py b/cli/medperf/commands/compatibility_test/utils.py index a12ac5ea2..c56a57d41 100644 --- a/cli/medperf/commands/compatibility_test/utils.py +++ b/cli/medperf/commands/compatibility_test/utils.py @@ -138,23 +138,23 @@ def create_test_dataset( # TODO: existing dataset could make problems # make some changes since this is a test dataset config.tmp_paths.remove(data_creation.dataset.path) - data_creation.dataset.write() if skip_data_preparation_step: data_creation.make_dataset_prepared() dataset = data_creation.dataset + old_generated_uid = dataset.generated_uid + old_path = dataset.path # prepare/check dataset DataPreparation.run(dataset.generated_uid) # update dataset generated_uid - old_path = dataset.path - generated_uid = get_folders_hash([dataset.data_path, dataset.labels_path]) - dataset.generated_uid = generated_uid - dataset.write() - if dataset.input_data_hash != dataset.generated_uid: + new_generated_uid = get_folders_hash([dataset.data_path, dataset.labels_path]) + if new_generated_uid != old_generated_uid: # move to a correct location if it underwent preparation - new_path = old_path.replace(dataset.input_data_hash, generated_uid) + new_path = old_path.replace(old_generated_uid, new_generated_uid) remove_path(new_path) os.rename(old_path, new_path) + dataset.generated_uid = new_generated_uid + dataset.write() - return generated_uid + return new_generated_uid diff --git a/cli/medperf/commands/dataset/associate.py b/cli/medperf/commands/dataset/associate.py index 84359fd1d..e338bc831 100644 --- a/cli/medperf/commands/dataset/associate.py +++ b/cli/medperf/commands/dataset/associate.py @@ -1,52 +1,31 @@ -from medperf import config -from medperf.entities.dataset import Dataset -from medperf.entities.benchmark import Benchmark -from medperf.utils import dict_pretty_print, approval_prompt -from medperf.commands.result.create import BenchmarkExecution +from medperf.commands.dataset.associate_benchmark import AssociateBenchmarkDataset +from medperf.commands.dataset.associate_training import AssociateTrainingDataset from medperf.exceptions import InvalidArgumentError class AssociateDataset: @staticmethod - def run(data_uid: int, benchmark_uid: int, approved=False, no_cache=False): - """Associates a registered dataset with a benchmark - - Args: - data_uid (int): UID of the registered dataset to associate - benchmark_uid (int): UID of the benchmark to associate with - """ - comms = config.comms - ui = config.ui - dset = Dataset.get(data_uid) - if dset.id is None: - msg = "The provided dataset is not registered." - raise InvalidArgumentError(msg) - - benchmark = Benchmark.get(benchmark_uid) - - if dset.data_preparation_mlcube != benchmark.data_preparation_mlcube: + def run( + data_uid: int, + benchmark_uid: int = None, + training_exp_uid: int = None, + approved=False, + no_cache=False, + ): + """Associates a dataset with a benchmark or a training exp""" + too_many_resources = benchmark_uid and training_exp_uid + no_resource = benchmark_uid is None and training_exp_uid is None + if no_resource or too_many_resources: raise InvalidArgumentError( - "The specified dataset wasn't prepared for this benchmark" + "Must provide either a benchmark or a training experiment" ) - - result = BenchmarkExecution.run( - benchmark_uid, - data_uid, - [benchmark.reference_model_mlcube], - no_cache=no_cache, - )[0] - ui.print("These are the results generated by the compatibility test. ") - ui.print("This will be sent along the association request.") - ui.print("They will not be part of the benchmark.") - dict_pretty_print(result.results) - - msg = "Please confirm that you would like to associate" - msg += f" the dataset {dset.name} with the benchmark {benchmark.name}." - msg += " [Y/n]" - approved = approved or approval_prompt(msg) - if approved: - ui.print("Generating dataset benchmark association") - metadata = {"test_result": result.results} - comms.associate_dset(dset.id, benchmark_uid, metadata) - else: - ui.print("Dataset association operation cancelled.") + if benchmark_uid: + AssociateBenchmarkDataset.run( + data_uid, benchmark_uid, approved=approved, no_cache=no_cache + ) + if training_exp_uid: + if no_cache: + raise InvalidArgumentError( + "no_cache argument is only valid when associating with a benchmark" + ) + AssociateTrainingDataset.run(data_uid, training_exp_uid, approved=approved) diff --git a/cli/medperf/commands/dataset/associate_benchmark.py b/cli/medperf/commands/dataset/associate_benchmark.py new file mode 100644 index 000000000..9b937c36d --- /dev/null +++ b/cli/medperf/commands/dataset/associate_benchmark.py @@ -0,0 +1,52 @@ +from medperf import config +from medperf.entities.dataset import Dataset +from medperf.entities.benchmark import Benchmark +from medperf.utils import dict_pretty_print, approval_prompt +from medperf.commands.result.create import BenchmarkExecution +from medperf.exceptions import InvalidArgumentError + + +class AssociateBenchmarkDataset: + @staticmethod + def run(data_uid: int, benchmark_uid: int, approved=False, no_cache=False): + """Associates a registered dataset with a benchmark + + Args: + data_uid (int): UID of the registered dataset to associate + benchmark_uid (int): UID of the benchmark to associate with + """ + comms = config.comms + ui = config.ui + dset = Dataset.get(data_uid) + if dset.id is None: + msg = "The provided dataset is not registered." + raise InvalidArgumentError(msg) + + benchmark = Benchmark.get(benchmark_uid) + + if dset.data_preparation_mlcube != benchmark.data_preparation_mlcube: + raise InvalidArgumentError( + "The specified dataset wasn't prepared for this benchmark" + ) + + result = BenchmarkExecution.run( + benchmark_uid, + data_uid, + [benchmark.reference_model_mlcube], + no_cache=no_cache, + )[0] + ui.print("These are the results generated by the compatibility test. ") + ui.print("This will be sent along the association request.") + ui.print("They will not be part of the benchmark.") + dict_pretty_print(result.results) + + msg = "Please confirm that you would like to associate" + msg += f" the dataset {dset.name} with the benchmark {benchmark.name}." + msg += " [Y/n]" + approved = approved or approval_prompt(msg) + if approved: + ui.print("Generating dataset benchmark association") + metadata = {"test_result": result.results} + comms.associate_benchmark_dataset(dset.id, benchmark_uid, metadata) + else: + ui.print("Dataset association operation cancelled.") diff --git a/cli/medperf/commands/dataset/associate_training.py b/cli/medperf/commands/dataset/associate_training.py new file mode 100644 index 000000000..7a3565089 --- /dev/null +++ b/cli/medperf/commands/dataset/associate_training.py @@ -0,0 +1,39 @@ +from medperf import config +from medperf.entities.dataset import Dataset +from medperf.entities.training_exp import TrainingExp +from medperf.utils import approval_prompt +from medperf.exceptions import InvalidArgumentError + + +class AssociateTrainingDataset: + @staticmethod + def run(data_uid: int, training_exp_uid: int, approved=False): + """Associates a dataset with a training experiment + + Args: + data_uid (int): UID of the registered dataset to associate + training_exp_uid (int): UID of the training experiment to associate with + """ + comms = config.comms + ui = config.ui + dset: Dataset = Dataset.get(data_uid) + if dset.id is None: + msg = "The provided dataset is not registered." + raise InvalidArgumentError(msg) + + training_exp = TrainingExp.get(training_exp_uid) + + if dset.data_preparation_mlcube != training_exp.data_preparation_mlcube: + raise InvalidArgumentError( + "The specified dataset wasn't prepared for this experiment" + ) + + msg = "Please confirm that you would like to associate" + msg += f" the dataset {dset.name} with the training experiment {training_exp.name}." + msg += " [Y/n]" + approved = approved or approval_prompt(msg) + if approved: + ui.print("Generating dataset training experiment association") + comms.associate_training_dataset(dset.id, training_exp_uid) + else: + ui.print("Dataset association operation cancelled.") diff --git a/cli/medperf/commands/dataset/dataset.py b/cli/medperf/commands/dataset/dataset.py index a27e36814..9a3ab2f3a 100644 --- a/cli/medperf/commands/dataset/dataset.py +++ b/cli/medperf/commands/dataset/dataset.py @@ -10,6 +10,7 @@ from medperf.commands.dataset.prepare import DataPreparation from medperf.commands.dataset.set_operational import DatasetSetOperational from medperf.commands.dataset.associate import AssociateDataset +from medperf.commands.dataset.train import TrainingExecution app = typer.Typer() @@ -17,17 +18,19 @@ @app.command("ls") @clean_except def list( - local: bool = typer.Option(False, "--local", help="Get local datasets"), + unregistered: bool = typer.Option( + False, "--unregistered", help="Get unregistered datasets" + ), mine: bool = typer.Option(False, "--mine", help="Get current-user datasets"), mlcube: int = typer.Option( None, "--mlcube", "-m", help="Get datasets for a given data prep mlcube" ), ): - """List datasets stored locally and remotely from the user""" + """List datasets""" EntityList.run( Dataset, fields=["UID", "Name", "Data Preparation Cube UID", "State", "Status", "Owner"], - local_only=local, + unregistered=unregistered, mine_only=mine, mlcube=mlcube, ) @@ -122,23 +125,52 @@ def associate( ..., "--data_uid", "-d", help="Registered Dataset UID" ), benchmark_uid: int = typer.Option( - ..., "--benchmark_uid", "-b", help="Benchmark UID" + None, "--benchmark_uid", "-b", help="Benchmark UID" + ), + training_exp_uid: int = typer.Option( + None, "--training_exp_uid", "-t", help="Training experiment UID" ), approval: bool = typer.Option(False, "-y", help="Skip approval step"), no_cache: bool = typer.Option( False, "--no-cache", - help="Execute the test even if results already exist", + help="Execute the benchmark association test even if results already exist", ), ): - """Associate a registered dataset with a specific benchmark. - The dataset and benchmark must share the same data preparation cube. - """ + """Associate a registered dataset with a specific benchmark or experiment.""" ui = config.ui - AssociateDataset.run(data_uid, benchmark_uid, approved=approval, no_cache=no_cache) + AssociateDataset.run( + data_uid, benchmark_uid, training_exp_uid, approved=approval, no_cache=no_cache + ) ui.print("✅ Done!") +@app.command("train") +@clean_except +def train( + training_exp_id: int = typer.Option( + ..., "--training_exp_id", "-t", help="UID of the desired benchmark" + ), + data_uid: int = typer.Option( + ..., "--data_uid", "-d", help="Registered Dataset UID" + ), + overwrite: bool = typer.Option( + False, "--overwrite", help="Overwrite outputs if present" + ), + restart_on_failure: bool = typer.Option( + False, + "--restart_on_failure", + help="Keep restarting failing training processes until Keyboard interrupt", + ), + approval: bool = typer.Option(False, "-y", help="Skip approval step"), +): + """Runs training""" + TrainingExecution.run( + training_exp_id, data_uid, overwrite, approval, restart_on_failure + ) + config.ui.print("✅ Done!") + + @app.command("view") @clean_except def view( @@ -149,8 +181,10 @@ def view( "--format", help="Format to display contents. Available formats: [yaml, json]", ), - local: bool = typer.Option( - False, "--local", help="Display local datasets if dataset ID is not provided" + unregistered: bool = typer.Option( + False, + "--unregistered", + help="Display unregistered datasets if dataset ID is not provided", ), mine: bool = typer.Option( False, @@ -165,4 +199,4 @@ def view( ), ): """Displays the information of one or more datasets""" - EntityView.run(entity_id, Dataset, format, local, mine, output) + EntityView.run(entity_id, Dataset, format, unregistered, mine, output) diff --git a/cli/medperf/commands/dataset/train.py b/cli/medperf/commands/dataset/train.py new file mode 100644 index 000000000..08105f26c --- /dev/null +++ b/cli/medperf/commands/dataset/train.py @@ -0,0 +1,162 @@ +import os +from medperf import config +from medperf.account_management.account_management import get_medperf_user_data +from medperf.entities.ca import CA +from medperf.entities.event import TrainingEvent +from medperf.exceptions import ( + CleanExit, + ExecutionError, + InvalidArgumentError, + MedperfException, +) +from medperf.entities.training_exp import TrainingExp +from medperf.entities.dataset import Dataset +from medperf.entities.cube import Cube +from medperf.utils import ( + approval_prompt, + dict_pretty_print, + get_pki_assets_path, + get_participant_label, + remove_path, +) +from medperf.certificates import trust + + +class TrainingExecution: + @classmethod + def run( + cls, + training_exp_id: int, + data_uid: int, + overwrite: bool = False, + approved: bool = False, + restart_on_failure: bool = False, + ): + """Starts the aggregation server of a training experiment + + Args: + training_exp_id (int): Training experiment UID. + """ + if restart_on_failure: + approved = True + overwrite = True + execution = cls(training_exp_id, data_uid, overwrite, approved) + if restart_on_failure: + execution.confirm_restart_on_failure() + + while True: + execution.prepare() + execution.validate() + execution.check_existing_outputs() + execution.prepare_plan() + execution.prepare_pki_assets() + execution.confirm_run() + with config.ui.interactive(): + execution.prepare_training_cube() + try: + execution.run_experiment() + break + except ExecutionError as e: + print(str(e)) + if not restart_on_failure: + break + + def __init__( + self, training_exp_id: int, data_uid: int, overwrite: bool, approved: bool + ) -> None: + self.training_exp_id = training_exp_id + self.data_uid = data_uid + self.overwrite = overwrite + self.ui = config.ui + self.approved = approved + + def confirm_restart_on_failure(self): + msg = ( + "You chose to restart on failure. This means that the training process" + " will automatically restart, without your approval, even if training configuration" + " changes from the server side. Do you confirm? [Y/n] " + ) + if not approval_prompt(msg): + raise CleanExit( + "Training cancelled. Rerun without the --restart_on_failure flag." + ) + + def prepare(self): + self.training_exp = TrainingExp.get(self.training_exp_id) + self.ui.print(f"Training Execution: {self.training_exp.name}") + self.event = TrainingEvent.from_experiment(self.training_exp_id) + self.dataset = Dataset.get(self.data_uid) + self.user_email: str = get_medperf_user_data()["email"] + self.out_logs = os.path.join(self.event.col_out_logs, str(self.dataset.id)) + + def validate(self): + if self.dataset.id is None: + msg = "The provided dataset is not registered." + raise InvalidArgumentError(msg) + + if self.dataset.state != "OPERATION": + msg = "The provided dataset is not operational." + raise InvalidArgumentError(msg) + + if self.event.finished: + msg = "The provided training experiment has to start a training event." + raise InvalidArgumentError(msg) + + def check_existing_outputs(self): + msg = ( + "Outputs still exist from previous runs. Overwrite" + " them by rerunning the command with --overwrite" + ) + paths = [self.out_logs] + for path in paths: + if os.path.exists(path): + if not self.overwrite: + raise MedperfException(msg) + remove_path(path) + + def prepare_plan(self): + self.training_exp.prepare_plan() + + def prepare_pki_assets(self): + ca = CA.from_experiment(self.training_exp_id) + trust(ca) + self.dataset_pki_assets = get_pki_assets_path(self.user_email, ca.name) + self.ca = ca + + def confirm_run(self): + msg = ( + "Above is the training configuration that will be used." + " Do you confirm starting training? [Y/n] " + ) + dict_pretty_print(self.training_exp.plan) + self.approved = self.approved or approval_prompt(msg) + + if not self.approved: + raise CleanExit("Training cancelled.") + + def prepare_training_cube(self): + self.cube = self.__get_cube(self.training_exp.fl_mlcube, "FL") + + def __get_cube(self, uid: int, name: str) -> Cube: + self.ui.text = ( + "Retrieving and setting up training MLCube. This may take some time." + ) + cube = Cube.get(uid) + cube.download_run_files() + self.ui.print(f"> {name} cube download complete") + return cube + + def run_experiment(self): + participant_label = get_participant_label(self.user_email, self.dataset.id) + env_dict = {"MEDPERF_PARTICIPANT_LABEL": participant_label} + params = { + "data_path": self.dataset.data_path, + "labels_path": self.dataset.labels_path, + "node_cert_folder": self.dataset_pki_assets, + "ca_cert_folder": self.ca.pki_assets, + "plan_path": self.training_exp.plan_path, + "output_logs": self.out_logs, + } + + self.ui.text = "Running Training" + self.cube.run(task="train", env_dict=env_dict, **params) diff --git a/cli/medperf/commands/execution.py b/cli/medperf/commands/execution.py index d8afb2244..79d6237b9 100644 --- a/cli/medperf/commands/execution.py +++ b/cli/medperf/commands/execution.py @@ -80,7 +80,7 @@ def run_inference(self): try: self.model.run( task="infer", - output_logs=self.model_logs_path, + output_logs_file=self.model_logs_path, timeout=infer_timeout, data_path=data_path, output_path=preds_path, @@ -105,7 +105,7 @@ def run_evaluation(self): try: self.evaluator.run( task="evaluate", - output_logs=self.metrics_logs_path, + output_logs_file=self.metrics_logs_path, timeout=evaluate_timeout, predictions=preds_path, labels=labels_path, diff --git a/cli/medperf/commands/list.py b/cli/medperf/commands/list.py index 5fd462bf7..b5d6226a4 100644 --- a/cli/medperf/commands/list.py +++ b/cli/medperf/commands/list.py @@ -10,27 +10,29 @@ class EntityList: def run( entity_class, fields, - local_only: bool = False, + unregistered: bool = False, mine_only: bool = False, **kwargs, ): """Lists all local datasets Args: - local_only (bool, optional): Display all local results. Defaults to False. + unregistered (bool, optional): Display only local unregistered results. Defaults to False. mine_only (bool, optional): Display all current-user results. Defaults to False. kwargs (dict): Additional parameters for filtering entity lists. """ - entity_list = EntityList(entity_class, fields, local_only, mine_only, **kwargs) + entity_list = EntityList( + entity_class, fields, unregistered, mine_only, **kwargs + ) entity_list.prepare() entity_list.validate() entity_list.filter() entity_list.display() - def __init__(self, entity_class, fields, local_only, mine_only, **kwargs): + def __init__(self, entity_class, fields, unregistered, mine_only, **kwargs): self.entity_class = entity_class self.fields = fields - self.local_only = local_only + self.unregistered = unregistered self.mine_only = mine_only self.filters = kwargs self.data = [] @@ -40,7 +42,7 @@ def prepare(self): self.filters["owner"] = get_medperf_user_data()["id"] entities = self.entity_class.all( - local_only=self.local_only, filters=self.filters + unregistered=self.unregistered, filters=self.filters ) self.data = [entity.display_dict() for entity in entities] diff --git a/cli/medperf/commands/mlcube/associate.py b/cli/medperf/commands/mlcube/associate.py index 8307caade..9ee7317cc 100644 --- a/cli/medperf/commands/mlcube/associate.py +++ b/cli/medperf/commands/mlcube/associate.py @@ -40,6 +40,6 @@ def run( if approved: ui.print("Generating mlcube benchmark association") metadata = {"test_result": results} - comms.associate_cube(cube_uid, benchmark_uid, metadata) + comms.associate_benchmark_model(cube_uid, benchmark_uid, metadata) else: ui.print("MLCube association operation cancelled") diff --git a/cli/medperf/commands/mlcube/mlcube.py b/cli/medperf/commands/mlcube/mlcube.py index 4c365e574..bad358f8e 100644 --- a/cli/medperf/commands/mlcube/mlcube.py +++ b/cli/medperf/commands/mlcube/mlcube.py @@ -9,21 +9,48 @@ from medperf.commands.mlcube.create import CreateCube from medperf.commands.mlcube.submit import SubmitCube from medperf.commands.mlcube.associate import AssociateCube +from medperf.commands.mlcube.run import run_mlcube app = typer.Typer() +@app.command("run") +@clean_except +def run( + mlcube_path: str = typer.Option( + ..., "--mlcube", "-m", help="path to mlcube folder" + ), + task: str = typer.Option(..., "--task", "-t", help="mlcube task to run"), + out_logs: str = typer.Option( + None, "--output-logs", "-o", help="where to store stdout" + ), + port: str = typer.Option(None, "--port", "-P", help="port to expose"), + env: str = typer.Option( + "", "--env", "-e", help="comma separated list of key=value pairs" + ), + params: str = typer.Option( + "", "--params", "-p", help="comma separated list of key=value pairs" + ), +): + """List mlcubes stored locally and remotely from the user""" + params = dict([p.split("=") for p in params.strip().strip(",").split(",") if p]) + env = dict([p.split("=") for p in env.strip().strip(",").split(",") if p]) + run_mlcube(mlcube_path, task, out_logs, params, port, env) + + @app.command("ls") @clean_except def list( - local: bool = typer.Option(False, "--local", help="Get local mlcubes"), + unregistered: bool = typer.Option( + False, "--unregistered", help="Get unregistered mlcubes" + ), mine: bool = typer.Option(False, "--mine", help="Get current-user mlcubes"), ): - """List mlcubes stored locally and remotely from the user""" + """List mlcubes""" EntityList.run( Cube, fields=["UID", "Name", "State", "Registered"], - local_only=local, + unregistered=unregistered, mine_only=mine, ) @@ -148,8 +175,10 @@ def view( "--format", help="Format to display contents. Available formats: [yaml, json]", ), - local: bool = typer.Option( - False, "--local", help="Display local mlcubes if mlcube ID is not provided" + unregistered: bool = typer.Option( + False, + "--unregistered", + help="Display unregistered mlcubes if mlcube ID is not provided", ), mine: bool = typer.Option( False, @@ -164,4 +193,4 @@ def view( ), ): """Displays the information of one or more mlcubes""" - EntityView.run(entity_id, Cube, format, local, mine, output) + EntityView.run(entity_id, Cube, format, unregistered, mine, output) diff --git a/cli/medperf/commands/mlcube/run.py b/cli/medperf/commands/mlcube/run.py new file mode 100644 index 000000000..86cb626b0 --- /dev/null +++ b/cli/medperf/commands/mlcube/run.py @@ -0,0 +1,14 @@ +from medperf.tests.mocks.cube import TestCube +import os +from medperf import config + + +def run_mlcube(mlcube_path, task, out_logs, params, port, env): + c = TestCube() + c.cube_path = os.path.join(mlcube_path, config.cube_filename) + c.params_path = os.path.join( + mlcube_path, config.workspace_path, config.params_filename + ) + if config.platform == "singularity": + c._set_image_hash_from_registry() + c.run(task, out_logs, port=port, env_dict=env, **params) diff --git a/cli/medperf/commands/result/create.py b/cli/medperf/commands/result/create.py index 42f97d990..1b8622810 100644 --- a/cli/medperf/commands/result/create.py +++ b/cli/medperf/commands/result/create.py @@ -1,5 +1,6 @@ import os from typing import List, Optional +from medperf.account_management.account_management import get_medperf_user_data from medperf.commands.execution import Execution from medperf.entities.result import Result from tabulate import tabulate @@ -100,6 +101,8 @@ def validate(self): if dset_prep_cube != bmark_prep_cube: msg = "The provided dataset is not compatible with the specified benchmark." raise InvalidArgumentError(msg) + # TODO: there is no check if dataset is associated with the benchmark + # Note that if it is present, this will break dataset association logic def prepare_models(self): if self.models_input_file: @@ -143,7 +146,9 @@ def __validate_models(self, benchmark_models): raise InvalidArgumentError(msg) def load_cached_results(self): - results = Result.all() + user_id = get_medperf_user_data()["id"] + results = Result.all(filters={"owner": user_id}) + results += Result.all(unregistered=True) benchmark_dset_results = [ result for result in results diff --git a/cli/medperf/commands/result/result.py b/cli/medperf/commands/result/result.py index 6fbb3b08a..40b65c52e 100644 --- a/cli/medperf/commands/result/result.py +++ b/cli/medperf/commands/result/result.py @@ -62,17 +62,19 @@ def submit( @app.command("ls") @clean_except def list( - local: bool = typer.Option(False, "--local", help="Get local results"), + unregistered: bool = typer.Option( + False, "--unregistered", help="Get unregistered results" + ), mine: bool = typer.Option(False, "--mine", help="Get current-user results"), benchmark: int = typer.Option( None, "--benchmark", "-b", help="Get results for a given benchmark" ), ): - """List results stored locally and remotely from the user""" + """List results""" EntityList.run( Result, fields=["UID", "Benchmark", "Model", "Dataset", "Registered"], - local_only=local, + unregistered=unregistered, mine_only=mine, benchmark=benchmark, ) @@ -88,8 +90,10 @@ def view( "--format", help="Format to display contents. Available formats: [yaml, json]", ), - local: bool = typer.Option( - False, "--local", help="Display local results if result ID is not provided" + unregistered: bool = typer.Option( + False, + "--unregistered", + help="Display unregistered results if result ID is not provided", ), mine: bool = typer.Option( False, @@ -107,4 +111,6 @@ def view( ), ): """Displays the information of one or more results""" - EntityView.run(entity_id, Result, format, local, mine, output, benchmark=benchmark) + EntityView.run( + entity_id, Result, format, unregistered, mine, output, benchmark=benchmark + ) diff --git a/cli/medperf/commands/training/close_event.py b/cli/medperf/commands/training/close_event.py new file mode 100644 index 000000000..2a922d97b --- /dev/null +++ b/cli/medperf/commands/training/close_event.py @@ -0,0 +1,59 @@ +import os +from medperf.entities.training_exp import TrainingExp +from medperf.entities.event import TrainingEvent +from medperf.utils import approval_prompt, dict_pretty_print +from medperf.exceptions import CleanExit, InvalidArgumentError +from medperf import config +import yaml + + +class CloseEvent: + """Used for both event cancellation (with custom report path) and for event closing + (with the expected report path generated by the aggregator)""" + + @classmethod + def run(cls, training_exp_id: int, report_path: str = None, approval: bool = False): + submission = cls(training_exp_id, report_path, approval) + submission.prepare() + submission.validate() + submission.read_report() + submission.submit() + submission.write() + + def __init__(self, training_exp_id: int, report_path: str, approval: bool): + self.training_exp_id = training_exp_id + self.approved = approval + self.report_path = report_path + + def prepare(self): + self.training_exp = TrainingExp.get(self.training_exp_id) + self.event = TrainingEvent.from_experiment(self.training_exp_id) + self.report_path = self.report_path or self.event.report_path + + def validate(self): + if self.event.finished: + raise InvalidArgumentError("This experiment has already finished") + if not os.path.exists(self.report_path): + raise InvalidArgumentError(f"Report {self.report_path} does not exist.") + + def read_report(self): + with open(self.report_path) as f: + self.report = yaml.safe_load(f) + + def submit(self): + self.event.report = self.report + body = {"finished": True, "report": self.report} + dict_pretty_print(self.report) + msg = ( + f"You are about to close the event of training experiment {self.training_exp.name}." + " This will be the submitted report. Do you confirm? [Y/n] " + ) + self.approved = self.approved or approval_prompt(msg) + + if self.approved: + config.comms.update_training_event(self.event.id, body) + return + raise CleanExit("Event closing cancelled") + + def write(self): + self.event.write() diff --git a/cli/medperf/commands/training/get_experiment_status.py b/cli/medperf/commands/training/get_experiment_status.py new file mode 100644 index 000000000..28f9e2709 --- /dev/null +++ b/cli/medperf/commands/training/get_experiment_status.py @@ -0,0 +1,89 @@ +from medperf import config +from medperf.account_management.account_management import get_medperf_user_data +from medperf.entities.ca import CA +from medperf.entities.training_exp import TrainingExp +from medperf.entities.cube import Cube +from medperf.utils import ( + get_pki_assets_path, + generate_tmp_path, + dict_pretty_print, + remove_path, +) +from medperf.certificates import trust +import yaml +import os + + +class GetExperimentStatus: + @classmethod + def run(cls, training_exp_id: int, silent: bool = False): + """Starts the aggregation server of a training experiment + + Args: + training_exp_id (int): Training experiment UID. + """ + execution = cls(training_exp_id) + execution.prepare() + execution.prepare_plan() + execution.prepare_pki_assets() + with config.ui.interactive(): + execution.prepare_admin_cube() + execution.get_experiment_status() + if not silent: + execution.print_experiment_status() + execution.store_status() + + def __init__(self, training_exp_id: int) -> None: + self.training_exp_id = training_exp_id + self.ui = config.ui + + def prepare(self): + self.training_exp = TrainingExp.get(self.training_exp_id) + self.ui.print(f"Training Experiment: {self.training_exp.name}") + self.user_email: str = get_medperf_user_data()["email"] + self.status_output = generate_tmp_path() + self.temp_dir = generate_tmp_path() + + def prepare_plan(self): + self.training_exp.prepare_plan() + + def prepare_pki_assets(self): + ca = CA.from_experiment(self.training_exp_id) + trust(ca) + self.admin_pki_assets = get_pki_assets_path(self.user_email, ca.name) + self.ca = ca + + def prepare_admin_cube(self): + self.cube = self.__get_cube(self.training_exp.fl_admin_mlcube, "FL Admin") + + def __get_cube(self, uid: int, name: str) -> Cube: + self.ui.text = ( + "Retrieving and setting up training MLCube. This may take some time." + ) + cube = Cube.get(uid) + cube.download_run_files() + self.ui.print(f"> {name} cube download complete") + return cube + + def get_experiment_status(self): + env_dict = {"MEDPERF_ADMIN_PARTICIPANT_CN": self.user_email} + params = { + "node_cert_folder": self.admin_pki_assets, + "ca_cert_folder": self.ca.pki_assets, + "plan_path": self.training_exp.plan_path, + "output_status_file": self.status_output, + "temp_dir": self.temp_dir, + } + + self.ui.text = "Getting training experiment status" + self.cube.run(task="get_experiment_status", env_dict=env_dict, **params) + + def print_experiment_status(self): + with open(self.status_output) as f: + contents = yaml.safe_load(f) + dict_pretty_print(contents, skip_none_values=False) + + def store_status(self): + new_status_path = self.training_exp.status_path + remove_path(new_status_path) + os.rename(self.status_output, new_status_path) diff --git a/cli/medperf/commands/training/set_plan.py b/cli/medperf/commands/training/set_plan.py new file mode 100644 index 000000000..0a959eb02 --- /dev/null +++ b/cli/medperf/commands/training/set_plan.py @@ -0,0 +1,86 @@ +import medperf.config as config +from medperf.entities.aggregator import Aggregator +from medperf.entities.training_exp import TrainingExp +from medperf.entities.cube import Cube +from medperf.exceptions import CleanExit, InvalidArgumentError +from medperf.utils import approval_prompt, dict_pretty_print, generate_tmp_path +import os + +import yaml + + +class SetPlan: + @classmethod + def run( + cls, training_exp_id: int, training_config_path: str, approval: bool = False + ): + """Creates and submits the training plan + Args: + training_exp_id (int): training experiment + training_config_path (str): path to a training config file + approval (bool): skip approval + """ + planset = cls(training_exp_id, training_config_path, approval) + planset.validate() + planset.prepare() + planset.create_plan() + planset.update() + planset.write() + + def __init__(self, training_exp_id: int, training_config_path: str, approval: bool): + self.ui = config.ui + self.training_exp_id = training_exp_id + self.training_config_path = os.path.abspath(training_config_path) + self.approved = approval + self.plan_out_path = generate_tmp_path() + + def validate(self): + if not os.path.exists(self.training_config_path): + raise InvalidArgumentError("Provided training config path does not exist") + + def prepare(self): + self.training_exp = TrainingExp.get(self.training_exp_id) + self.aggregator = Aggregator.from_experiment(self.training_exp_id) + self.mlcube = self.__get_cube(self.training_exp.fl_mlcube, "FL") + self.aggregator.prepare_config() + + def __get_cube(self, uid: int, name: str) -> Cube: + self.ui.text = f"Retrieving {name} cube" + cube = Cube.get(uid) + cube.download_run_files() + self.ui.print(f"> {name} cube download complete") + return cube + + def create_plan(self): + """Auto-generates dataset UIDs for both input and output paths""" + params = { + "training_config_path": self.training_config_path, + "aggregator_config_path": self.aggregator.config_path, + "plan_path": self.plan_out_path, + } + self.mlcube.run("generate_plan", **params) + + def update(self): + with open(self.plan_out_path) as f: + plan = yaml.safe_load(f) + self.training_exp.plan = plan + body = {"plan": plan} + dict_pretty_print(body) + msg = ( + "This is the training plan that will be submitted and used by the participants." + " Do you confirm?[Y/n] " + ) + self.approved = self.approved or approval_prompt(msg) + + if self.approved: + config.comms.update_training_exp(self.training_exp.id, body) + return + + raise CleanExit("Setting the training plan was cancelled") + + def write(self) -> str: + """Writes the registration into disk + Args: + filename (str, optional): name of the file. Defaults to config.reg_file. + """ + self.training_exp.write() diff --git a/cli/medperf/commands/training/start_event.py b/cli/medperf/commands/training/start_event.py new file mode 100644 index 000000000..18f6400aa --- /dev/null +++ b/cli/medperf/commands/training/start_event.py @@ -0,0 +1,88 @@ +from medperf.entities.training_exp import TrainingExp +from medperf.entities.event import TrainingEvent +from medperf.utils import approval_prompt, dict_pretty_print, get_participant_label +from medperf.exceptions import CleanExit, InvalidArgumentError +import yaml +import os + + +class StartEvent: + @classmethod + def run( + cls, + training_exp_id: int, + name: str, + participants_list_file: str = None, + approval: bool = False, + ): + submission = cls(training_exp_id, name, participants_list_file, approval) + submission.prepare() + submission.validate() + submission.prepare_participants_list() + updated_body = submission.submit() + submission.write(updated_body) + + def __init__( + self, training_exp_id: int, name: str, participants_list_file: str, approval + ): + self.training_exp_id = training_exp_id + self.name = name + self.participants_list_file = participants_list_file + self.approved = approval + + def prepare(self): + self.training_exp = TrainingExp.get(self.training_exp_id) + + def validate(self): + if self.training_exp.approval_status != "APPROVED": + raise InvalidArgumentError("This experiment has not been approved yet") + if self.participants_list_file is not None: + if not os.path.exists(self.participants_list_file): + raise InvalidArgumentError( + "Provided participants list path does not exist" + ) + + def prepare_participants_list(self): + if self.participants_list_file is None: + self._prepare_participants_list_from_associations() + else: + self._prepare_participants_list_from_file() + + def _prepare_participants_list_from_file(self): + with open(self.participants_list_file) as f: + self.participants_list = yaml.safe_load(f) + + def _prepare_participants_list_from_associations(self): + datasets_with_users = TrainingExp.get_datasets_with_users(self.training_exp_id) + participants_list = {} + for dataset in datasets_with_users: + user_email = dataset["owner"]["email"] + data_id = dataset["id"] + participant_label = get_participant_label(user_email, data_id) + participant_common_name = user_email + participants_list[participant_label] = participant_common_name + self.participants_list = participants_list + + def submit(self): + dict_pretty_print(self.participants_list) + msg = ( + f"You are about to start an event for the training experiment {self.training_exp.name}." + " This is the list of participants (participant label, participant common name)" + " that will be able to participate in your training experiment. Do you confirm? [Y/n] " + ) + self.approved = self.approved or approval_prompt(msg) + + self.event = TrainingEvent( + name=self.name, + training_exp=self.training_exp_id, + participants=self.participants_list, + ) + if self.approved: + updated_body = self.event.upload() + return updated_body + + raise CleanExit("Event creation cancelled") + + def write(self, updated_body): + event = TrainingEvent(**updated_body) + event.write() diff --git a/cli/medperf/commands/training/submit.py b/cli/medperf/commands/training/submit.py new file mode 100644 index 000000000..02221da91 --- /dev/null +++ b/cli/medperf/commands/training/submit.py @@ -0,0 +1,58 @@ +import medperf.config as config +from medperf.entities.training_exp import TrainingExp +from medperf.entities.cube import Cube +from medperf.utils import remove_path + + +class SubmitTrainingExp: + @classmethod + def run(cls, training_exp_info: dict): + """Submits a new cube to the medperf platform + Args: + benchmark_info (dict): benchmark information + expected keys: + name (str): benchmark name + description (str): benchmark description + docs_url (str): benchmark documentation url + demo_url (str): benchmark demo dataset url + demo_hash (str): benchmark demo dataset hash + data_preparation_mlcube (int): benchmark data preparation mlcube uid + reference_model_mlcube (int): benchmark reference model mlcube uid + evaluator_mlcube (int): benchmark data evaluator mlcube uid + """ + ui = config.ui + submission = cls(training_exp_info) + + with ui.interactive(): + ui.text = "Getting FL MLCube" + submission.get_fl_mlcube() + ui.text = "Getting FL admin MLCube" + submission.get_fl_admin_mlcube() + ui.print("> Completed retrieving FL MLCube") + ui.text = "Submitting TrainingExp to MedPerf" + updated_benchmark_body = submission.submit() + ui.print("Uploaded") + submission.write(updated_benchmark_body) + + def __init__(self, training_exp_info: dict): + self.ui = config.ui + self.training_exp = TrainingExp(**training_exp_info) + config.tmp_paths.append(self.training_exp.path) + + def get_fl_mlcube(self): + mlcube_id = self.training_exp.fl_mlcube + Cube.get(mlcube_id) + + def get_fl_admin_mlcube(self): + mlcube_id = self.training_exp.fl_admin_mlcube + if mlcube_id: + Cube.get(mlcube_id) + + def submit(self): + updated_body = self.training_exp.upload() + return updated_body + + def write(self, updated_body): + remove_path(self.training_exp.path) + training_exp = TrainingExp(**updated_body) + training_exp.write() diff --git a/cli/medperf/commands/training/training.py b/cli/medperf/commands/training/training.py new file mode 100644 index 000000000..161d1f087 --- /dev/null +++ b/cli/medperf/commands/training/training.py @@ -0,0 +1,192 @@ +from typing import Optional +from medperf.entities.training_exp import TrainingExp +import typer + +import medperf.config as config +from medperf.decorators import clean_except + +from medperf.commands.training.submit import SubmitTrainingExp +from medperf.commands.training.set_plan import SetPlan +from medperf.commands.training.start_event import StartEvent +from medperf.commands.training.close_event import CloseEvent +from medperf.commands.list import EntityList +from medperf.commands.view import EntityView +from medperf.commands.training.get_experiment_status import GetExperimentStatus +from medperf.commands.training.update_plan import UpdatePlan + +app = typer.Typer() + + +@app.command("submit") +@clean_except +def submit( + name: str = typer.Option(..., "--name", "-n", help="Name of the benchmark"), + description: str = typer.Option( + ..., "--description", "-d", help="Description of the benchmark" + ), + docs_url: str = typer.Option("", "--docs-url", "-u", help="URL to documentation"), + prep_mlcube: int = typer.Option(..., "--prep-mlcube", "-p", help="prep MLCube UID"), + fl_mlcube: int = typer.Option( + ..., "--fl-mlcube", "-m", help="Reference Model MLCube UID" + ), + fl_admin_mlcube: int = typer.Option( + None, "--fl-admin-mlcube", "-a", help="FL admin interface MLCube" + ), + operational: bool = typer.Option( + False, + "--operational", + help="Submit the experiment as OPERATIONAL", + ), +): + """Submits a new benchmark to the platform""" + training_exp_info = { + "name": name, + "description": description, + "docs_url": docs_url, + "fl_mlcube": fl_mlcube, + "fl_admin_mlcube": fl_admin_mlcube, + "demo_dataset_tarball_url": "link", + "demo_dataset_tarball_hash": "hash", + "demo_dataset_generated_uid": "uid", + "data_preparation_mlcube": prep_mlcube, + "state": "OPERATION" if operational else "DEVELOPMENT", + } + SubmitTrainingExp.run(training_exp_info) + config.ui.print("✅ Done!") + + +@app.command("set_plan") +@clean_except +def set_plan( + training_exp_id: int = typer.Option( + ..., "--training_exp_id", "-t", help="UID of the desired benchmark" + ), + training_config_path: str = typer.Option( + ..., "--config-path", "-c", help="config path" + ), + approval: bool = typer.Option(False, "-y", help="Skip approval step"), +): + """Runs the benchmark execution step for a given benchmark, prepared dataset and model""" + SetPlan.run(training_exp_id, training_config_path, approval) + config.ui.print("✅ Done!") + + +@app.command("start_event") +@clean_except +def start_event( + training_exp_id: int = typer.Option( + ..., "--training_exp_id", "-t", help="UID of the desired benchmark" + ), + name: str = typer.Option(..., "--name", "-n", help="Name of the benchmark"), + participants_list_file: str = typer.Option( + None, "--participants_list_file", "-p", help="Name of the benchmark" + ), + approval: bool = typer.Option(False, "-y", help="Skip approval step"), +): + """Runs the benchmark execution step for a given benchmark, prepared dataset and model""" + StartEvent.run(training_exp_id, name, participants_list_file, approval) + config.ui.print("✅ Done!") + + +@app.command("get_experiment_status") +@clean_except +def get_experiment_status( + training_exp_id: int = typer.Option( + ..., "--training_exp_id", "-t", help="UID of the desired benchmark" + ), + silent: bool = typer.Option(False, "--silent", help="don't print"), +): + """Runs the benchmark execution step for a given benchmark, prepared dataset and model""" + GetExperimentStatus.run(training_exp_id, silent) + config.ui.print("✅ Done!") + + +@app.command("update_plan") +@clean_except +def update_plan( + training_exp_id: int = typer.Option( + ..., "--training_exp_id", "-t", help="UID of the desired benchmark" + ), + field_name: str = typer.Option( + ..., "--field_name", "-f", help="UID of the desired benchmark" + ), + value: str = typer.Option( + ..., "--value", "-v", help="UID of the desired benchmark" + ), +): + """Runtime-update of a scalar field of the training plan""" + UpdatePlan.run(training_exp_id, field_name, value) + config.ui.print("✅ Done!") + + +@app.command("close_event") +@clean_except +def close_event( + training_exp_id: int = typer.Option( + ..., "--training_exp_id", "-t", help="UID of the desired benchmark" + ), + approval: bool = typer.Option(False, "-y", help="Skip approval step"), +): + """Runs the benchmark execution step for a given benchmark, prepared dataset and model""" + CloseEvent.run(training_exp_id, approval=approval) + config.ui.print("✅ Done!") + + +@app.command("cancel_event") +@clean_except +def cancel_event( + training_exp_id: int = typer.Option( + ..., "--training_exp_id", "-t", help="UID of the desired benchmark" + ), + report_path: str = typer.Option(..., "--report-path", "-r", help="report path"), + approval: bool = typer.Option(False, "-y", help="Skip approval step"), +): + """Runs the benchmark execution step for a given benchmark, prepared dataset and model""" + CloseEvent.run(training_exp_id, report_path=report_path, approval=approval) + config.ui.print("✅ Done!") + + +@app.command("ls") +@clean_except +def list( + local: bool = typer.Option(False, "--local", help="Get local exps"), + mine: bool = typer.Option(False, "--mine", help="Get current-user exps"), +): + """List experiments stored locally and remotely from the user""" + EntityList.run( + TrainingExp, + fields=["UID", "Name", "State", "Approval Status", "Registered"], + local_only=local, + mine_only=mine, + ) + + +@app.command("view") +@clean_except +def view( + entity_id: Optional[int] = typer.Argument(None, help="Benchmark ID"), + format: str = typer.Option( + "yaml", + "-f", + "--format", + help="Format to display contents. Available formats: [yaml, json]", + ), + local: bool = typer.Option( + False, + "--local", + help="Display local benchmarks if benchmark ID is not provided", + ), + mine: bool = typer.Option( + False, + "--mine", + help="Display current-user benchmarks if benchmark ID is not provided", + ), + output: str = typer.Option( + None, + "--output", + "-o", + help="Output file to store contents. If not provided, the output will be displayed", + ), +): + """Displays the information of one or more benchmarks""" + EntityView.run(entity_id, TrainingExp, format, local, mine, output) diff --git a/cli/medperf/commands/training/update_plan.py b/cli/medperf/commands/training/update_plan.py new file mode 100644 index 000000000..baa064300 --- /dev/null +++ b/cli/medperf/commands/training/update_plan.py @@ -0,0 +1,74 @@ +from medperf import config +from medperf.account_management.account_management import get_medperf_user_data +from medperf.entities.ca import CA +from medperf.entities.training_exp import TrainingExp +from medperf.entities.cube import Cube +from medperf.utils import get_pki_assets_path, generate_tmp_path +from medperf.certificates import trust + + +class UpdatePlan: + @classmethod + def run(cls, training_exp_id: int, field_name: str, field_value: str): + """Starts the aggregation server of a training experiment + + Args: + training_exp_id (int): Training experiment UID. + """ + execution = cls(training_exp_id, field_name, field_value) + execution.prepare() + execution.prepare_plan() + execution.prepare_pki_assets() + with config.ui.interactive(): + execution.prepare_admin_cube() + execution.update_plan() + + def __init__(self, training_exp_id: int, field_name: str, field_value: str) -> None: + self.training_exp_id = training_exp_id + self.field_name = field_name + self.field_value = field_value + self.ui = config.ui + + def prepare(self): + self.training_exp = TrainingExp.get(self.training_exp_id) + self.ui.print(f"Training Experiment: {self.training_exp.name}") + self.user_email: str = get_medperf_user_data()["email"] + self.temp_dir = generate_tmp_path() + + def prepare_plan(self): + self.training_exp.prepare_plan() + + def prepare_pki_assets(self): + ca = CA.from_experiment(self.training_exp_id) + trust(ca) + self.admin_pki_assets = get_pki_assets_path(self.user_email, ca.name) + self.ca = ca + + def prepare_admin_cube(self): + self.cube = self.__get_cube(self.training_exp.fl_admin_mlcube, "FL Admin") + + def __get_cube(self, uid: int, name: str) -> Cube: + self.ui.text = ( + "Retrieving and setting up training MLCube. This may take some time." + ) + cube = Cube.get(uid) + cube.download_run_files() + self.ui.print(f"> {name} cube download complete") + return cube + + def update_plan(self): + env_dict = { + "MEDPERF_ADMIN_PARTICIPANT_CN": self.user_email, + "MEDPERF_UPDATE_FIELD_NAME": self.field_name, + "MEDPERF_UPDATE_FIELD_VALUE": self.field_value, + } + + params = { + "node_cert_folder": self.admin_pki_assets, + "ca_cert_folder": self.ca.pki_assets, + "plan_path": self.training_exp.plan_path, + "temp_dir": self.temp_dir, + } + + self.ui.text = "Updating plan" + self.cube.run(task="update_plan", env_dict=env_dict, **params) diff --git a/cli/medperf/commands/view.py b/cli/medperf/commands/view.py index b4c242f0a..8c2a4179f 100644 --- a/cli/medperf/commands/view.py +++ b/cli/medperf/commands/view.py @@ -14,7 +14,7 @@ def run( entity_id: Union[int, str], entity_class: Entity, format: str = "yaml", - local_only: bool = False, + unregistered: bool = False, mine_only: bool = False, output: str = None, **kwargs, @@ -24,14 +24,14 @@ def run( Args: entity_id (Union[int, str]): Entity identifies entity_class (Entity): Entity type - local_only (bool, optional): Display all local entities. Defaults to False. + unregistered (bool, optional): Display only local unregistered entities. Defaults to False. mine_only (bool, optional): Display all current-user entities. Defaults to False. format (str, optional): What format to use to display the contents. Valid formats: [yaml, json]. Defaults to yaml. output (str, optional): Path to a file for storing the entity contents. If not provided, the contents are printed. kwargs (dict): Additional parameters for filtering entity lists. """ entity_view = EntityView( - entity_id, entity_class, format, local_only, mine_only, output, **kwargs + entity_id, entity_class, format, unregistered, mine_only, output, **kwargs ) entity_view.validate() entity_view.prepare() @@ -41,12 +41,12 @@ def run( entity_view.store() def __init__( - self, entity_id, entity_class, format, local_only, mine_only, output, **kwargs + self, entity_id, entity_class, format, unregistered, mine_only, output, **kwargs ): self.entity_id = entity_id self.entity_class = entity_class self.format = format - self.local_only = local_only + self.unregistered = unregistered self.mine_only = mine_only self.output = output self.filters = kwargs @@ -65,7 +65,7 @@ def prepare(self): self.filters["owner"] = get_medperf_user_data()["id"] entities = self.entity_class.all( - local_only=self.local_only, filters=self.filters + unregistered=self.unregistered, filters=self.filters ) self.data = [entity.todict() for entity in entities] diff --git a/cli/medperf/comms/interface.py b/cli/medperf/comms/interface.py index 01436e435..45516034f 100644 --- a/cli/medperf/comms/interface.py +++ b/cli/medperf/comms/interface.py @@ -1,4 +1,4 @@ -from typing import List +# from typing import List from abc import ABC, abstractmethod @@ -13,289 +13,275 @@ def __init__(self, source: str): token (str, Optional): authentication token to be used throughout communication. Defaults to None. """ - @classmethod - @abstractmethod - def parse_url(self, url: str) -> str: - """Parse the source URL so that it can be used by the comms implementation. - It should handle protocols and versioning to be able to communicate with the API. - - Args: - url (str): base URL - - Returns: - str: parsed URL with protocol and version - """ - - @abstractmethod - def get_current_user(self): - """Retrieve the currently-authenticated user information""" - - @abstractmethod - def get_benchmarks(self) -> List[dict]: - """Retrieves all benchmarks in the platform. - - Returns: - List[dict]: all benchmarks information. - """ - - @abstractmethod - def get_benchmark(self, benchmark_uid: int) -> dict: - """Retrieves the benchmark specification file from the server - - Args: - benchmark_uid (int): uid for the desired benchmark - - Returns: - dict: benchmark specification - """ - - @abstractmethod - def get_benchmark_model_associations(self, benchmark_uid: int) -> List[int]: - """Retrieves all the model associations of a benchmark. - - Args: - benchmark_uid (int): UID of the desired benchmark - - Returns: - list[int]: List of benchmark model associations - """ - - @abstractmethod - def get_user_benchmarks(self) -> List[dict]: - """Retrieves all benchmarks created by the user - - Returns: - List[dict]: Benchmarks data - """ - - @abstractmethod - def get_cubes(self) -> List[dict]: - """Retrieves all MLCubes in the platform - - Returns: - List[dict]: List containing the data of all MLCubes - """ - - @abstractmethod - def get_cube_metadata(self, cube_uid: int) -> dict: - """Retrieves metadata about the specified cube - - Args: - cube_uid (int): UID of the desired cube. - - Returns: - dict: Dictionary containing url and hashes for the cube files - """ - - @abstractmethod - def get_user_cubes(self) -> List[dict]: - """Retrieves metadata from all cubes registered by the user - - Returns: - List[dict]: List of dictionaries containing the mlcubes registration information - """ - - @abstractmethod - def upload_benchmark(self, benchmark_dict: dict) -> int: - """Uploads a new benchmark to the server. - - Args: - benchmark_dict (dict): benchmark_data to be uploaded - - Returns: - int: UID of newly created benchmark - """ - - @abstractmethod - def upload_mlcube(self, mlcube_body: dict) -> int: - """Uploads an MLCube instance to the platform - - Args: - mlcube_body (dict): Dictionary containing all the relevant data for creating mlcubes - - Returns: - int: id of the created mlcube instance on the platform - """ - - @abstractmethod - def get_datasets(self) -> List[dict]: - """Retrieves all datasets in the platform + # @classmethod + # @abstractmethod + # def parse_url(self, url: str) -> str: + # """Parse the source URL so that it can be used by the comms implementation. + # It should handle protocols and versioning to be able to communicate with the API. - Returns: - List[dict]: List of data from all datasets - """ + # Args: + # url (str): base URL - @abstractmethod - def get_dataset(self, dset_uid: str) -> dict: - """Retrieves a specific dataset + # Returns: + # str: parsed URL with protocol and version + # """ - Args: - dset_uid (str): Dataset UID + # @abstractmethod + # def get_current_user(self): + # """Retrieve the currently-authenticated user information""" - Returns: - dict: Dataset metadata - """ + # @abstractmethod + # def get_benchmarks(self) -> List[dict]: + # """Retrieves all benchmarks in the platform. - @abstractmethod - def get_user_datasets(self) -> dict: - """Retrieves all datasets registered by the user + # Returns: + # List[dict]: all benchmarks information. + # """ - Returns: - dict: dictionary with the contents of each dataset registration query - """ + # @abstractmethod + # def get_benchmark(self, benchmark_uid: int) -> dict: + # """Retrieves the benchmark specification file from the server - @abstractmethod - def upload_dataset(self, reg_dict: dict) -> int: - """Uploads registration data to the server, under the sha name of the file. + # Args: + # benchmark_uid (int): uid for the desired benchmark - Args: - reg_dict (dict): Dictionary containing registration information. + # Returns: + # dict: benchmark specification + # """ - Returns: - int: id of the created dataset registration. - """ + # @abstractmethod + # def get_benchmark_model_associations(self, benchmark_uid: int) -> List[int]: + # """Retrieves all the model associations of a benchmark. - @abstractmethod - def get_results(self) -> List[dict]: - """Retrieves all results + # Args: + # benchmark_uid (int): UID of the desired benchmark - Returns: - List[dict]: List of results - """ + # Returns: + # list[int]: List of benchmark model associations + # """ - @abstractmethod - def get_result(self, result_uid: str) -> dict: - """Retrieves a specific result data + # @abstractmethod + # def get_user_benchmarks(self) -> List[dict]: + # """Retrieves all benchmarks created by the user - Args: - result_uid (str): Result UID + # Returns: + # List[dict]: Benchmarks data + # """ - Returns: - dict: Result metadata - """ - - @abstractmethod - def get_user_results(self) -> dict: - """Retrieves all results registered by the user - - Returns: - dict: dictionary with the contents of each dataset registration query - """ - - @abstractmethod - def get_benchmark_results(self, benchmark_id: int) -> dict: - """Retrieves all results for a given benchmark - - Args: - benchmark_id (int): benchmark ID to retrieve results from - - Returns: - dict: dictionary with the contents of each result in the specified benchmark - """ - - @abstractmethod - def upload_result(self, results_dict: dict) -> int: - """Uploads result to the server. + # @abstractmethod + # def get_cubes(self) -> List[dict]: + # """Retrieves all MLCubes in the platform - Args: - results_dict (dict): Dictionary containing results information. + # Returns: + # List[dict]: List containing the data of all MLCubes + # """ - Returns: - int: id of the generated results entry - """ + # @abstractmethod + # def get_cube_metadata(self, cube_uid: int) -> dict: + # """Retrieves metadata about the specified cube - @abstractmethod - def associate_dset(self, data_uid: int, benchmark_uid: int, metadata: dict = {}): - """Create a Dataset Benchmark association + # Args: + # cube_uid (int): UID of the desired cube. - Args: - data_uid (int): Registered dataset UID - benchmark_uid (int): Benchmark UID - metadata (dict, optional): Additional metadata. Defaults to {}. - """ + # Returns: + # dict: Dictionary containing url and hashes for the cube files + # """ - @abstractmethod - def associate_cube(self, cube_uid: str, benchmark_uid: int, metadata: dict = {}): - """Create an MLCube-Benchmark association + # @abstractmethod + # def get_user_cubes(self) -> List[dict]: + # """Retrieves metadata from all cubes registered by the user - Args: - cube_uid (str): MLCube UID - benchmark_uid (int): Benchmark UID - metadata (dict, optional): Additional metadata. Defaults to {}. - """ + # Returns: + # List[dict]: List of dictionaries containing the mlcubes registration information + # """ - @abstractmethod - def set_dataset_association_approval( - self, dataset_uid: str, benchmark_uid: str, status: str - ): - """Approves a dataset association + # @abstractmethod + # def upload_benchmark(self, benchmark_dict: dict) -> int: + # """Uploads a new benchmark to the server. - Args: - dataset_uid (str): Dataset UID - benchmark_uid (str): Benchmark UID - status (str): Approval status to set for the association - """ + # Args: + # benchmark_dict (dict): benchmark_data to be uploaded - @abstractmethod - def set_mlcube_association_approval( - self, mlcube_uid: str, benchmark_uid: str, status: str - ): - """Approves an mlcube association + # Returns: + # int: UID of newly created benchmark + # """ - Args: - mlcube_uid (str): Dataset UID - benchmark_uid (str): Benchmark UID - status (str): Approval status to set for the association - """ + # @abstractmethod + # def upload_mlcube(self, mlcube_body: dict) -> int: + # """Uploads an MLCube instance to the platform - @abstractmethod - def get_datasets_associations(self) -> List[dict]: - """Get all dataset associations related to the current user + # Args: + # mlcube_body (dict): Dictionary containing all the relevant data for creating mlcubes - Returns: - List[dict]: List containing all associations information - """ + # Returns: + # int: id of the created mlcube instance on the platform + # """ - @abstractmethod - def get_cubes_associations(self) -> List[dict]: - """Get all cube associations related to the current user + # @abstractmethod + # def get_datasets(self) -> List[dict]: + # """Retrieves all datasets in the platform - Returns: - List[dict]: List containing all associations information - """ + # Returns: + # List[dict]: List of data from all datasets + # """ - @abstractmethod - def set_mlcube_association_priority( - self, benchmark_uid: str, mlcube_uid: str, priority: int - ): - """Sets the priority of an mlcube-benchmark association + # @abstractmethod + # def get_dataset(self, dset_uid: str) -> dict: + # """Retrieves a specific dataset - Args: - mlcube_uid (str): MLCube UID - benchmark_uid (str): Benchmark UID - priority (int): priority value to set for the association - """ + # Args: + # dset_uid (str): Dataset UID - @abstractmethod - def update_dataset(self, dataset_id: int, data: dict): - """Updates the contents of a datasets identified by dataset_id to the new data dictionary. - Updates may be partial. + # Returns: + # dict: Dataset metadata + # """ - Args: - dataset_id (int): ID of the dataset to update - data (dict): Updated information of the dataset. - """ + # @abstractmethod + # def get_user_datasets(self) -> dict: + # """Retrieves all datasets registered by the user - @abstractmethod - def get_user(self, user_id: int) -> dict: - """Retrieves the specified user. This will only return if - the current user has permission to view the requested user, - either by being himself, an admin or an owner of a data preparation - mlcube used by the requested user + # Returns: + # dict: dictionary with the contents of each dataset registration query + # """ - Args: - user_id (int): User UID + # @abstractmethod + # def upload_dataset(self, reg_dict: dict) -> int: + # """Uploads registration data to the server, under the sha name of the file. - Returns: - dict: Requested user information - """ + # Args: + # reg_dict (dict): Dictionary containing registration information. + + # Returns: + # int: id of the created dataset registration. + # """ + + # @abstractmethod + # def get_results(self) -> List[dict]: + # """Retrieves all results + + # Returns: + # List[dict]: List of results + # """ + + # @abstractmethod + # def get_result(self, result_uid: str) -> dict: + # """Retrieves a specific result data + + # Args: + # result_uid (str): Result UID + + # Returns: + # dict: Result metadata + # """ + + # @abstractmethod + # def get_user_results(self) -> dict: + # """Retrieves all results registered by the user + + # Returns: + # dict: dictionary with the contents of each dataset registration query + # """ + + # @abstractmethod + # def get_benchmark_results(self, benchmark_id: int) -> dict: + # """Retrieves all results for a given benchmark + + # Args: + # benchmark_id (int): benchmark ID to retrieve results from + + # Returns: + # dict: dictionary with the contents of each result in the specified benchmark + # """ + + # @abstractmethod + # def upload_result(self, results_dict: dict) -> int: + # """Uploads result to the server. + + # Args: + # results_dict (dict): Dictionary containing results information. + + # Returns: + # int: id of the generated results entry + # """ + + # @abstractmethod + # def associate_dset(self, data_uid: int, benchmark_uid: int, metadata: dict = {}): + # """Create a Dataset Benchmark association + + # Args: + # data_uid (int): Registered dataset UID + # benchmark_uid (int): Benchmark UID + # metadata (dict, optional): Additional metadata. Defaults to {}. + # """ + + # @abstractmethod + # def associate_cube(self, cube_uid: str, benchmark_uid: int, metadata: dict = {}): + # """Create an MLCube-Benchmark association + + # Args: + # cube_uid (str): MLCube UID + # benchmark_uid (int): Benchmark UID + # metadata (dict, optional): Additional metadata. Defaults to {}. + # """ + + # @abstractmethod + # def set_dataset_association_approval( + # self, dataset_uid: str, benchmark_uid: str, status: str + # ): + # """Approves a dataset association + + # Args: + # dataset_uid (str): Dataset UID + # benchmark_uid (str): Benchmark UID + # status (str): Approval status to set for the association + # """ + + # @abstractmethod + # def set_mlcube_association_approval( + # self, mlcube_uid: str, benchmark_uid: str, status: str + # ): + # """Approves an mlcube association + + # Args: + # mlcube_uid (str): Dataset UID + # benchmark_uid (str): Benchmark UID + # status (str): Approval status to set for the association + # """ + + # @abstractmethod + # def get_datasets_associations(self) -> List[dict]: + # """Get all dataset associations related to the current user + + # Returns: + # List[dict]: List containing all associations information + # """ + + # @abstractmethod + # def get_cubes_associations(self) -> List[dict]: + # """Get all cube associations related to the current user + + # Returns: + # List[dict]: List containing all associations information + # """ + + # @abstractmethod + # def set_mlcube_association_priority( + # self, benchmark_uid: str, mlcube_uid: str, priority: int + # ): + # """Sets the priority of an mlcube-benchmark association + + # Args: + # mlcube_uid (str): MLCube UID + # benchmark_uid (str): Benchmark UID + # priority (int): priority value to set for the association + # """ + + # @abstractmethod + # def update_dataset(self, dataset_id: int, data: dict): + # """Updates the contents of a datasets identified by dataset_id to the new data dictionary. + # Updates may be partial. + + # Args: + # dataset_id (int): ID of the dataset to update + # data (dict): Updated information of the dataset. + # """ diff --git a/cli/medperf/comms/rest.py b/cli/medperf/comms/rest.py index 5ac236f93..994827919 100644 --- a/cli/medperf/comms/rest.py +++ b/cli/medperf/comms/rest.py @@ -5,12 +5,7 @@ from medperf.enums import Status import medperf.config as config from medperf.comms.interface import Comms -from medperf.utils import ( - sanitize_json, - log_response_error, - format_errors_dict, - filter_latest_associations, -) +from medperf.utils import sanitize_json, log_response_error, format_errors_dict from medperf.exceptions import ( CommunicationError, CommunicationRetrievalError, @@ -81,6 +76,7 @@ def __get_list( page_size=config.default_page_size, offset=0, binary_reduction=False, + error_msg: str = "", ): """Retrieves a list of elements from a URL by iterating over pages until num_elements is obtained. If num_elements is None, then iterates until all elements have been retrieved. @@ -110,16 +106,15 @@ def __get_list( if not binary_reduction: log_response_error(res) details = format_errors_dict(res.json()) - raise CommunicationRetrievalError( - f"there was an error retrieving the current list: {details}" - ) + raise CommunicationRetrievalError(f"{error_msg}: {details}") log_response_error(res, warn=True) details = format_errors_dict(res.json()) if page_size <= 1: - raise CommunicationRetrievalError( - f"Could not retrieve list. Minimum page size achieved without success: {details}" + logging.debug( + "Could not retrieve list. Minimum page size achieved without success" ) + raise CommunicationRetrievalError(f"{error_msg}: {details}") page_size = page_size // 2 continue else: @@ -133,33 +128,54 @@ def __get_list( return el_list[:num_elements] return el_list - def __set_approval_status(self, url: str, status: str) -> requests.Response: - """Sets the approval status of a resource + def __get(self, url: str, error_msg: str) -> dict: + """self.__auth_get with error handling""" + res = self.__auth_get(url) + if res.status_code != 200: + log_response_error(res) + details = format_errors_dict(res.json()) + raise CommunicationRetrievalError(f"{error_msg}: {details}") + return res.json() - Args: - url (str): URL to the resource to update - status (str): approval status to set + def __post(self, url: str, json: dict, error_msg: str) -> int: + """self.__auth_post with error handling""" + res = self.__auth_post(url, json=json) + if res.status_code != 201: + log_response_error(res) + details = format_errors_dict(res.json()) + raise CommunicationRetrievalError(f"{error_msg}: {details}") + return res.json() - Returns: - requests.Response: Response object returned by the update - """ - data = {"approval_status": status} - res = self.__auth_put(url, json=data) - return res + def __put(self, url: str, json: dict, error_msg: str): + """self.__auth_put with error handling""" + res = self.__auth_put(url, json=json) + if res.status_code != 200: + log_response_error(res) + details = format_errors_dict(res.json()) + raise CommunicationRequestError(f"{error_msg}: {details}") def get_current_user(self): """Retrieve the currently-authenticated user information""" - res = self.__auth_get(f"{self.server_url}/me/") - return res.json() + url = f"{self.server_url}/me/" + error_msg = "Could not get current user" + return self.__get(url, error_msg) - def get_benchmarks(self) -> List[dict]: - """Retrieves all benchmarks in the platform. + # get object + def get_user(self, user_id: int) -> dict: + """Retrieves the specified user. This will only return if + the current user has permission to view the requested user, + either by being himself, an admin or an owner of a data preparation + mlcube used by the requested user + + Args: + user_id (int): User UID Returns: - List[dict]: all benchmarks information. + dict: Requested user information """ - bmks = self.__get_list(f"{self.server_url}/benchmarks/") - return bmks + url = f"{self.server_url}/users/{user_id}/" + error_msg = "Could not retrieve user" + return self.__get(url, error_msg) def get_benchmark(self, benchmark_uid: int) -> dict: """Retrieves the benchmark specification file from the server @@ -170,103 +186,161 @@ def get_benchmark(self, benchmark_uid: int) -> dict: Returns: dict: benchmark specification """ - res = self.__auth_get(f"{self.server_url}/benchmarks/{benchmark_uid}") - if res.status_code != 200: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRetrievalError( - f"the specified benchmark doesn't exist: {details}" - ) - return res.json() + url = f"{self.server_url}/benchmarks/{benchmark_uid}" + error_msg = "Could not retrieve benchmark" + return self.__get(url, error_msg) - def get_benchmark_model_associations(self, benchmark_uid: int) -> List[int]: - """Retrieves all the model associations of a benchmark. + def get_cube_metadata(self, cube_uid: int) -> dict: + """Retrieves metadata about the specified cube Args: - benchmark_uid (int): UID of the desired benchmark + cube_uid (int): UID of the desired cube. Returns: - list[int]: List of benchmark model associations + dict: Dictionary containing url and hashes for the cube files """ - assocs = self.__get_list(f"{self.server_url}/benchmarks/{benchmark_uid}/models") - return filter_latest_associations(assocs, "model_mlcube") + url = f"{self.server_url}/mlcubes/{cube_uid}/" + error_msg = "Could not retrieve mlcube" + return self.__get(url, error_msg) - def get_user_benchmarks(self) -> List[dict]: - """Retrieves all benchmarks created by the user + def get_dataset(self, dset_uid: int) -> dict: + """Retrieves a specific dataset + + Args: + dset_uid (int): Dataset UID Returns: - List[dict]: Benchmarks data + dict: Dataset metadata """ - bmks = self.__get_list(f"{self.server_url}/me/benchmarks/") - return bmks + url = f"{self.server_url}/datasets/{dset_uid}/" + error_msg = "Could not retrieve dataset" + return self.__get(url, error_msg) - def get_cubes(self) -> List[dict]: - """Retrieves all MLCubes in the platform + def get_result(self, result_uid: int) -> dict: + """Retrieves a specific result data + + Args: + result_uid (int): Result UID Returns: - List[dict]: List containing the data of all MLCubes + dict: Result metadata """ - cubes = self.__get_list(f"{self.server_url}/mlcubes/") - return cubes + url = f"{self.server_url}/results/{result_uid}/" + error_msg = "Could not retrieve result" + return self.__get(url, error_msg) - def get_cube_metadata(self, cube_uid: int) -> dict: - """Retrieves metadata about the specified cube + def get_training_exp(self, training_exp_id: int) -> dict: + """Retrieves the training_exp specification file from the server Args: - cube_uid (int): UID of the desired cube. + training_exp_id (int): uid for the desired training_exp Returns: - dict: Dictionary containing url and hashes for the cube files + dict: training_exp specification """ - res = self.__auth_get(f"{self.server_url}/mlcubes/{cube_uid}/") - if res.status_code != 200: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRetrievalError( - f"the specified cube doesn't exist {details}" - ) - return res.json() + url = f"{self.server_url}/training/{training_exp_id}/" + error_msg = "Could not retrieve training experiment" + return self.__get(url, error_msg) - def get_user_cubes(self) -> List[dict]: - """Retrieves metadata from all cubes registered by the user + def get_aggregator(self, aggregator_id: int) -> dict: + """Retrieves the aggregator specification file from the server + + Args: + benchmark_uid (int): uid for the desired benchmark Returns: - List[dict]: List of dictionaries containing the mlcubes registration information + dict: benchmark specification """ - cubes = self.__get_list(f"{self.server_url}/me/mlcubes/") - return cubes + url = f"{self.server_url}/aggregators/{aggregator_id}" + error_msg = "Could not retrieve aggregator" + return self.__get(url, error_msg) - def upload_benchmark(self, benchmark_dict: dict) -> int: - """Uploads a new benchmark to the server. + def get_ca(self, ca_id: int) -> dict: + """Retrieves the aggregator specification file from the server Args: - benchmark_dict (dict): benchmark_data to be uploaded + benchmark_uid (int): uid for the desired benchmark Returns: - int: UID of newly created benchmark + dict: benchmark specification """ - res = self.__auth_post(f"{self.server_url}/benchmarks/", json=benchmark_dict) - if res.status_code != 201: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRetrievalError(f"Could not upload benchmark: {details}") - return res.json() + url = f"{self.server_url}/cas/{ca_id}" + error_msg = "Could not retrieve ca" + return self.__get(url, error_msg) - def upload_mlcube(self, mlcube_body: dict) -> int: - """Uploads an MLCube instance to the platform + def get_training_event(self, event_id: int) -> dict: + """Retrieves the aggregator specification file from the server Args: - mlcube_body (dict): Dictionary containing all the relevant data for creating mlcubes + benchmark_uid (int): uid for the desired benchmark Returns: - int: id of the created mlcube instance on the platform + dict: benchmark specification """ - res = self.__auth_post(f"{self.server_url}/mlcubes/", json=mlcube_body) - if res.status_code != 201: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRetrievalError(f"Could not upload the mlcube: {details}") - return res.json() + url = f"{self.server_url}/training/events/{event_id}" + error_msg = "Could not retrieve training event" + return self.__get(url, error_msg) + + # get object of an object + def get_experiment_event(self, training_exp_id: int) -> dict: + """Retrieves the training experiment's event object from the server + + Args: + training_exp_id (int): uid for the training experiment + + Returns: + dict: event specification + """ + url = f"{self.server_url}/training/{training_exp_id}/event/" + error_msg = "Could not retrieve training experiment event" + return self.__get(url, error_msg) + + def get_experiment_aggregator(self, training_exp_id: int) -> dict: + """Retrieves the training experiment's aggregator object from the server + + Args: + training_exp_id (int): uid for the training experiment + + Returns: + dict: aggregator specification + """ + url = f"{self.server_url}/training/{training_exp_id}/aggregator/" + error_msg = "Could not retrieve training experiment aggregator" + return self.__get(url, error_msg) + + def get_experiment_ca(self, training_exp_id: int) -> dict: + """Retrieves the training experiment's ca object from the server + + Args: + training_exp_id (int): uid for the training experiment + + Returns: + dict: ca specification + """ + url = f"{self.server_url}/training/{training_exp_id}/ca/" + error_msg = "Could not retrieve training experiment ca" + return self.__get(url, error_msg) + + # get list + def get_benchmarks(self) -> List[dict]: + """Retrieves all benchmarks in the platform. + + Returns: + List[dict]: all benchmarks information. + """ + url = f"{self.server_url}/benchmarks/" + error_msg = "Could not retrieve benchmarks" + return self.__get_list(url, error_msg=error_msg) + + def get_cubes(self) -> List[dict]: + """Retrieves all MLCubes in the platform + + Returns: + List[dict]: List containing the data of all MLCubes + """ + url = f"{self.server_url}/mlcubes/" + error_msg = "Could not retrieve mlcubes" + return self.__get_list(url, error_msg=error_msg) def get_datasets(self) -> List[dict]: """Retrieves all datasets in the platform @@ -274,26 +348,70 @@ def get_datasets(self) -> List[dict]: Returns: List[dict]: List of data from all datasets """ - dsets = self.__get_list(f"{self.server_url}/datasets/") - return dsets + url = f"{self.server_url}/datasets/" + error_msg = "Could not retrieve datasets" + return self.__get_list(url, error_msg=error_msg) - def get_dataset(self, dset_uid: int) -> dict: - """Retrieves a specific dataset + def get_results(self) -> List[dict]: + """Retrieves all results - Args: - dset_uid (int): Dataset UID + Returns: + List[dict]: List of results + """ + url = f"{self.server_url}/results/" + error_msg = "Could not retrieve results" + return self.__get_list(url, error_msg=error_msg) + + def get_training_exps(self) -> List[dict]: + """Retrieves all training_exps Returns: - dict: Dataset metadata + List[dict]: List of training_exps """ - res = self.__auth_get(f"{self.server_url}/datasets/{dset_uid}/") - if res.status_code != 200: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRetrievalError( - f"Could not retrieve the specified dataset from server: {details}" - ) - return res.json() + url = f"{self.server_url}/training/" + error_msg = "Could not retrieve training experiments" + return self.__get_list(url, error_msg=error_msg) + + def get_aggregators(self) -> List[dict]: + """Retrieves all aggregators + + Returns: + List[dict]: List of aggregators + """ + url = f"{self.server_url}/aggregators/" + error_msg = "Could not retrieve aggregators" + return self.__get_list(url, error_msg=error_msg) + + def get_cas(self) -> List[dict]: + """Retrieves all cas + + Returns: + List[dict]: List of cas + """ + url = f"{self.server_url}/cas/" + error_msg = "Could not retrieve cas" + return self.__get_list(url, error_msg=error_msg) + + def get_training_events(self) -> List[dict]: + """Retrieves all training events + + Returns: + List[dict]: List of training events + """ + url = f"{self.server_url}/training/events/" + error_msg = "Could not retrieve training events" + return self.__get_list(url, error_msg=error_msg) + + # get user list + def get_user_cubes(self) -> List[dict]: + """Retrieves metadata from all cubes registered by the user + + Returns: + List[dict]: List of dictionaries containing the mlcubes registration information + """ + url = f"{self.server_url}/me/mlcubes/" + error_msg = "Could not retrieve user mlcubes" + return self.__get_list(url, error_msg=error_msg) def get_user_datasets(self) -> dict: """Retrieves all datasets registered by the user @@ -301,8 +419,147 @@ def get_user_datasets(self) -> dict: Returns: dict: dictionary with the contents of each dataset registration query """ - dsets = self.__get_list(f"{self.server_url}/me/datasets/") - return dsets + url = f"{self.server_url}/me/datasets/" + error_msg = "Could not retrieve user datasets" + return self.__get_list(url, error_msg=error_msg) + + def get_user_benchmarks(self) -> List[dict]: + """Retrieves all benchmarks created by the user + + Returns: + List[dict]: Benchmarks data + """ + url = f"{self.server_url}/me/benchmarks/" + error_msg = "Could not retrieve user benchmarks" + return self.__get_list(url, error_msg=error_msg) + + def get_user_results(self) -> dict: + """Retrieves all results registered by the user + + Returns: + dict: dictionary with the contents of each result registration query + """ + url = f"{self.server_url}/me/results/" + error_msg = "Could not retrieve user results" + return self.__get_list(url, error_msg=error_msg) + + def get_user_training_exps(self) -> dict: + """Retrieves all training_exps registered by the user + + Returns: + dict: dictionary with the contents of each result registration query + """ + url = f"{self.server_url}/me/training/" + error_msg = "Could not retrieve user training experiments" + return self.__get_list(url, error_msg=error_msg) + + def get_user_aggregators(self) -> dict: + """Retrieves all aggregators registered by the user + + Returns: + dict: dictionary with the contents of each result registration query + """ + url = f"{self.server_url}/me/aggregators/" + error_msg = "Could not retrieve user aggregators" + return self.__get_list(url, error_msg=error_msg) + + def get_user_cas(self) -> dict: + """Retrieves all cas registered by the user + + Returns: + dict: dictionary with the contents of each result registration query + """ + url = f"{self.server_url}/me/cas/" + error_msg = "Could not retrieve user cas" + return self.__get_list(url, error_msg=error_msg) + + def get_user_training_events(self) -> dict: + """Retrieves all training events registered by the user + + Returns: + dict: dictionary with the contents of each result registration query + """ + url = f"{self.server_url}/me/training/events/" + error_msg = "Could not retrieve user training events" + return self.__get_list(url, error_msg=error_msg) + + # get user associations list + def get_user_benchmarks_datasets_associations(self) -> List[dict]: + """Get all dataset associations related to the current user + + Returns: + List[dict]: List containing all associations information + """ + url = f"{self.server_url}/me/datasets/associations/" + error_msg = "Could not retrieve user datasets benchmark associations" + return self.__get_list(url, error_msg=error_msg) + + def get_user_benchmarks_models_associations(self) -> List[dict]: + """Get all cube associations related to the current user + + Returns: + List[dict]: List containing all associations information + """ + url = f"{self.server_url}/me/mlcubes/associations/" + error_msg = "Could not retrieve user mlcubes benchmark associations" + return self.__get_list(url, error_msg=error_msg) + + def get_user_training_datasets_associations(self) -> List[dict]: + """Get all training dataset associations related to the current user + + Returns: + List[dict]: List containing all associations information + """ + url = f"{self.server_url}/me/datasets/training_associations/" + error_msg = "Could not retrieve user datasets training associations" + return self.__get_list(url, error_msg=error_msg) + + def get_user_training_aggregators_associations(self) -> List[dict]: + """Get all aggregator associations related to the current user + + Returns: + List[dict]: List containing all associations information + """ + url = f"{self.server_url}/me/aggregators/training_associations/" + error_msg = "Could not retrieve user aggregators training associations" + return self.__get_list(url, error_msg=error_msg) + + def get_user_training_cas_associations(self) -> List[dict]: + """Get all ca associations related to the current user + + Returns: + List[dict]: List containing all associations information + """ + url = f"{self.server_url}/me/cas/training_associations/" + error_msg = "Could not retrieve user cas training associations" + return self.__get_list(url, error_msg=error_msg) + + # upload + def upload_benchmark(self, benchmark_dict: dict) -> int: + """Uploads a new benchmark to the server. + + Args: + benchmark_dict (dict): benchmark_data to be uploaded + + Returns: + int: UID of newly created benchmark + """ + url = f"{self.server_url}/benchmarks/" + error_msg = "could not upload benchmark" + return self.__post(url, json=benchmark_dict, error_msg=error_msg) + + def upload_mlcube(self, mlcube_body: dict) -> int: + """Uploads an MLCube instance to the platform + + Args: + mlcube_body (dict): Dictionary containing all the relevant data for creating mlcubes + + Returns: + int: id of the created mlcube instance on the platform + """ + url = f"{self.server_url}/mlcubes/" + error_msg = "could not upload mlcube" + return self.__post(url, json=mlcube_body, error_msg=error_msg) def upload_dataset(self, reg_dict: dict) -> int: """Uploads registration data to the server, under the sha name of the file. @@ -313,84 +570,79 @@ def upload_dataset(self, reg_dict: dict) -> int: Returns: int: id of the created dataset registration. """ - res = self.__auth_post(f"{self.server_url}/datasets/", json=reg_dict) - if res.status_code != 201: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRequestError(f"Could not upload the dataset: {details}") - return res.json() + url = f"{self.server_url}/datasets/" + error_msg = "could not upload dataset" + return self.__post(url, json=reg_dict, error_msg=error_msg) - def get_results(self) -> List[dict]: - """Retrieves all results + def upload_result(self, results_dict: dict) -> int: + """Uploads result to the server. + + Args: + results_dict (dict): Dictionary containing results information. Returns: - List[dict]: List of results + dicr: generated results entry """ - res = self.__get_list(f"{self.server_url}/results") - if res.status_code != 200: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRetrievalError(f"Could not retrieve results: {details}") - return res.json() + url = f"{self.server_url}/results/" + error_msg = "could not upload result" + return self.__post(url, json=results_dict, error_msg=error_msg) - def get_result(self, result_uid: int) -> dict: - """Retrieves a specific result data + def upload_training_exp(self, training_exp_dict: dict) -> int: + """Uploads a new training_exp to the server. Args: - result_uid (int): Result UID + training_exp_dict (dict): training_exp to be uploaded Returns: - dict: Result metadata + dict: newly created training_exp """ - res = self.__auth_get(f"{self.server_url}/results/{result_uid}/") - if res.status_code != 200: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRetrievalError( - f"Could not retrieve the specified result: {details}" - ) - return res.json() + url = f"{self.server_url}/training/" + error_msg = "could not upload training experiment" + return self.__post(url, json=training_exp_dict, error_msg=error_msg) - def get_user_results(self) -> dict: - """Retrieves all results registered by the user + def upload_aggregator(self, aggregator_dict: dict) -> int: + """Uploads a new aggregator to the server. + + Args: + benchmark_dict (dict): benchmark_data to be uploaded Returns: - dict: dictionary with the contents of each result registration query + int: UID of newly created benchmark """ - results = self.__get_list(f"{self.server_url}/me/results/") - return results + url = f"{self.server_url}/aggregators/" + error_msg = "could not upload aggregator" + return self.__post(url, json=aggregator_dict, error_msg=error_msg) - def get_benchmark_results(self, benchmark_id: int) -> dict: - """Retrieves all results for a given benchmark + def upload_ca(self, ca_dict: dict) -> int: + """Uploads a new ca to the server. Args: - benchmark_id (int): benchmark ID to retrieve results from + benchmark_dict (dict): benchmark_data to be uploaded Returns: - dict: dictionary with the contents of each result in the specified benchmark + int: UID of newly created benchmark """ - results = self.__get_list( - f"{self.server_url}/benchmarks/{benchmark_id}/results" - ) - return results + url = f"{self.server_url}/cas/" + error_msg = "could not upload ca" + return self.__post(url, json=ca_dict, error_msg=error_msg) - def upload_result(self, results_dict: dict) -> int: - """Uploads result to the server. + def upload_training_event(self, trainnig_event_dict: dict) -> int: + """Uploads a new training event to the server. Args: - results_dict (dict): Dictionary containing results information. + benchmark_dict (dict): benchmark_data to be uploaded Returns: - int: id of the generated results entry + int: UID of newly created benchmark """ - res = self.__auth_post(f"{self.server_url}/results/", json=results_dict) - if res.status_code != 201: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRequestError(f"Could not upload the results: {details}") - return res.json() + url = f"{self.server_url}/training/events/" + error_msg = "could not upload training event" + return self.__post(url, json=trainnig_event_dict, error_msg=error_msg) - def associate_dset(self, data_uid: int, benchmark_uid: int, metadata: dict = {}): + # Association creation + def associate_benchmark_dataset( + self, data_uid: int, benchmark_uid: int, metadata: dict = {} + ): """Create a Dataset Benchmark association Args: @@ -398,21 +650,19 @@ def associate_dset(self, data_uid: int, benchmark_uid: int, metadata: dict = {}) benchmark_uid (int): Benchmark UID metadata (dict, optional): Additional metadata. Defaults to {}. """ + url = f"{self.server_url}/datasets/benchmarks/" data = { "dataset": data_uid, "benchmark": benchmark_uid, "approval_status": Status.PENDING.value, "metadata": metadata, } - res = self.__auth_post(f"{self.server_url}/datasets/benchmarks/", json=data) - if res.status_code != 201: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRequestError( - f"Could not associate dataset to benchmark: {details}" - ) + error_msg = "Could not associate dataset to benchmark" + return self.__post(url, json=data, error_msg=error_msg) - def associate_cube(self, cube_uid: int, benchmark_uid: int, metadata: dict = {}): + def associate_benchmark_model( + self, cube_uid: int, benchmark_uid: int, metadata: dict = {} + ): """Create an MLCube-Benchmark association Args: @@ -420,22 +670,68 @@ def associate_cube(self, cube_uid: int, benchmark_uid: int, metadata: dict = {}) benchmark_uid (int): Benchmark UID metadata (dict, optional): Additional metadata. Defaults to {}. """ + url = f"{self.server_url}/mlcubes/benchmarks/" data = { "approval_status": Status.PENDING.value, "model_mlcube": cube_uid, "benchmark": benchmark_uid, "metadata": metadata, } - res = self.__auth_post(f"{self.server_url}/mlcubes/benchmarks/", json=data) - if res.status_code != 201: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRequestError( - f"Could not associate mlcube to benchmark: {details}" - ) + error_msg = "Could not associate mlcube to benchmark" + return self.__post(url, json=data, error_msg=error_msg) + + def associate_training_dataset(self, data_uid: int, training_exp_id: int): + """Create a Dataset experiment association + + Args: + data_uid (int): Registered dataset UID + benchmark_uid (int): Benchmark UID + metadata (dict, optional): Additional metadata. Defaults to {}. + """ + url = f"{self.server_url}/datasets/training/" + data = { + "dataset": data_uid, + "training_exp": training_exp_id, + "approval_status": Status.PENDING.value, + } + error_msg = "Could not associate dataset to training_exp" + return self.__post(url, json=data, error_msg=error_msg) + + def associate_training_aggregator(self, aggregator_id: int, training_exp_id: int): + """Create a aggregator experiment association + + Args: + aggregator_id (int): Registered aggregator UID + training_exp_id (int): training experiment UID + """ + url = f"{self.server_url}/aggregators/training/" + data = { + "aggregator": aggregator_id, + "training_exp": training_exp_id, + "approval_status": Status.PENDING.value, + } + error_msg = "Could not associate aggregator to training_exp" + return self.__post(url, json=data, error_msg=error_msg) + + def associate_training_ca(self, ca_id: int, training_exp_id: int): + """Create a ca experiment association - def set_dataset_association_approval( - self, benchmark_uid: int, dataset_uid: int, status: str + Args: + ca_id (int): Registered ca UID + training_exp_id (int): training experiment UID + """ + url = f"{self.server_url}/cas/training/" + data = { + "ca": ca_id, + "training_exp": training_exp_id, + "approval_status": Status.PENDING.value, + } + error_msg = "Could not associate ca to training_exp" + return self.__post(url, json=data, error_msg=error_msg) + + # updates associations + def update_benchmark_dataset_association( + self, benchmark_uid: int, dataset_uid: int, data: str ): """Approves a dataset association @@ -445,16 +741,11 @@ def set_dataset_association_approval( status (str): Approval status to set for the association """ url = f"{self.server_url}/datasets/{dataset_uid}/benchmarks/{benchmark_uid}/" - res = self.__set_approval_status(url, status) - if res.status_code != 200: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRequestError( - f"Could not approve association between dataset {dataset_uid} and benchmark {benchmark_uid}: {details}" - ) + error_msg = f"Could not update association: dataset {dataset_uid}, benchmark {benchmark_uid}" + self.__put(url, json=data, error_msg=error_msg) - def set_mlcube_association_approval( - self, benchmark_uid: int, mlcube_uid: int, status: str + def update_benchmark_model_association( + self, benchmark_uid: int, mlcube_uid: int, data: dict ): """Approves an mlcube association @@ -464,60 +755,93 @@ def set_mlcube_association_approval( status (str): Approval status to set for the association """ url = f"{self.server_url}/mlcubes/{mlcube_uid}/benchmarks/{benchmark_uid}/" - res = self.__set_approval_status(url, status) - if res.status_code != 200: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRequestError( - f"Could not approve association between mlcube {mlcube_uid} and benchmark {benchmark_uid}: {details}" - ) + error_msg = ( + f"Could update association: mlcube {mlcube_uid}, benchmark {benchmark_uid}" + ) + self.__put(url, json=data, error_msg=error_msg) - def get_datasets_associations(self) -> List[dict]: - """Get all dataset associations related to the current user + def update_training_aggregator_association( + self, training_exp_id: int, aggregator_id: int, data: dict + ): + """Approves a aggregator association - Returns: - List[dict]: List containing all associations information + Args: + dataset_uid (int): Dataset UID + benchmark_uid (int): Benchmark UID + status (str): Approval status to set for the association """ - assocs = self.__get_list(f"{self.server_url}/me/datasets/associations/") - return filter_latest_associations(assocs, "dataset") + url = ( + f"{self.server_url}/aggregators/{aggregator_id}/training/{training_exp_id}/" + ) + error_msg = ( + "Could not update association: aggregator" + f" {aggregator_id}, training_exp {training_exp_id}" + ) + self.__put(url, json=data, error_msg=error_msg) - def get_cubes_associations(self) -> List[dict]: - """Get all cube associations related to the current user + def update_training_dataset_association( + self, training_exp_id: int, dataset_uid: int, data: dict + ): + """Approves a training dataset association - Returns: - List[dict]: List containing all associations information + Args: + dataset_uid (int): Dataset UID + benchmark_uid (int): Benchmark UID + status (str): Approval status to set for the association """ - assocs = self.__get_list(f"{self.server_url}/me/mlcubes/associations/") - return filter_latest_associations(assocs, "model_mlcube") + url = f"{self.server_url}/datasets/{dataset_uid}/training/{training_exp_id}/" + error_msg = ( + "Could not approve association: dataset" + f"{dataset_uid}, training_exp {training_exp_id}" + ) + self.__put(url, json=data, error_msg=error_msg) - def set_mlcube_association_priority( - self, benchmark_uid: int, mlcube_uid: int, priority: int + def update_training_ca_association( + self, training_exp_id: int, ca_uid: int, data: dict ): - """Sets the priority of an mlcube-benchmark association + """Approves a training ca association Args: - mlcube_uid (int): MLCube UID + dataset_uid (int): Dataset UID benchmark_uid (int): Benchmark UID - priority (int): priority value to set for the association + status (str): Approval status to set for the association """ - url = f"{self.server_url}/mlcubes/{mlcube_uid}/benchmarks/{benchmark_uid}/" - data = {"priority": priority} - res = self.__auth_put(url, json=data) - if res.status_code != 200: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRequestError( - f"Could not set the priority of mlcube {mlcube_uid} within the benchmark {benchmark_uid}: {details}" - ) + url = f"{self.server_url}/cas/{ca_uid}/training/{training_exp_id}/" + error_msg = ( + "Could not update association: ca" + f"{ca_uid}, training_exp {training_exp_id}" + ) + self.__put(url, json=data, error_msg=error_msg) + # update objects def update_dataset(self, dataset_id: int, data: dict): url = f"{self.server_url}/datasets/{dataset_id}/" - res = self.__auth_put(url, json=data) - if res.status_code != 200: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRequestError(f"Could not update dataset: {details}") - return res.json() + error_msg = "Could not update dataset" + return self.__put(url, json=data, error_msg=error_msg) + + def update_training_exp(self, training_exp_id: int, data: dict): + url = f"{self.server_url}/training/{training_exp_id}/" + error_msg = "Could not update training experiment" + return self.__put(url, json=data, error_msg=error_msg) + + def update_training_event(self, training_event_id: int, data: dict): + url = f"{self.server_url}/training/events/{training_event_id}/" + error_msg = "Could not update training event" + return self.__put(url, json=data, error_msg=error_msg) + + # misc + def get_benchmark_results(self, benchmark_id: int) -> dict: + """Retrieves all results for a given benchmark + + Args: + benchmark_id (int): benchmark ID to retrieve results from + + Returns: + dict: dictionary with the contents of each result in the specified benchmark + """ + url = f"{self.server_url}/benchmarks/{benchmark_id}/results/" + error_msg = "Could not get benchmark results" + return self.__get_list(url, error_msg=error_msg) def get_mlcube_datasets(self, mlcube_id: int) -> dict: """Retrieves all datasets that have the specified mlcube as the prep mlcube @@ -528,26 +852,45 @@ def get_mlcube_datasets(self, mlcube_id: int) -> dict: Returns: dict: dictionary with the contents of each dataset """ + url = f"{self.server_url}/mlcubes/{mlcube_id}/datasets/" + error_msg = "Could not get mlcube datasets" + return self.__get_list(url, error_msg=error_msg) - datasets = self.__get_list(f"{self.server_url}/mlcubes/{mlcube_id}/datasets/") - return datasets + def get_training_datasets_associations(self, training_exp_id: int) -> dict: + """Retrieves all datasets for a given training_exp - def get_user(self, user_id: int) -> dict: - """Retrieves the specified user. This will only return if - the current user has permission to view the requested user, - either by being himself, an admin or an owner of a data preparation - mlcube used by the requested user + Args: + benchmark_id (int): benchmark ID to retrieve results from + + Returns: + dict: dictionary with the contents of each result in the specified benchmark + """ + url = f"{self.server_url}/training/{training_exp_id}/datasets" + error_msg = "Could not get training experiment datasets associations" + return self.__get_list(url, error_msg=error_msg) + + def get_benchmark_models_associations(self, benchmark_uid: int) -> List[int]: + """Retrieves all the model associations of a benchmark. Args: - user_id (int): User UID + benchmark_uid (int): UID of the desired benchmark Returns: - dict: Requested user information + list[int]: List of benchmark model associations """ - url = f"{self.server_url}/users/{user_id}/" - res = self.__auth_get(url) - if res.status_code != 200: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRequestError(f"Could not retrieve user: {details}") - return res.json() + url = f"{self.server_url}/benchmarks/{benchmark_uid}/models" + error_msg = "Could not get benchmark models associations" + return self.__get_list(url, error_msg=error_msg) + + def get_training_datasets_with_users(self, training_exp_id: int) -> dict: + """Retrieves all datasets for a given training_exp and their owner information + + Args: + training_exp_id (int): training exp ID + + Returns: + dict: dictionary with the contents of dataset IDs and owner info + """ + url = f"{self.server_url}/training/{training_exp_id}/participants_info/" + error_msg = "Could not get training experiment participants info" + return self.__get_list(url, error_msg=error_msg) diff --git a/cli/medperf/config.py b/cli/medperf/config.py index f43af8f4b..b748db74e 100644 --- a/cli/medperf/config.py +++ b/cli/medperf/config.py @@ -49,6 +49,7 @@ auth_jwks_file = str(config_storage / ".jwks") creds_folder = str(config_storage / ".tokens") tokens_db = str(config_storage / ".tokens_db") +pki_assets = str(config_storage / ".pki_assets") images_folder = ".images" trash_folder = ".trash" @@ -62,6 +63,10 @@ results_folder = "results" predictions_folder = "predictions" tests_folder = "tests" +training_folder = "training" +aggregators_folder = "aggregators" +cas_folder = "cas" +training_events_folder = "training_events" default_base_storage = str(Path.home().resolve() / ".medperf") @@ -110,6 +115,22 @@ "base": default_base_storage, "name": tests_folder, }, + "training_folder": { + "base": default_base_storage, + "name": training_folder, + }, + "aggregators_folder": { + "base": default_base_storage, + "name": aggregators_folder, + }, + "cas_folder": { + "base": default_base_storage, + "name": cas_folder, + }, + "training_events_folder": { + "base": default_base_storage, + "name": training_events_folder, + }, } root_folders = [ @@ -126,6 +147,10 @@ "results_folder", "predictions_folder", "tests_folder", + "training_folder", + "aggregators_folder", + "cas_folder", + "training_events_folder", ] # MedPerf filenames conventions @@ -133,12 +158,27 @@ benchmarks_filename = "benchmark.yaml" test_report_file = "test_report.yaml" reg_file = "registration-info.yaml" +agg_file = "agg-info.yaml" +ca_file = "ca-info.yaml" +training_event_file = "event.yaml" cube_metadata_filename = "mlcube-meta.yaml" log_file = "medperf.log" log_package_file = "medperf_logs.tar.gz" tarball_filename = "tmp.tar.gz" demo_dset_paths_file = "paths.yaml" mlcube_cache_file = ".cache_metadata.yaml" +training_exps_filename = "training-info.yaml" +participants_list_filename = "cols.yaml" +training_exp_plan_filename = "plan.yaml" +training_exp_status_filename = "status.yaml" +training_report_file = "report.yaml" +training_report_folder = "report" +training_out_agg_logs = "agg_logs" +training_out_col_logs = "col_logs" +training_out_weights = "weights" +ca_cert_folder = "ca_cert" +ca_config_file = "ca_config.json" +agg_config_file = "aggregator_config.yaml" report_file = "report.yaml" metadata_folder = "metadata" statistics_filename = "statistics.yaml" diff --git a/cli/medperf/entities/aggregator.py b/cli/medperf/entities/aggregator.py new file mode 100644 index 000000000..47187df8b --- /dev/null +++ b/cli/medperf/entities/aggregator.py @@ -0,0 +1,108 @@ +import os +from pydantic import validator + +from medperf.entities.interface import Entity +from medperf.entities.schemas import MedperfSchema + +import medperf.config as config +from medperf.account_management import get_medperf_user_data +import yaml + + +class Aggregator(Entity, MedperfSchema): + """ + Class representing a compatibility test report entry + + A test report is comprised of the components of a test execution: + - data used, which can be: + - a demo aggregator url and its hash, or + - a raw data path and its labels path, or + - a prepared aggregator uid + - Data preparation cube if the data used was not already prepared + - model cube + - evaluator cube + - results + """ + + metadata: dict = {} + config: dict + aggregation_mlcube: int + + @staticmethod + def get_type(): + return "aggregator" + + @staticmethod + def get_storage_path(): + return config.aggregators_folder + + @staticmethod + def get_comms_retriever(): + return config.comms.get_aggregator + + @staticmethod + def get_metadata_filename(): + return config.agg_file + + @staticmethod + def get_comms_uploader(): + return config.comms.upload_aggregator + + @validator("config", pre=True, always=True) + def check_config(cls, v, *, values, **kwargs): + keys = set(v.keys()) + allowed_keys = { + "address", + "port", + } + if keys != allowed_keys: + raise ValueError("config must contain two keys only: address and port") + return v + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.generated_uid = self.name + + self.address = self.config["address"] + self.port = self.config["port"] + + self.config_path = os.path.join(self.path, config.agg_config_file) + + @classmethod + def from_experiment(cls, training_exp_uid: int) -> "Aggregator": + meta = config.comms.get_experiment_aggregator(training_exp_uid) + agg = cls(**meta) + agg.write() + return agg + + @classmethod + def _Entity__remote_prefilter(cls, filters: dict) -> callable: + """Applies filtering logic that must be done before retrieving remote entities + + Args: + filters (dict): filters to apply + + Returns: + callable: A function for retrieving remote entities with the applied prefilters + """ + comms_fn = config.comms.get_aggregators + if "owner" in filters and filters["owner"] == get_medperf_user_data()["id"]: + comms_fn = config.comms.get_user_aggregators + return comms_fn + + def prepare_config(self): + with open(self.config_path, "w") as f: + yaml.dump(self.config, f) + + def display_dict(self): + return { + "UID": self.identifier, + "Name": self.name, + "Generated Hash": self.generated_uid, + "Address": self.address, + "MLCube": int(self.aggregation_mlcube), + "Port": self.port, + "Created At": self.created_at, + "Registered": self.is_registered, + } diff --git a/cli/medperf/entities/benchmark.py b/cli/medperf/entities/benchmark.py index 849ea3fcd..9edd76a7e 100644 --- a/cli/medperf/entities/benchmark.py +++ b/cli/medperf/entities/benchmark.py @@ -1,18 +1,14 @@ -import os -from medperf.exceptions import MedperfException -import yaml -import logging -from typing import List, Optional, Union +from typing import List, Optional +from medperf.commands.association.utils import get_associations_list from pydantic import HttpUrl, Field import medperf.config as config -from medperf.entities.interface import Entity, Uploadable -from medperf.exceptions import CommunicationRetrievalError, InvalidArgumentError +from medperf.entities.interface import Entity from medperf.entities.schemas import MedperfSchema, ApprovableSchema, DeployableSchema from medperf.account_management import get_medperf_user_data -class Benchmark(Entity, Uploadable, MedperfSchema, ApprovableSchema, DeployableSchema): +class Benchmark(Entity, MedperfSchema, ApprovableSchema, DeployableSchema): """ Class representing a Benchmark @@ -35,6 +31,26 @@ class Benchmark(Entity, Uploadable, MedperfSchema, ApprovableSchema, DeployableS user_metadata: dict = {} is_active: bool = True + @staticmethod + def get_type(): + return "benchmark" + + @staticmethod + def get_storage_path(): + return config.benchmarks_folder + + @staticmethod + def get_comms_retriever(): + return config.comms.get_benchmark + + @staticmethod + def get_metadata_filename(): + return config.benchmarks_filename + + @staticmethod + def get_comms_uploader(): + return config.comms.upload_benchmark + def __init__(self, *args, **kwargs): """Creates a new benchmark instance @@ -44,53 +60,9 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.generated_uid = f"p{self.data_preparation_mlcube}m{self.reference_model_mlcube}e{self.data_evaluator_mlcube}" - path = config.benchmarks_folder - if self.id: - path = os.path.join(path, str(self.id)) - else: - path = os.path.join(path, self.generated_uid) - self.path = path - - @classmethod - def all(cls, local_only: bool = False, filters: dict = {}) -> List["Benchmark"]: - """Gets and creates instances of all retrievable benchmarks - - Args: - local_only (bool, optional): Wether to retrieve only local entities. Defaults to False. - filters (dict, optional): key-value pairs specifying filters to apply to the list of entities. - - Returns: - List[Benchmark]: a list of Benchmark instances. - """ - logging.info("Retrieving all benchmarks") - benchmarks = [] - - if not local_only: - benchmarks = cls.__remote_all(filters=filters) - - remote_uids = set([bmk.id for bmk in benchmarks]) - - local_benchmarks = cls.__local_all() - - benchmarks += [bmk for bmk in local_benchmarks if bmk.id not in remote_uids] - - return benchmarks @classmethod - def __remote_all(cls, filters: dict) -> List["Benchmark"]: - benchmarks = [] - try: - comms_fn = cls.__remote_prefilter(filters) - bmks_meta = comms_fn() - benchmarks = [cls(**meta) for meta in bmks_meta] - except CommunicationRetrievalError: - msg = "Couldn't retrieve all benchmarks from the server" - logging.warning(msg) - - return benchmarks - - @classmethod - def __remote_prefilter(cls, filters: dict) -> callable: + def _Entity__remote_prefilter(cls, filters: dict) -> callable: """Applies filtering logic that must be done before retrieving remote entities Args: @@ -104,104 +76,6 @@ def __remote_prefilter(cls, filters: dict) -> callable: comms_fn = config.comms.get_user_benchmarks return comms_fn - @classmethod - def __local_all(cls) -> List["Benchmark"]: - benchmarks = [] - bmks_storage = config.benchmarks_folder - try: - uids = next(os.walk(bmks_storage))[1] - except StopIteration: - msg = "Couldn't iterate over benchmarks directory" - logging.warning(msg) - raise MedperfException(msg) - - for uid in uids: - meta = cls.__get_local_dict(uid) - benchmark = cls(**meta) - benchmarks.append(benchmark) - - return benchmarks - - @classmethod - def get( - cls, benchmark_uid: Union[str, int], local_only: bool = False - ) -> "Benchmark": - """Retrieves and creates a Benchmark instance from the server. - If benchmark already exists in the platform then retrieve that - version. - - Args: - benchmark_uid (str): UID of the benchmark. - comms (Comms): Instance of a communication interface. - - Returns: - Benchmark: a Benchmark instance with the retrieved data. - """ - - if not str(benchmark_uid).isdigit() or local_only: - return cls.__local_get(benchmark_uid) - - try: - return cls.__remote_get(benchmark_uid) - except CommunicationRetrievalError: - logging.warning(f"Getting Benchmark {benchmark_uid} from comms failed") - logging.info(f"Looking for benchmark {benchmark_uid} locally") - return cls.__local_get(benchmark_uid) - - @classmethod - def __remote_get(cls, benchmark_uid: int) -> "Benchmark": - """Retrieves and creates a Dataset instance from the comms instance. - If the dataset is present in the user's machine then it retrieves it from there. - - Args: - dset_uid (str): server UID of the dataset - - Returns: - Dataset: Specified Dataset Instance - """ - logging.debug(f"Retrieving benchmark {benchmark_uid} remotely") - benchmark_dict = config.comms.get_benchmark(benchmark_uid) - benchmark = cls(**benchmark_dict) - benchmark.write() - return benchmark - - @classmethod - def __local_get(cls, benchmark_uid: Union[str, int]) -> "Benchmark": - """Retrieves and creates a Dataset instance from the comms instance. - If the dataset is present in the user's machine then it retrieves it from there. - - Args: - dset_uid (str): server UID of the dataset - - Returns: - Dataset: Specified Dataset Instance - """ - logging.debug(f"Retrieving benchmark {benchmark_uid} locally") - benchmark_dict = cls.__get_local_dict(benchmark_uid) - benchmark = cls(**benchmark_dict) - return benchmark - - @classmethod - def __get_local_dict(cls, benchmark_uid) -> dict: - """Retrieves a local benchmark information - - Args: - benchmark_uid (str): uid of the local benchmark - - Returns: - dict: information of the benchmark - """ - logging.info(f"Retrieving benchmark {benchmark_uid} from local storage") - storage = config.benchmarks_folder - bmk_storage = os.path.join(storage, str(benchmark_uid)) - bmk_file = os.path.join(bmk_storage, config.benchmarks_filename) - if not os.path.exists(bmk_file): - raise InvalidArgumentError("No benchmark with the given uid could be found") - with open(bmk_file, "r") as f: - data = yaml.safe_load(f) - - return data - @classmethod def get_models_uids(cls, benchmark_uid: int) -> List[int]: """Retrieves the list of models associated to the benchmark @@ -213,51 +87,12 @@ def get_models_uids(cls, benchmark_uid: int) -> List[int]: Returns: List[int]: List of mlcube uids """ - associations = config.comms.get_benchmark_model_associations(benchmark_uid) - models_uids = [ - assoc["model_mlcube"] - for assoc in associations - if assoc["approval_status"] == "APPROVED" - ] + associations = get_associations_list( + "benchmark", "model_mlcube", "APPROVED", experiment_id=benchmark_uid + ) + models_uids = [assoc["model_mlcube"] for assoc in associations] return models_uids - def todict(self) -> dict: - """Dictionary representation of the benchmark instance - - Returns: - dict: Dictionary containing benchmark information - """ - return self.extended_dict() - - def write(self) -> str: - """Writes the benchmark into disk - - Args: - filename (str, optional): name of the file. Defaults to config.benchmarks_filename. - - Returns: - str: path to the created benchmark file - """ - data = self.todict() - bmk_file = os.path.join(self.path, config.benchmarks_filename) - if not os.path.exists(bmk_file): - os.makedirs(self.path, exist_ok=True) - with open(bmk_file, "w") as f: - yaml.dump(data, f) - return bmk_file - - def upload(self): - """Uploads a benchmark to the server - - Args: - comms (Comms): communications entity to submit through - """ - if self.for_test: - raise InvalidArgumentError("Cannot upload test benchmarks.") - body = self.todict() - updated_body = config.comms.upload_benchmark(body) - return updated_body - def display_dict(self): return { "UID": self.identifier, diff --git a/cli/medperf/entities/ca.py b/cli/medperf/entities/ca.py new file mode 100644 index 000000000..7c945ac04 --- /dev/null +++ b/cli/medperf/entities/ca.py @@ -0,0 +1,115 @@ +import json +import os +from medperf.entities.interface import Entity +from medperf.entities.schemas import MedperfSchema +from pydantic import validator +import medperf.config as config +from medperf.account_management import get_medperf_user_data + + +class CA(Entity, MedperfSchema): + """ + Class representing a compatibility test report entry + + A test report is comprised of the components of a test execution: + - data used, which can be: + - a demo aggregator url and its hash, or + - a raw data path and its labels path, or + - a prepared aggregator uid + - Data preparation cube if the data used was not already prepared + - model cube + - evaluator cube + - results + """ + + metadata: dict = {} + client_mlcube: int + server_mlcube: int + ca_mlcube: int + config: dict + + @staticmethod + def get_type(): + return "ca" + + @staticmethod + def get_storage_path(): + return config.cas_folder + + @staticmethod + def get_comms_retriever(): + return config.comms.get_ca + + @staticmethod + def get_metadata_filename(): + return config.ca_file + + @staticmethod + def get_comms_uploader(): + return config.comms.upload_ca + + @validator("config", pre=True, always=True) + def check_config(cls, v, *, values, **kwargs): + keys = set(v.keys()) + allowed_keys = { + "address", + "port", + "fingerprint", + "client_provisioner", + "server_provisioner", + } + if keys != allowed_keys: + raise ValueError("config must contain two keys only: address and port") + return v + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.generated_uid = self.name + + self.address = self.config["address"] + self.port = self.config["port"] + self.fingerprint = self.config["fingerprint"] + self.client_provisioner = self.config["client_provisioner"] + self.server_provisioner = self.config["server_provisioner"] + + self.config_path = os.path.join(self.path, config.ca_config_file) + self.pki_assets = os.path.join(self.path, config.ca_cert_folder) + + @classmethod + def from_experiment(cls, training_exp_uid: int) -> "CA": + meta = config.comms.get_experiment_ca(training_exp_uid) + ca = cls(**meta) + ca.write() + return ca + + @classmethod + def _Entity__remote_prefilter(cls, filters: dict) -> callable: + """Applies filtering logic that must be done before retrieving remote entities + + Args: + filters (dict): filters to apply + + Returns: + callable: A function for retrieving remote entities with the applied prefilters + """ + comms_fn = config.comms.get_cas + if "owner" in filters and filters["owner"] == get_medperf_user_data()["id"]: + comms_fn = config.comms.get_user_cas + return comms_fn + + def prepare_config(self): + with open(self.config_path, "w") as f: + json.dump(self.config, f) + + def display_dict(self): + return { + "UID": self.identifier, + "Name": self.name, + "Generated Hash": self.generated_uid, + "Address": self.address, + "fingerprint": self.fingerprint, + "Port": self.port, + "Created At": self.created_at, + "Registered": self.is_registered, + } diff --git a/cli/medperf/entities/cube.py b/cli/medperf/entities/cube.py index 98d2b95a8..b3ffae880 100644 --- a/cli/medperf/entities/cube.py +++ b/cli/medperf/entities/cube.py @@ -1,7 +1,7 @@ import os import yaml import logging -from typing import List, Dict, Optional, Union +from typing import Dict, Optional, Union from pydantic import Field from pathlib import Path @@ -12,21 +12,15 @@ generate_tmp_path, spawn_and_kill, ) -from medperf.entities.interface import Entity, Uploadable +from medperf.entities.interface import Entity from medperf.entities.schemas import MedperfSchema, DeployableSchema -from medperf.exceptions import ( - InvalidArgumentError, - ExecutionError, - InvalidEntityError, - MedperfException, - CommunicationRetrievalError, -) +from medperf.exceptions import InvalidArgumentError, ExecutionError, InvalidEntityError import medperf.config as config from medperf.comms.entity_resources import resources from medperf.account_management import get_medperf_user_data -class Cube(Entity, Uploadable, MedperfSchema, DeployableSchema): +class Cube(Entity, MedperfSchema, DeployableSchema): """ Class representing an MLCube Container @@ -48,6 +42,26 @@ class Cube(Entity, Uploadable, MedperfSchema, DeployableSchema): metadata: dict = {} user_metadata: dict = {} + @staticmethod + def get_type(): + return "cube" + + @staticmethod + def get_storage_path(): + return config.cubes_folder + + @staticmethod + def get_comms_retriever(): + return config.comms.get_cube_metadata + + @staticmethod + def get_metadata_filename(): + return config.cube_metadata_filename + + @staticmethod + def get_comms_uploader(): + return config.comms.upload_mlcube + def __init__(self, *args, **kwargs): """Creates a Cube instance @@ -57,59 +71,13 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.generated_uid = self.name - path = config.cubes_folder - if self.id: - path = os.path.join(path, str(self.id)) - else: - path = os.path.join(path, self.generated_uid) - # NOTE: maybe have these as @property, to have the same entity reusable - # before and after submission - self.path = path - self.cube_path = os.path.join(path, config.cube_filename) + self.cube_path = os.path.join(self.path, config.cube_filename) self.params_path = None if self.git_parameters_url: - self.params_path = os.path.join(path, config.params_filename) - - @classmethod - def all(cls, local_only: bool = False, filters: dict = {}) -> List["Cube"]: - """Class method for retrieving all retrievable MLCubes - - Args: - local_only (bool, optional): Wether to retrieve only local entities. Defaults to False. - filters (dict, optional): key-value pairs specifying filters to apply to the list of entities. - - Returns: - List[Cube]: List containing all cubes - """ - logging.info("Retrieving all cubes") - cubes = [] - if not local_only: - cubes = cls.__remote_all(filters=filters) - - remote_uids = set([cube.id for cube in cubes]) - - local_cubes = cls.__local_all() - - cubes += [cube for cube in local_cubes if cube.id not in remote_uids] - - return cubes + self.params_path = os.path.join(self.path, config.params_filename) @classmethod - def __remote_all(cls, filters: dict) -> List["Cube"]: - cubes = [] - - try: - comms_fn = cls.__remote_prefilter(filters) - cubes_meta = comms_fn() - cubes = [cls(**meta) for meta in cubes_meta] - except CommunicationRetrievalError: - msg = "Couldn't retrieve all cubes from the server" - logging.warning(msg) - - return cubes - - @classmethod - def __remote_prefilter(cls, filters: dict): + def _Entity__remote_prefilter(cls, filters: dict): """Applies filtering logic that must be done before retrieving remote entities Args: @@ -124,25 +92,6 @@ def __remote_prefilter(cls, filters: dict): return comms_fn - @classmethod - def __local_all(cls) -> List["Cube"]: - cubes = [] - cubes_folder = config.cubes_folder - try: - uids = next(os.walk(cubes_folder))[1] - logging.debug(f"Local cubes found: {uids}") - except StopIteration: - msg = "Couldn't iterate over cubes directory" - logging.warning(msg) - raise MedperfException(msg) - - for uid in uids: - meta = cls.__get_local_dict(uid) - cube = cls(**meta) - cubes.append(cube) - - return cubes - @classmethod def get(cls, cube_uid: Union[str, int], local_only: bool = False) -> "Cube": """Retrieves and creates a Cube instance from the comms. If cube already exists @@ -155,36 +104,12 @@ def get(cls, cube_uid: Union[str, int], local_only: bool = False) -> "Cube": Cube : a Cube instance with the retrieved data. """ - if not str(cube_uid).isdigit() or local_only: - cube = cls.__local_get(cube_uid) - else: - try: - cube = cls.__remote_get(cube_uid) - except CommunicationRetrievalError: - logging.warning(f"Getting MLCube {cube_uid} from comms failed") - logging.info(f"Retrieving MLCube {cube_uid} from local storage") - cube = cls.__local_get(cube_uid) - + cube = super().get(cube_uid, local_only) if not cube.is_valid: raise InvalidEntityError("The requested MLCube is marked as INVALID.") cube.download_config_files() return cube - @classmethod - def __remote_get(cls, cube_uid: int) -> "Cube": - logging.debug(f"Retrieving mlcube {cube_uid} remotely") - meta = config.comms.get_cube_metadata(cube_uid) - cube = cls(**meta) - cube.write() - return cube - - @classmethod - def __local_get(cls, cube_uid: Union[str, int]) -> "Cube": - logging.debug(f"Retrieving cube {cube_uid} locally") - local_meta = cls.__get_local_dict(cube_uid) - cube = cls(**local_meta) - return cube - def download_mlcube(self): url = self.git_mlcube_url path, file_hash = resources.get_cube(url, self.path, self.mlcube_hash) @@ -302,10 +227,13 @@ def download_run_files(self): def run( self, task: str, - output_logs: str = None, + output_logs_file: str = None, string_params: Dict[str, str] = {}, timeout: int = None, read_protected_input: bool = True, + port=None, + publish_on=None, + env_dict: dict = {}, **kwargs, ): """Executes a given task on the cube instance @@ -318,9 +246,24 @@ def run( read_protected_input (bool, optional): Wether to disable write permissions on input volumes. Defaults to True. kwargs (dict): additional arguments that are passed directly to the mlcube command """ + # TODO: refactor this function. Move things to MLCube if possible kwargs.update(string_params) cmd = f"mlcube --log-level {config.loglevel} run" - cmd += f" --mlcube=\"{self.cube_path}\" --task={task} --platform={config.platform} --network=none" + cmd += ( + f' --mlcube="{self.cube_path}" --task={task} --platform={config.platform}' + ) + if task not in [ + "train", + "start_aggregator", + "trust", + "get_client_cert", + "get_server_cert", + "get_experiment_status", + "add_collaborator", + "remove_collaborator", + "update_plan", + ]: + cmd += " --network=none" if config.gpus is not None: cmd += f" --gpus={config.gpus}" if read_protected_input: @@ -330,6 +273,13 @@ def run( cmd = " ".join([cmd, cmd_arg]) container_loglevel = config.container_loglevel + if container_loglevel: + env_dict["MEDPERF_LOGLEVEL"] = container_loglevel.upper() + + env_args_string = "" + for key, val in env_dict.items(): + env_args_string += f"--env {key}={val} " + env_args_string = env_args_string.strip() # TODO: we should override run args instead of what we are doing below # we shouldn't allow arbitrary run args unless our client allows it @@ -339,16 +289,29 @@ def run( gpu_args = self.get_config("docker.gpu_args") or "" cpu_args = " ".join([cpu_args, "-u $(id -u):$(id -g)"]).strip() gpu_args = " ".join([gpu_args, "-u $(id -u):$(id -g)"]).strip() + if port is not None: + if publish_on: + cpu_args += f" -p {publish_on}:{port}:{port}" + gpu_args += f" -p {publish_on}:{port}:{port}" + else: + cpu_args += f" -p {port}:{port}" + gpu_args += f" -p {port}:{port}" cmd += f' -Pdocker.cpu_args="{cpu_args}"' cmd += f' -Pdocker.gpu_args="{gpu_args}"' + if env_args_string: # TODO: why MLCube UI is so brittle? + env_args = self.get_config("docker.env_args") or "" + env_args = " ".join([env_args, env_args_string]).strip() + cmd += f' -Pdocker.env_args="{env_args}"' - if container_loglevel: - cmd += f' -Pdocker.env_args="-e MEDPERF_LOGLEVEL={container_loglevel.upper()}"' elif config.platform == "singularity": # use -e to discard host env vars, -C to isolate the container (see singularity run --help) run_args = self.get_config("singularity.run_args") or "" run_args = " ".join([run_args, "-eC"]).strip() + run_args += " " + env_args_string cmd += f' -Psingularity.run_args="{run_args}"' + # TODO: check if ports are already exposed. Think if this is OK + # TODO: check about exposing to specific network interfaces + # TODO: check if --env works # set image name in case of running docker image with singularity # Assuming we only accept mlcube.yamls with either singularity or docker sections @@ -358,7 +321,6 @@ def run( cmd += ( f' -Psingularity.image="{self._converted_singularity_image_name}"' ) - # TODO: pass logging env for singularity also there else: raise InvalidArgumentError("Unsupported platform") @@ -371,8 +333,8 @@ def run( proc = proc_wrapper.proc proc_out = combine_proc_sp_text(proc) - if output_logs is not None: - with open(output_logs, "w") as f: + if output_logs_file is not None: + with open(output_logs_file, "w") as f: f.write(proc_out) if proc.exitstatus != 0: raise ExecutionError("There was an error while executing the cube") @@ -430,36 +392,6 @@ def get_config(self, identifier): return cube - def todict(self) -> Dict: - return self.extended_dict() - - def write(self): - cube_loc = str(Path(self.cube_path).parent) - meta_file = os.path.join(cube_loc, config.cube_metadata_filename) - os.makedirs(cube_loc, exist_ok=True) - with open(meta_file, "w") as f: - yaml.dump(self.todict(), f) - return meta_file - - def upload(self): - if self.for_test: - raise InvalidArgumentError("Cannot upload test mlcubes.") - cube_dict = self.todict() - updated_cube_dict = config.comms.upload_mlcube(cube_dict) - return updated_cube_dict - - @classmethod - def __get_local_dict(cls, uid): - cubes_folder = config.cubes_folder - meta_file = os.path.join(cubes_folder, str(uid), config.cube_metadata_filename) - if not os.path.exists(meta_file): - raise InvalidArgumentError( - "The requested mlcube information could not be found locally" - ) - with open(meta_file, "r") as f: - meta = yaml.safe_load(f) - return meta - def display_dict(self): return { "UID": self.identifier, diff --git a/cli/medperf/entities/dataset.py b/cli/medperf/entities/dataset.py index 4c210431f..f50e8d680 100644 --- a/cli/medperf/entities/dataset.py +++ b/cli/medperf/entities/dataset.py @@ -1,22 +1,17 @@ import os import yaml -import logging from pydantic import Field, validator -from typing import List, Optional, Union +from typing import Optional, Union from medperf.utils import remove_path -from medperf.entities.interface import Entity, Uploadable +from medperf.entities.interface import Entity from medperf.entities.schemas import MedperfSchema, DeployableSchema -from medperf.exceptions import ( - InvalidArgumentError, - MedperfException, - CommunicationRetrievalError, -) + import medperf.config as config from medperf.account_management import get_medperf_user_data -class Dataset(Entity, Uploadable, MedperfSchema, DeployableSchema): +class Dataset(Entity, MedperfSchema, DeployableSchema): """ Class representing a Dataset @@ -37,6 +32,26 @@ class Dataset(Entity, Uploadable, MedperfSchema, DeployableSchema): report: dict = {} submitted_as_prepared: bool + @staticmethod + def get_type(): + return "dataset" + + @staticmethod + def get_storage_path(): + return config.datasets_folder + + @staticmethod + def get_comms_retriever(): + return config.comms.get_dataset + + @staticmethod + def get_metadata_filename(): + return config.reg_file + + @staticmethod + def get_comms_uploader(): + return config.comms.upload_dataset + @validator("data_preparation_mlcube", pre=True, always=True) def check_data_preparation_mlcube(cls, v, *, values, **kwargs): if not isinstance(v, int) and not values["for_test"]: @@ -48,13 +63,6 @@ def check_data_preparation_mlcube(cls, v, *, values, **kwargs): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - path = config.datasets_folder - if self.id: - path = os.path.join(path, str(self.id)) - else: - path = os.path.join(path, self.generated_uid) - - self.path = path self.data_path = os.path.join(self.path, "data") self.labels_path = os.path.join(self.path, "labels") self.report_path = os.path.join(self.path, config.report_file) @@ -86,48 +94,8 @@ def is_ready(self): flag_file = os.path.join(self.path, config.ready_flag_file) return os.path.exists(flag_file) - def todict(self): - return self.extended_dict() - - @classmethod - def all(cls, local_only: bool = False, filters: dict = {}) -> List["Dataset"]: - """Gets and creates instances of all the locally prepared datasets - - Args: - local_only (bool, optional): Wether to retrieve only local entities. Defaults to False. - filters (dict, optional): key-value pairs specifying filters to apply to the list of entities. - - Returns: - List[Dataset]: a list of Dataset instances. - """ - logging.info("Retrieving all datasets") - dsets = [] - if not local_only: - dsets = cls.__remote_all(filters=filters) - - remote_uids = set([dset.id for dset in dsets]) - - local_dsets = cls.__local_all() - - dsets += [dset for dset in local_dsets if dset.id not in remote_uids] - - return dsets - - @classmethod - def __remote_all(cls, filters: dict) -> List["Dataset"]: - dsets = [] - try: - comms_fn = cls.__remote_prefilter(filters) - dsets_meta = comms_fn() - dsets = [cls(**meta) for meta in dsets_meta] - except CommunicationRetrievalError: - msg = "Couldn't retrieve all datasets from the server" - logging.warning(msg) - - return dsets - @classmethod - def __remote_prefilter(cls, filters: dict) -> callable: + def _Entity__remote_prefilter(cls, filters: dict) -> callable: """Applies filtering logic that must be done before retrieving remote entities Args: @@ -149,111 +117,6 @@ def func(): return comms_fn - @classmethod - def __local_all(cls) -> List["Dataset"]: - dsets = [] - datasets_folder = config.datasets_folder - try: - uids = next(os.walk(datasets_folder))[1] - except StopIteration: - msg = "Couldn't iterate over the dataset directory" - logging.warning(msg) - raise MedperfException(msg) - - for uid in uids: - local_meta = cls.__get_local_dict(uid) - dset = cls(**local_meta) - dsets.append(dset) - - return dsets - - @classmethod - def get(cls, dset_uid: Union[str, int], local_only: bool = False) -> "Dataset": - """Retrieves and creates a Dataset instance from the comms instance. - If the dataset is present in the user's machine then it retrieves it from there. - - Args: - dset_uid (str): server UID of the dataset - - Returns: - Dataset: Specified Dataset Instance - """ - if not str(dset_uid).isdigit() or local_only: - return cls.__local_get(dset_uid) - - try: - return cls.__remote_get(dset_uid) - except CommunicationRetrievalError: - logging.warning(f"Getting Dataset {dset_uid} from comms failed") - logging.info(f"Looking for dataset {dset_uid} locally") - return cls.__local_get(dset_uid) - - @classmethod - def __remote_get(cls, dset_uid: int) -> "Dataset": - """Retrieves and creates a Dataset instance from the comms instance. - If the dataset is present in the user's machine then it retrieves it from there. - - Args: - dset_uid (str): server UID of the dataset - - Returns: - Dataset: Specified Dataset Instance - """ - logging.debug(f"Retrieving dataset {dset_uid} remotely") - meta = config.comms.get_dataset(dset_uid) - dataset = cls(**meta) - dataset.write() - return dataset - - @classmethod - def __local_get(cls, dset_uid: Union[str, int]) -> "Dataset": - """Retrieves and creates a Dataset instance from the comms instance. - If the dataset is present in the user's machine then it retrieves it from there. - - Args: - dset_uid (str): server UID of the dataset - - Returns: - Dataset: Specified Dataset Instance - """ - logging.debug(f"Retrieving dataset {dset_uid} locally") - local_meta = cls.__get_local_dict(dset_uid) - dataset = cls(**local_meta) - return dataset - - def write(self): - logging.info(f"Updating registration information for dataset: {self.id}") - logging.debug(f"registration information: {self.todict()}") - regfile = os.path.join(self.path, config.reg_file) - os.makedirs(self.path, exist_ok=True) - with open(regfile, "w") as f: - yaml.dump(self.todict(), f) - return regfile - - def upload(self): - """Uploads the registration information to the comms. - - Args: - comms (Comms): Instance of the comms interface. - """ - if self.for_test: - raise InvalidArgumentError("Cannot upload test datasets.") - dataset_dict = self.todict() - updated_dataset_dict = config.comms.upload_dataset(dataset_dict) - return updated_dataset_dict - - @classmethod - def __get_local_dict(cls, data_uid): - dataset_path = os.path.join(config.datasets_folder, str(data_uid)) - regfile = os.path.join(dataset_path, config.reg_file) - if not os.path.exists(regfile): - raise InvalidArgumentError( - "The requested dataset information could not be found locally" - ) - with open(regfile, "r") as f: - reg = yaml.safe_load(f) - return reg - def display_dict(self): return { "UID": self.identifier, diff --git a/cli/medperf/entities/event.py b/cli/medperf/entities/event.py new file mode 100644 index 000000000..cd528524a --- /dev/null +++ b/cli/medperf/entities/event.py @@ -0,0 +1,110 @@ +from datetime import datetime +import os +from typing import Optional +from medperf.entities.interface import Entity +import medperf.config as config +from medperf.entities.schemas import MedperfSchema +from medperf.account_management import get_medperf_user_data +import yaml + + +class TrainingEvent(Entity, MedperfSchema): + """ + Class representing a compatibility test report entry + + A test report is comprised of the components of a test execution: + - data used, which can be: + - a demo aggregator url and its hash, or + - a raw data path and its labels path, or + - a prepared aggregator uid + - Data preparation cube if the data used was not already prepared + - model cube + - evaluator cube + - results + """ + + training_exp: int + participants: dict + finished: bool = False + finished_at: Optional[datetime] + report: Optional[dict] + + @staticmethod + def get_type(): + return "training event" + + @staticmethod + def get_storage_path(): + return config.training_events_folder + + @staticmethod + def get_comms_retriever(): + return config.comms.get_training_event + + @staticmethod + def get_metadata_filename(): + return config.training_event_file + + @staticmethod + def get_comms_uploader(): + return config.comms.upload_training_event + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.generated_uid = self.name + self.participants_list_path = os.path.join( + self.path, config.participants_list_filename + ) + timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f") + self.agg_out_logs = os.path.join( + self.path, config.training_out_agg_logs + timestamp + ) + self.col_out_logs = os.path.join(self.path, config.training_out_col_logs) + self.out_weights = os.path.join( + self.path, config.training_out_weights + timestamp + ) + self.report_path = os.path.join( + self.path, + config.training_report_folder + timestamp, + config.training_report_file, + ) + + @classmethod + def from_experiment(cls, training_exp_uid: int) -> "TrainingEvent": + meta = config.comms.get_experiment_event(training_exp_uid) + ca = cls(**meta) + ca.write() + return ca + + @classmethod + def _Entity__remote_prefilter(cls, filters: dict) -> callable: + """Applies filtering logic that must be done before retrieving remote entities + + Args: + filters (dict): filters to apply + + Returns: + callable: A function for retrieving remote entities with the applied prefilters + """ + comms_fn = config.comms.get_training_events + if "owner" in filters and filters["owner"] == get_medperf_user_data()["id"]: + comms_fn = config.comms.get_user_training_events + return comms_fn + + def prepare_participants_list(self): + with open(self.participants_list_path, "w") as f: + yaml.dump(self.participants, f) + + def display_dict(self): + return { + "UID": self.identifier, + "Name": self.name, + "Experiment": self.training_exp, + "Generated Hash": self.generated_uid, + "Participants": self.participants, + "Created At": self.created_at, + "Registered": self.is_registered, + "Finished": self.finished, + "Report": self.report, + } diff --git a/cli/medperf/entities/interface.py b/cli/medperf/entities/interface.py index af2afabd7..7a5f0b5ef 100644 --- a/cli/medperf/entities/interface.py +++ b/cli/medperf/entities/interface.py @@ -1,77 +1,215 @@ from typing import List, Dict, Union -from abc import ABC, abstractmethod +from abc import ABC +import logging +import os +import yaml +from medperf.exceptions import MedperfException, InvalidArgumentError +from medperf.entities.schemas import MedperfBaseSchema -class Entity(ABC): - @abstractmethod - def all( - cls, local_only: bool = False, comms_func: callable = None - ) -> List["Entity"]: +class Entity(MedperfBaseSchema, ABC): + @staticmethod + def get_type(): + raise NotImplementedError() + + @staticmethod + def get_storage_path(): + raise NotImplementedError() + + @staticmethod + def get_comms_retriever(): + raise NotImplementedError() + + @staticmethod + def get_metadata_filename(): + raise NotImplementedError() + + @staticmethod + def get_comms_uploader(): + raise NotImplementedError() + + @property + def identifier(self): + return self.id or self.generated_uid + + @property + def is_registered(self): + return self.id is not None + + @property + def path(self): + storage_path = self.get_storage_path() + return os.path.join(storage_path, str(self.identifier)) + + @classmethod + def all(cls, unregistered: bool = False, filters: dict = {}) -> List["Entity"]: """Gets a list of all instances of the respective entity. - Wether the list is local or remote depends on the implementation. + Whether the list is local or remote depends on the implementation. Args: - local_only (bool, optional): Wether to retrieve only local entities. Defaults to False. - comms_func (callable, optional): Function to use to retrieve remote entities. - If not provided, will use the default entrypoint. + unregistered (bool, optional): Wether to retrieve only unregistered local entities. Defaults to False. + filters (dict, optional): key-value pairs specifying filters to apply to the list of entities. + Returns: List[Entity]: a list of entities. """ + logging.info(f"Retrieving all {cls.get_type()} entities") + if unregistered: + if filters: + raise InvalidArgumentError( + "Filtering is not supported for unregistered entities" + ) + return cls.__unregistered_all() + return cls.__remote_all(filters=filters) + + @classmethod + def __remote_all(cls, filters: dict) -> List["Entity"]: + comms_fn = cls.__remote_prefilter(filters) + entity_meta = comms_fn() + entities = [cls(**meta) for meta in entity_meta] + return entities + + @classmethod + def __unregistered_all(cls) -> List["Entity"]: + entities = [] + storage_path = cls.get_storage_path() + try: + uids = next(os.walk(storage_path))[1] + except StopIteration: + msg = f"Couldn't iterate over the {cls.get_type()} storage" + logging.warning(msg) + raise MedperfException(msg) + + for uid in uids: + if uid.isdigit(): + continue + meta = cls.__get_local_dict(uid) + entity = cls(**meta) + entities.append(entity) + + return entities + + @classmethod + def __remote_prefilter(cls, filters: dict) -> callable: + """Applies filtering logic that must be done before retrieving remote entities + + Args: + filters (dict): filters to apply + + Returns: + callable: A function for retrieving remote entities with the applied prefilters + """ + raise NotImplementedError - @abstractmethod - def get(cls, uid: Union[str, int]) -> "Entity": + @classmethod + def get(cls, uid: Union[str, int], local_only: bool = False) -> "Entity": """Gets an instance of the respective entity. Wether this requires only local read or remote calls depends on the implementation. Args: uid (str): Unique Identifier to retrieve the entity + local_only (bool): If True, the entity will be retrieved locally Returns: Entity: Entity Instance associated to the UID """ - @abstractmethod - def todict(self) -> Dict: - """Dictionary representation of the entity + if not str(uid).isdigit() or local_only: + return cls.__local_get(uid) + return cls.__remote_get(uid) + + @classmethod + def __remote_get(cls, uid: int) -> "Entity": + """Retrieves and creates an entity instance from the comms instance. + + Args: + uid (int): server UID of the entity Returns: - Dict: Dictionary containing information about the entity + Entity: Specified Entity Instance """ + logging.debug(f"Retrieving {cls.get_type()} {uid} remotely") + comms_func = cls.get_comms_retriever() + entity_dict = comms_func(uid) + entity = cls(**entity_dict) + entity.write() + return entity - @abstractmethod - def write(self) -> str: - """Writes the entity to the local storage + @classmethod + def __local_get(cls, uid: Union[str, int]) -> "Entity": + """Retrieves and creates an entity instance from the local storage. + + Args: + uid (str|int): UID of the entity Returns: - str: Path to the stored entity + Entity: Specified Entity Instance """ + logging.debug(f"Retrieving {cls.get_type()} {uid} locally") + entity_dict = cls.__get_local_dict(uid) + entity = cls(**entity_dict) + return entity - @abstractmethod - def display_dict(self) -> dict: - """Returns a dictionary of entity properties that can be displayed - to a user interface using a verbose name of the property rather than - the internal names + @classmethod + def __get_local_dict(cls, uid: Union[str, int]) -> dict: + """Retrieves a local entity information + + Args: + uid (str): uid of the local entity Returns: - dict: the display dictionary + dict: information of the entity """ + logging.info(f"Retrieving {cls.get_type()} {uid} from local storage") + storage_path = cls.get_storage_path() + metadata_filename = cls.get_metadata_filename() + bmk_file = os.path.join(storage_path, str(uid), metadata_filename) + if not os.path.exists(bmk_file): + raise InvalidArgumentError( + f"No {cls.get_type()} with the given uid could be found" + ) + with open(bmk_file, "r") as f: + data = yaml.safe_load(f) + + return data + + def write(self) -> str: + """Writes the entity to the local storage + Returns: + str: Path to the stored entity + """ + data = self.todict() + metadata_filename = self.get_metadata_filename() + entity_file = os.path.join(self.path, metadata_filename) + os.makedirs(self.path, exist_ok=True) + with open(entity_file, "w") as f: + yaml.dump(data, f) + return entity_file -class Uploadable: - @abstractmethod def upload(self) -> Dict: """Upload the entity-related information to the communication's interface Returns: Dict: Dictionary with the updated entity information """ + if self.for_test: + raise InvalidArgumentError( + f"This test {self.get_type()} cannot be uploaded." + ) + body = self.todict() + comms_func = self.get_comms_uploader() + updated_body = comms_func(body) + return updated_body - @property - def identifier(self): - return self.id or self.generated_uid + def display_dict(self) -> dict: + """Returns a dictionary of entity properties that can be displayed + to a user interface using a verbose name of the property rather than + the internal names - @property - def is_registered(self): - return self.id is not None + Returns: + dict: the display dictionary + """ + raise NotImplementedError diff --git a/cli/medperf/entities/report.py b/cli/medperf/entities/report.py index c76f09894..65147e558 100644 --- a/cli/medperf/entities/report.py +++ b/cli/medperf/entities/report.py @@ -1,16 +1,11 @@ import hashlib -import os -import yaml -import logging from typing import List, Union, Optional -from medperf.entities.schemas import MedperfBaseSchema import medperf.config as config -from medperf.exceptions import InvalidArgumentError from medperf.entities.interface import Entity -class TestReport(Entity, MedperfBaseSchema): +class TestReport(Entity): """ Class representing a compatibility test report entry @@ -35,11 +30,23 @@ class TestReport(Entity, MedperfBaseSchema): data_evaluator_mlcube: Union[int, str] results: Optional[dict] + @staticmethod + def get_type(): + return "report" + + @staticmethod + def get_storage_path(): + return config.tests_folder + + @staticmethod + def get_metadata_filename(): + return config.test_report_file + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.id = None + self.for_test = True self.generated_uid = self.__generate_uid() - path = config.tests_folder - self.path = os.path.join(path, self.generated_uid) def __generate_uid(self): """A helper that generates a unique hash for a test report.""" @@ -52,71 +59,14 @@ def set_results(self, results): self.results = results @classmethod - def all( - cls, local_only: bool = False, mine_only: bool = False - ) -> List["TestReport"]: - """Gets and creates instances of test reports. - Arguments are only specified for compatibility with - `Entity.List` and `Entity.View`, but they don't contribute to - the logic. - - Returns: - List[TestReport]: List containing all test reports - """ - logging.info("Retrieving all reports") - reports = [] - tests_folder = config.tests_folder - try: - uids = next(os.walk(tests_folder))[1] - except StopIteration: - msg = "Couldn't iterate over the tests directory" - logging.warning(msg) - raise RuntimeError(msg) - - for uid in uids: - local_meta = cls.__get_local_dict(uid) - report = cls(**local_meta) - reports.append(report) - - return reports - - @classmethod - def get(cls, report_uid: str) -> "TestReport": - """Retrieves and creates a TestReport instance obtained the user's machine - - Args: - report_uid (str): UID of the TestReport instance - - Returns: - TestReport: Specified TestReport instance - """ - logging.debug(f"Retrieving report {report_uid}") - report_dict = cls.__get_local_dict(report_uid) - report = cls(**report_dict) - report.write() - return report - - def todict(self): - return self.extended_dict() - - def write(self): - report_file = os.path.join(self.path, config.test_report_file) - os.makedirs(self.path, exist_ok=True) - with open(report_file, "w") as f: - yaml.dump(self.todict(), f) - return report_file + def all(cls, unregistered: bool = False, filters: dict = {}) -> List["Entity"]: + assert unregistered, "Reports are only unregistered" + assert filters == {}, "Reports cannot be filtered" + return super().all(unregistered=True, filters={}) @classmethod - def __get_local_dict(cls, local_uid): - report_path = os.path.join(config.tests_folder, str(local_uid)) - report_file = os.path.join(report_path, config.test_report_file) - if not os.path.exists(report_file): - raise InvalidArgumentError( - f"The requested report {local_uid} could not be retrieved" - ) - with open(report_file, "r") as f: - report_info = yaml.safe_load(f) - return report_info + def get(cls, report_uid: str, local_only: bool = False) -> "TestReport": + return super().get(report_uid, local_only=True) def display_dict(self): if self.data_path: diff --git a/cli/medperf/entities/result.py b/cli/medperf/entities/result.py index c82add87b..af4098521 100644 --- a/cli/medperf/entities/result.py +++ b/cli/medperf/entities/result.py @@ -1,16 +1,10 @@ -import os -import yaml -import logging -from typing import List, Union - -from medperf.entities.interface import Entity, Uploadable +from medperf.entities.interface import Entity from medperf.entities.schemas import MedperfSchema, ApprovableSchema import medperf.config as config -from medperf.exceptions import CommunicationRetrievalError, InvalidArgumentError from medperf.account_management import get_medperf_user_data -class Result(Entity, Uploadable, MedperfSchema, ApprovableSchema): +class Result(Entity, MedperfSchema, ApprovableSchema): """ Class representing a Result entry @@ -28,59 +22,34 @@ class Result(Entity, Uploadable, MedperfSchema, ApprovableSchema): metadata: dict = {} user_metadata: dict = {} - def __init__(self, *args, **kwargs): - """Creates a new result instance""" - super().__init__(*args, **kwargs) - - self.generated_uid = f"b{self.benchmark}m{self.model}d{self.dataset}" - path = config.results_folder - if self.id: - path = os.path.join(path, str(self.id)) - else: - path = os.path.join(path, self.generated_uid) - - self.path = path - - @classmethod - def all(cls, local_only: bool = False, filters: dict = {}) -> List["Result"]: - """Gets and creates instances of all the user's results - - Args: - local_only (bool, optional): Wether to retrieve only local entities. Defaults to False. - filters (dict, optional): key-value pairs specifying filters to apply to the list of entities. - - Returns: - List[Result]: List containing all results - """ - logging.info("Retrieving all results") - results = [] - if not local_only: - results = cls.__remote_all(filters=filters) - - remote_uids = set([result.id for result in results]) + @staticmethod + def get_type(): + return "result" - local_results = cls.__local_all() + @staticmethod + def get_storage_path(): + return config.results_folder - results += [res for res in local_results if res.id not in remote_uids] + @staticmethod + def get_comms_retriever(): + return config.comms.get_result - return results + @staticmethod + def get_metadata_filename(): + return config.results_info_file - @classmethod - def __remote_all(cls, filters: dict) -> List["Result"]: - results = [] + @staticmethod + def get_comms_uploader(): + return config.comms.upload_result - try: - comms_fn = cls.__remote_prefilter(filters) - results_meta = comms_fn() - results = [cls(**meta) for meta in results_meta] - except CommunicationRetrievalError: - msg = "Couldn't retrieve all results from the server" - logging.warning(msg) + def __init__(self, *args, **kwargs): + """Creates a new result instance""" + super().__init__(*args, **kwargs) - return results + self.generated_uid = f"b{self.benchmark}m{self.model}d{self.dataset}" @classmethod - def __remote_prefilter(cls, filters: dict) -> callable: + def _Entity__remote_prefilter(cls, filters: dict) -> callable: """Applies filtering logic that must be done before retrieving remote entities Args: @@ -104,113 +73,6 @@ def get_benchmark_results(): return comms_fn - @classmethod - def __local_all(cls) -> List["Result"]: - results = [] - results_folder = config.results_folder - try: - uids = next(os.walk(results_folder))[1] - except StopIteration: - msg = "Couldn't iterate over the dataset directory" - logging.warning(msg) - raise RuntimeError(msg) - - for uid in uids: - local_meta = cls.__get_local_dict(uid) - result = cls(**local_meta) - results.append(result) - - return results - - @classmethod - def get(cls, result_uid: Union[str, int], local_only: bool = False) -> "Result": - """Retrieves and creates a Result instance obtained from the platform. - If the result instance already exists in the user's machine, it loads - the local instance - - Args: - result_uid (str): UID of the Result instance - - Returns: - Result: Specified Result instance - """ - if not str(result_uid).isdigit() or local_only: - return cls.__local_get(result_uid) - - try: - return cls.__remote_get(result_uid) - except CommunicationRetrievalError: - logging.warning(f"Getting Result {result_uid} from comms failed") - logging.info(f"Looking for result {result_uid} locally") - return cls.__local_get(result_uid) - - @classmethod - def __remote_get(cls, result_uid: int) -> "Result": - """Retrieves and creates a Dataset instance from the comms instance. - If the dataset is present in the user's machine then it retrieves it from there. - - Args: - result_uid (str): server UID of the dataset - - Returns: - Dataset: Specified Dataset Instance - """ - logging.debug(f"Retrieving result {result_uid} remotely") - meta = config.comms.get_result(result_uid) - result = cls(**meta) - result.write() - return result - - @classmethod - def __local_get(cls, result_uid: Union[str, int]) -> "Result": - """Retrieves and creates a Dataset instance from the comms instance. - If the dataset is present in the user's machine then it retrieves it from there. - - Args: - result_uid (str): server UID of the dataset - - Returns: - Dataset: Specified Dataset Instance - """ - logging.debug(f"Retrieving result {result_uid} locally") - local_meta = cls.__get_local_dict(result_uid) - result = cls(**local_meta) - return result - - def todict(self): - return self.extended_dict() - - def upload(self): - """Uploads the results to the comms - - Args: - comms (Comms): Instance of the communications interface. - """ - if self.for_test: - raise InvalidArgumentError("Cannot upload test results.") - results_info = self.todict() - updated_results_info = config.comms.upload_result(results_info) - return updated_results_info - - def write(self): - result_file = os.path.join(self.path, config.results_info_file) - os.makedirs(self.path, exist_ok=True) - with open(result_file, "w") as f: - yaml.dump(self.todict(), f) - return result_file - - @classmethod - def __get_local_dict(cls, local_uid): - result_path = os.path.join(config.results_folder, str(local_uid)) - result_file = os.path.join(result_path, config.results_info_file) - if not os.path.exists(result_file): - raise InvalidArgumentError( - f"The requested result {local_uid} could not be retrieved" - ) - with open(result_file, "r") as f: - results_info = yaml.safe_load(f) - return results_info - def display_dict(self): return { "UID": self.identifier, diff --git a/cli/medperf/entities/schemas.py b/cli/medperf/entities/schemas.py index 0e7a54291..cac3d3a01 100644 --- a/cli/medperf/entities/schemas.py +++ b/cli/medperf/entities/schemas.py @@ -46,7 +46,7 @@ def dict(self, *args, **kwargs) -> dict: out_dict = {k: v for k, v in model_dict.items() if k in valid_fields} return out_dict - def extended_dict(self) -> dict: + def todict(self) -> dict: """Dictionary containing both original and alias fields Returns: @@ -74,7 +74,7 @@ class Config: use_enum_values = True -class MedperfSchema(MedperfBaseSchema): +class MedperfSchema(BaseModel): for_test: bool = False id: Optional[int] name: str = Field(..., max_length=64) diff --git a/cli/medperf/entities/training_exp.py b/cli/medperf/entities/training_exp.py new file mode 100644 index 000000000..d9d44c385 --- /dev/null +++ b/cli/medperf/entities/training_exp.py @@ -0,0 +1,132 @@ +import os +from medperf.commands.association.utils import get_associations_list +import yaml +from typing import List, Optional +from pydantic import HttpUrl, Field + +import medperf.config as config +from medperf.entities.interface import Entity +from medperf.entities.schemas import MedperfSchema, ApprovableSchema, DeployableSchema +from medperf.account_management import get_medperf_user_data + + +class TrainingExp(Entity, MedperfSchema, ApprovableSchema, DeployableSchema): + """ + Class representing a TrainingExp + + a training_exp is a bundle of assets that enables quantitative + measurement of the performance of AI models for a specific + clinical problem. A TrainingExp instance contains information + regarding how to prepare datasets for execution, as well as + what models to run and how to evaluate them. + """ + + description: Optional[str] = Field(None, max_length=20) + docs_url: Optional[HttpUrl] + demo_dataset_tarball_url: str + demo_dataset_tarball_hash: str + demo_dataset_generated_uid: str + data_preparation_mlcube: int + fl_mlcube: int + fl_admin_mlcube: Optional[int] + plan: dict = {} + metadata: dict = {} + user_metadata: dict = {} + + @staticmethod + def get_type(): + return "training experiment" + + @staticmethod + def get_storage_path(): + return config.training_folder + + @staticmethod + def get_comms_retriever(): + return config.comms.get_training_exp + + @staticmethod + def get_metadata_filename(): + return config.training_exps_filename + + @staticmethod + def get_comms_uploader(): + return config.comms.upload_training_exp + + def __init__(self, *args, **kwargs): + """Creates a new training_exp instance + + Args: + training_exp_desc (Union[dict, TrainingExpModel]): TrainingExp instance description + """ + super().__init__(*args, **kwargs) + + self.generated_uid = self.name + self.plan_path = os.path.join(self.path, config.training_exp_plan_filename) + self.status_path = os.path.join(self.path, config.training_exp_status_filename) + + @classmethod + def _Entity__remote_prefilter(cls, filters: dict) -> callable: + """Applies filtering logic that must be done before retrieving remote entities + + Args: + filters (dict): filters to apply + + Returns: + callable: A function for retrieving remote entities with the applied prefilters + """ + comms_fn = config.comms.get_training_exps + if "owner" in filters and filters["owner"] == get_medperf_user_data()["id"]: + comms_fn = config.comms.get_user_training_exps + return comms_fn + + @classmethod + def get_datasets_uids(cls, training_exp_uid: int) -> List[int]: + """Retrieves the list of models associated to the training_exp + + Args: + training_exp_uid (int): UID of the training_exp. + comms (Comms): Instance of the communications interface. + + Returns: + List[int]: List of mlcube uids + """ + associations = get_associations_list( + "training_exp", "dataset", "APPROVED", experiment_id=training_exp_uid + ) + datasets_uids = [assoc["dataset"] for assoc in associations] + return datasets_uids + + @classmethod + def get_datasets_with_users(cls, training_exp_uid: int) -> List[int]: + """Retrieves the list of models associated to the training_exp + + Args: + training_exp_uid (int): UID of the training_exp. + comms (Comms): Instance of the communications interface. + + Returns: + List[int]: List of mlcube uids + """ + uids_with_users = config.comms.get_training_datasets_with_users( + training_exp_uid + ) + return uids_with_users + + def prepare_plan(self): + with open(self.plan_path, "w") as f: + yaml.dump(self.plan, f) + + def display_dict(self): + return { + "UID": self.identifier, + "Name": self.name, + "Description": self.description, + "Documentation": self.docs_url, + "Created At": self.created_at, + "FL MLCube": int(self.fl_mlcube), + "Plan": self.plan, + "State": self.state, + "Registered": self.is_registered, + "Approval Status": self.approval_status, + } diff --git a/cli/medperf/storage/__init__.py b/cli/medperf/storage/__init__.py index acebb6f31..59097b2ff 100644 --- a/cli/medperf/storage/__init__.py +++ b/cli/medperf/storage/__init__.py @@ -2,7 +2,7 @@ import shutil from medperf import config -from medperf.config_management import read_config, write_config +from medperf.config_management import read_config, write_config, ConfigManager from .utils import full_folder_path @@ -19,12 +19,7 @@ def init_storage(): os.makedirs(folder, exist_ok=True) -def apply_configuration_migrations(): - if not os.path.exists(config.config_path): - return - - config_p = read_config() - +def __apply_logs_migrations(config_p: ConfigManager): if "logs_folder" not in config_p.storage: return @@ -35,4 +30,27 @@ def apply_configuration_migrations(): del config_p.storage["logs_folder"] + +def __apply_training_migrations(config_p: ConfigManager): + + for folder in [ + "aggregators_folder", + "cas_folder", + "training_events_folder", + "training_folder", + ]: + if folder not in config_p.storage: + # Assuming for now all folders are always moved together + # I used here "benchmarks_folder" arbitrarily + config_p.storage[folder] = config_p.storage["benchmarks_folder"] + + +def apply_configuration_migrations(): + if not os.path.exists(config.config_path): + return + + config_p = read_config() + __apply_logs_migrations(config_p) + __apply_training_migrations(config_p) + write_config(config_p) diff --git a/cli/medperf/tests/commands/association/test_approve.py b/cli/medperf/tests/commands/association/test_approve.py index 23c50c721..490351490 100644 --- a/cli/medperf/tests/commands/association/test_approve.py +++ b/cli/medperf/tests/commands/association/test_approve.py @@ -25,7 +25,7 @@ def test_run_fails_if_invalid_arguments(mocker, comms, ui, dset_uid, mlcube_uid) @pytest.mark.parametrize("status", [Status.APPROVED, Status.REJECTED]) def test_run_calls_comms_dset_approval_with_status(mocker, comms, ui, dset_uid, status): # Arrange - spy = mocker.patch.object(comms, "set_dataset_association_approval") + spy = mocker.patch.object(comms, "update_benchmark_dataset_association") # Act Approval.run(1, status, dataset_uid=dset_uid) @@ -40,7 +40,7 @@ def test_run_calls_comms_mlcube_approval_with_status( mocker, comms, ui, mlcube_uid, status ): # Arrange - spy = mocker.patch.object(comms, "set_mlcube_association_approval") + spy = mocker.patch.object(comms, "update_benchmark_model_association") # Act Approval.run(1, status, mlcube_uid=mlcube_uid) diff --git a/cli/medperf/tests/commands/association/test_priority.py b/cli/medperf/tests/commands/association/test_priority.py index 8d7a70392..81b67e602 100644 --- a/cli/medperf/tests/commands/association/test_priority.py +++ b/cli/medperf/tests/commands/association/test_priority.py @@ -36,7 +36,7 @@ def setup_comms(mocker, comms, associations): ) mocker.patch.object( comms, - "set_mlcube_association_priority", + "update_benchmark_model_association", side_effect=set_priority_behavior(associations), ) diff --git a/cli/medperf/tests/commands/benchmark/test_associate.py b/cli/medperf/tests/commands/benchmark/test_associate.py index 461a968f8..92c5b704f 100644 --- a/cli/medperf/tests/commands/benchmark/test_associate.py +++ b/cli/medperf/tests/commands/benchmark/test_associate.py @@ -11,8 +11,8 @@ def test_run_fails_if_model_and_dset_passed(mocker, model_uid, data_uid, comms, ui): # Arrange num_arguments = int(data_uid is None) + int(model_uid is None) - mocker.patch.object(comms, "associate_cube") - mocker.patch.object(comms, "associate_dset") + mocker.patch.object(comms, "associate_benchmark_model") + mocker.patch.object(comms, "associate_benchmark_dataset") mocker.patch(PATCH_ASSOC.format("AssociateCube.run")) mocker.patch(PATCH_ASSOC.format("AssociateDataset.run")) diff --git a/cli/medperf/tests/commands/dataset/test_associate.py b/cli/medperf/tests/commands/dataset/test_associate.py index 647eb6a13..8f1b2c1e7 100644 --- a/cli/medperf/tests/commands/dataset/test_associate.py +++ b/cli/medperf/tests/commands/dataset/test_associate.py @@ -71,7 +71,7 @@ def test_associates_if_approved( ): # Arrange result = TestResult() - assoc_func = "associate_dset" + assoc_func = "associate_benchmark_dataset" mocker.patch(PATCH_ASSOC.format("approval_prompt"), return_value=True) exec_ret = [result] mocker.patch(PATCH_ASSOC.format("BenchmarkExecution.run"), return_value=exec_ret) @@ -93,7 +93,7 @@ def test_stops_if_not_approved(mocker, comms, ui, dataset, benchmark): exec_ret = [result] mocker.patch(PATCH_ASSOC.format("BenchmarkExecution.run"), return_value=exec_ret) spy = mocker.patch(PATCH_ASSOC.format("approval_prompt"), return_value=False) - assoc_spy = mocker.patch.object(comms, "associate_dset") + assoc_spy = mocker.patch.object(comms, "associate_benchmark_dataset") # Act AssociateDataset.run(1, 1) @@ -110,7 +110,7 @@ def test_associate_calls_allows_cache_by_default(mocker, comms, ui, dataset, ben result = TestResult() data_uid = 1562 benchmark_uid = 3557 - assoc_func = "associate_dset" + assoc_func = "associate_benchmark_dataset" mocker.patch(PATCH_ASSOC.format("approval_prompt"), return_value=True) exec_ret = [result] spy = mocker.patch( diff --git a/cli/medperf/tests/commands/mlcube/test_associate.py b/cli/medperf/tests/commands/mlcube/test_associate.py index cf72ab574..ab8743f12 100644 --- a/cli/medperf/tests/commands/mlcube/test_associate.py +++ b/cli/medperf/tests/commands/mlcube/test_associate.py @@ -30,7 +30,7 @@ def test_run_associates_cube_with_comms( mocker, cube, benchmark, cube_uid, benchmark_uid, comms, ui ): # Arrange - spy = mocker.patch.object(comms, "associate_cube") + spy = mocker.patch.object(comms, "associate_benchmark_model") comp_ret = ("", {}) mocker.patch.object(ui, "prompt", return_value="y") mocker.patch( @@ -70,7 +70,7 @@ def test_stops_if_not_approved(mocker, comms, ui, cube, benchmark): PATCH_ASSOC.format("CompatibilityTestExecution.run"), return_value=comp_ret ) spy = mocker.patch(PATCH_ASSOC.format("approval_prompt"), return_value=False) - assoc_spy = mocker.patch.object(comms, "associate_cube") + assoc_spy = mocker.patch.object(comms, "associate_benchmark_model") # Act AssociateCube.run(1, 1) diff --git a/cli/medperf/tests/commands/result/test_create.py b/cli/medperf/tests/commands/result/test_create.py index 74299c77e..c69544781 100644 --- a/cli/medperf/tests/commands/result/test_create.py +++ b/cli/medperf/tests/commands/result/test_create.py @@ -57,6 +57,9 @@ def mock_result_all(mocker, state_variables): TestResult(benchmark=triplet[0], model=triplet[1], dataset=triplet[2]) for triplet in cached_results_triplets ] + mocker.patch( + PATCH_EXECUTION.format("get_medperf_user_data", return_value={"id": 1}) + ) mocker.patch(PATCH_EXECUTION.format("Result.all"), return_value=results) diff --git a/cli/medperf/tests/commands/test_execution.py b/cli/medperf/tests/commands/test_execution.py index 669d7dfd9..09df143d1 100644 --- a/cli/medperf/tests/commands/test_execution.py +++ b/cli/medperf/tests/commands/test_execution.py @@ -169,14 +169,14 @@ def test_cube_run_are_called_properly(mocker, setup): exp_model_call = call( task="infer", - output_logs=exp_model_logs_path, + output_logs_file=exp_model_logs_path, timeout=config.infer_timeout, data_path=INPUT_DATASET.data_path, output_path=exp_preds_path, ) exp_eval_call = call( task="evaluate", - output_logs=exp_metrics_logs_path, + output_logs_file=exp_metrics_logs_path, timeout=config.evaluate_timeout, predictions=exp_preds_path, labels=INPUT_DATASET.labels_path, diff --git a/cli/medperf/tests/commands/test_list.py b/cli/medperf/tests/commands/test_list.py index 1c2dc3267..ce7035960 100644 --- a/cli/medperf/tests/commands/test_list.py +++ b/cli/medperf/tests/commands/test_list.py @@ -47,18 +47,18 @@ def set_common_attributes(self, setup): self.state_variables = state_variables self.spies = spies - @pytest.mark.parametrize("local_only", [False, True]) + @pytest.mark.parametrize("unregistered", [False, True]) @pytest.mark.parametrize("mine_only", [False, True]) - def test_entity_all_is_called_properly(self, mocker, local_only, mine_only): + def test_entity_all_is_called_properly(self, mocker, unregistered, mine_only): # Arrange filters = {"owner": 1} if mine_only else {} # Act - EntityList.run(Entity, [], local_only, mine_only) + EntityList.run(Entity, [], unregistered, mine_only) # Assert self.spies["all"].assert_called_once_with( - local_only=local_only, filters=filters + unregistered=unregistered, filters=filters ) @pytest.mark.parametrize("fields", [["UID", "MLCube"]]) diff --git a/cli/medperf/tests/commands/test_view.py b/cli/medperf/tests/commands/test_view.py index a2dddfeda..0ffe0fb13 100644 --- a/cli/medperf/tests/commands/test_view.py +++ b/cli/medperf/tests/commands/test_view.py @@ -1,143 +1,86 @@ import pytest -import yaml -import json from medperf.entities.interface import Entity -from medperf.exceptions import InvalidArgumentError from medperf.commands.view import EntityView - -def expected_output(entities, format): - if isinstance(entities, list): - data = [entity.todict() for entity in entities] - else: - data = entities.todict() - - if format == "yaml": - return yaml.dump(data) - if format == "json": - return json.dumps(data) - - -def generate_entity(id, mocker): - entity = mocker.create_autospec(spec=Entity) - mocker.patch.object(entity, "todict", return_value={"id": id}) - return entity +PATCH_VIEW = "medperf.commands.view.{}" @pytest.fixture -def ui_spy(mocker, ui): - return mocker.patch.object(ui, "print") +def entity(mocker): + return mocker.create_autospec(Entity) -@pytest.fixture( - params=[{"local": ["1", "2", "3"], "remote": ["4", "5", "6"], "user": ["4"]}] -) -def setup(request, mocker): - local_ids = request.param.get("local", []) - remote_ids = request.param.get("remote", []) - user_ids = request.param.get("user", []) - all_ids = list(set(local_ids + remote_ids + user_ids)) - - local_entities = [generate_entity(id, mocker) for id in local_ids] - remote_entities = [generate_entity(id, mocker) for id in remote_ids] - user_entities = [generate_entity(id, mocker) for id in user_ids] - all_entities = list(set(local_entities + remote_entities + user_entities)) - - def mock_all(filters={}, local_only=False): - if "owner" in filters: - return user_entities - if local_only: - return local_entities - return all_entities - - def mock_get(entity_id): - if entity_id in all_ids: - return generate_entity(entity_id, mocker) - else: - raise InvalidArgumentError - - mocker.patch("medperf.commands.view.get_medperf_user_data", return_value={"id": 1}) - mocker.patch.object(Entity, "all", side_effect=mock_all) - mocker.patch.object(Entity, "get", side_effect=mock_get) - - return local_entities, remote_entities, user_entities, all_entities - - -class TestViewEntityID: - def test_view_displays_entity_if_given(self, mocker, setup, ui_spy): - # Arrange - entity_id = "1" - entity = generate_entity(entity_id, mocker) - output = expected_output(entity, "yaml") - - # Act - EntityView.run(entity_id, Entity) - - # Assert - ui_spy.assert_called_once_with(output) - - def test_view_displays_all_if_no_id(self, setup, ui_spy): - # Arrange - *_, entities = setup - output = expected_output(entities, "yaml") - - # Act - EntityView.run(None, Entity) - - # Assert - ui_spy.assert_called_once_with(output) - - -class TestViewFilteredEntities: - def test_view_displays_local_entities(self, setup, ui_spy): - # Arrange - entities, *_ = setup - output = expected_output(entities, "yaml") - - # Act - EntityView.run(None, Entity, local_only=True) - - # Assert - ui_spy.assert_called_once_with(output) - - def test_view_displays_user_entities(self, setup, ui_spy): - # Arrange - *_, entities, _ = setup - output = expected_output(entities, "yaml") - - # Act - EntityView.run(None, Entity, mine_only=True) - - # Assert - ui_spy.assert_called_once_with(output) - - -@pytest.mark.parametrize("entity_id", ["4", None]) -@pytest.mark.parametrize("format", ["yaml", "json"]) -class TestViewOutput: - @pytest.fixture - def output(self, setup, mocker, entity_id, format): - if entity_id is None: - *_, entities = setup - return expected_output(entities, format) - else: - entity = generate_entity(entity_id, mocker) - return expected_output(entity, format) - - def test_view_displays_specified_format(self, entity_id, output, ui_spy, format): - # Act - EntityView.run(entity_id, Entity, format=format) - - # Assert - ui_spy.assert_called_once_with(output) - - def test_view_stores_specified_format(self, entity_id, output, format, fs): - # Arrange - filename = "file.txt" - - # Act - EntityView.run(entity_id, Entity, format=format, output=filename) - - # Assert - contents = open(filename, "r").read() - assert contents == output +@pytest.fixture +def entity_view(mocker): + view_class = EntityView(None, Entity, "", "", "", "") + return view_class + + +def test_prepare_with_id_given(mocker, entity_view, entity): + # Arrange + entity_view.entity_id = 1 + get_spy = mocker.patch(PATCH_VIEW.format("Entity.get"), return_value=entity) + all_spy = mocker.patch(PATCH_VIEW.format("Entity.all"), return_value=[entity]) + + # Act + entity_view.prepare() + + # Assert + get_spy.assert_called_once_with(1) + all_spy.assert_not_called() + assert not isinstance(entity_view.data, list) + + +def test_prepare_with_no_id_given(mocker, entity_view, entity): + # Arrange + entity_view.entity_id = None + entity_view.mine_only = False + get_spy = mocker.patch(PATCH_VIEW.format("Entity.get"), return_value=entity) + all_spy = mocker.patch(PATCH_VIEW.format("Entity.all"), return_value=[entity]) + + # Act + entity_view.prepare() + + # Assert + all_spy.assert_called_once() + get_spy.assert_not_called() + assert isinstance(entity_view.data, list) + + +@pytest.mark.parametrize("unregistered", [False, True]) +def test_prepare_with_no_id_calls_all_with_unregistered_properly( + mocker, entity_view, entity, unregistered +): + # Arrange + entity_view.entity_id = None + entity_view.mine_only = False + entity_view.unregistered = unregistered + all_spy = mocker.patch(PATCH_VIEW.format("Entity.all"), return_value=[entity]) + + # Act + entity_view.prepare() + + # Assert + all_spy.assert_called_once_with(unregistered=unregistered, filters={}) + + +@pytest.mark.parametrize("filters", [{}, {"f1": "v1"}]) +@pytest.mark.parametrize("mine_only", [False, True]) +def test_prepare_with_no_id_calls_all_with_proper_filters( + mocker, entity_view, entity, filters, mine_only +): + # Arrange + entity_view.entity_id = None + entity_view.mine_only = False + entity_view.unregistered = False + entity_view.filters = filters + all_spy = mocker.patch(PATCH_VIEW.format("Entity.all"), return_value=[entity]) + mocker.patch(PATCH_VIEW.format("get_medperf_user_data"), return_value={"id": 1}) + if mine_only: + filters["owner"] = 1 + + # Act + entity_view.prepare() + + # Assert + all_spy.assert_called_once_with(unregistered=False, filters=filters) diff --git a/cli/medperf/tests/comms/test_rest.py b/cli/medperf/tests/comms/test_rest.py index fb3596c98..e17f99f48 100644 --- a/cli/medperf/tests/comms/test_rest.py +++ b/cli/medperf/tests/comms/test_rest.py @@ -52,7 +52,7 @@ def server(mocker, ui): {"json": {}}, ), ( - "associate_dset", + "associate_benchmark_dataset", "post", 201, [1, 1], @@ -115,7 +115,7 @@ def test_methods_run_authorized_method(mocker, server, method_params): ("get_cube_metadata", [1], {}, CommunicationRetrievalError), ("upload_dataset", [{}], {"id": "invalid id"}, CommunicationRequestError), ("upload_result", [{}], {"id": "invalid id"}, CommunicationRequestError), - ("associate_dset", [1, 1], {}, CommunicationRequestError), + ("associate_benchmark_dataset", [1, 1], {}, CommunicationRequestError), ], ) def test_methods_exit_if_status_not_200(mocker, server, status, method_params): @@ -462,7 +462,9 @@ def test_upload_results_returns_result_body(mocker, server, body): @pytest.mark.parametrize("cube_uid", [2156, 915]) @pytest.mark.parametrize("benchmark_uid", [1206, 3741]) -def test_associate_cube_posts_association_data(mocker, server, cube_uid, benchmark_uid): +def test_associate_benchmark_model_posts_association_data( + mocker, server, cube_uid, benchmark_uid +): # Arrange data = { "approval_status": Status.PENDING.value, @@ -474,7 +476,7 @@ def test_associate_cube_posts_association_data(mocker, server, cube_uid, benchma spy = mocker.patch(patch_server.format("REST._REST__auth_post"), return_value=res) # Act - server.associate_cube(cube_uid, benchmark_uid) + server.associate_benchmark_model(cube_uid, benchmark_uid) # Assert spy.assert_called_once_with(ANY, json=data) @@ -483,7 +485,7 @@ def test_associate_cube_posts_association_data(mocker, server, cube_uid, benchma @pytest.mark.parametrize("dataset_uid", [4417, 1057]) @pytest.mark.parametrize("benchmark_uid", [1011, 635]) @pytest.mark.parametrize("status", [Status.APPROVED.value, Status.REJECTED.value]) -def test_set_dataset_association_approval_sets_approval( +def test_update_benchmark_dataset_association_sets_approval( mocker, server, dataset_uid, benchmark_uid, status ): # Arrange @@ -494,7 +496,7 @@ def test_set_dataset_association_approval_sets_approval( exp_url = f"{full_url}/datasets/{dataset_uid}/benchmarks/{benchmark_uid}/" # Act - server.set_dataset_association_approval(benchmark_uid, dataset_uid, status) + server.update_benchmark_dataset_association(benchmark_uid, dataset_uid, status) # Assert spy.assert_called_once_with(exp_url, status) @@ -503,7 +505,7 @@ def test_set_dataset_association_approval_sets_approval( @pytest.mark.parametrize("mlcube_uid", [4596, 3530]) @pytest.mark.parametrize("benchmark_uid", [3966, 4188]) @pytest.mark.parametrize("status", [Status.APPROVED.value, Status.REJECTED.value]) -def test_set_mlcube_association_approval_sets_approval( +def test_update_benchmark_model_association_sets_approval( mocker, server, mlcube_uid, benchmark_uid, status ): # Arrange @@ -514,7 +516,7 @@ def test_set_mlcube_association_approval_sets_approval( exp_url = f"{full_url}/mlcubes/{mlcube_uid}/benchmarks/{benchmark_uid}/" # Act - server.set_mlcube_association_approval(benchmark_uid, mlcube_uid, status) + server.update_benchmark_model_association(benchmark_uid, mlcube_uid, status) # Assert spy.assert_called_once_with(exp_url, status) @@ -576,7 +578,7 @@ def test_upload_benchmark_returns_benchmark_body(mocker, server, body): @pytest.mark.parametrize("mlcube_uid", [4596, 3530]) @pytest.mark.parametrize("benchmark_uid", [3966, 4188]) @pytest.mark.parametrize("priority", [2, -10]) -def test_set_mlcube_association_priority_sets_priority( +def test_update_benchmark_model_association_sets_priority( mocker, server, mlcube_uid, benchmark_uid, priority ): # Arrange @@ -585,7 +587,7 @@ def test_set_mlcube_association_priority_sets_priority( exp_url = f"{full_url}/mlcubes/{mlcube_uid}/benchmarks/{benchmark_uid}/" # Act - server.set_mlcube_association_priority(benchmark_uid, mlcube_uid, priority) + server.update_benchmark_model_association(benchmark_uid, mlcube_uid, priority) # Assert spy.assert_called_once_with(exp_url, json={"priority": priority}) diff --git a/cli/medperf/tests/entities/test_benchmark.py b/cli/medperf/tests/entities/test_benchmark.py index 3f1fde2e2..c36771e12 100644 --- a/cli/medperf/tests/entities/test_benchmark.py +++ b/cli/medperf/tests/entities/test_benchmark.py @@ -9,8 +9,9 @@ @pytest.fixture( params={ - "local": [1, 2, 3], - "remote": [4, 5, 6], + "unregistered": ["b1", "b2"], + "local": ["b1", "b2", 1, 2, 3], + "remote": [1, 2, 3, 4, 5, 6], "user": [4], "models": [10, 11], } diff --git a/cli/medperf/tests/entities/test_cube.py b/cli/medperf/tests/entities/test_cube.py index 96f81dba0..89e7cc5a9 100644 --- a/cli/medperf/tests/entities/test_cube.py +++ b/cli/medperf/tests/entities/test_cube.py @@ -24,7 +24,14 @@ } -@pytest.fixture(params={"local": [1, 2, 3], "remote": [4, 5, 6], "user": [4]}) +@pytest.fixture( + params={ + "unregistered": ["c1", "c2"], + "local": ["c1", "c2", 1, 2, 3], + "remote": [1, 2, 3, 4, 5, 6], + "user": [4], + } +) def setup(request, mocker, comms, fs): local_ents = request.param.get("local", []) remote_ents = request.param.get("remote", []) @@ -282,7 +289,9 @@ def test_run_stops_execution_if_child_fails(self, mocker, setup, task): cube.run(task) -@pytest.mark.parametrize("setup", [{"local": [DEFAULT_CUBE]}], indirect=True) +@pytest.mark.parametrize( + "setup", [{"local": [DEFAULT_CUBE], "remote": [DEFAULT_CUBE]}], indirect=True +) @pytest.mark.parametrize("task", ["task"]) @pytest.mark.parametrize( "out_key,out_value", diff --git a/cli/medperf/tests/entities/test_entity.py b/cli/medperf/tests/entities/test_entity.py index c636b2c26..b9d309f39 100644 --- a/cli/medperf/tests/entities/test_entity.py +++ b/cli/medperf/tests/entities/test_entity.py @@ -15,7 +15,7 @@ setup_result_fs, setup_result_comms, ) -from medperf.exceptions import InvalidArgumentError +from medperf.exceptions import CommunicationRetrievalError, InvalidArgumentError @pytest.fixture(params=[Benchmark, Cube, Dataset, Result]) @@ -23,7 +23,14 @@ def Implementation(request): return request.param -@pytest.fixture(params={"local": [1, 2, 3], "remote": [4, 5, 6], "user": [4]}) +@pytest.fixture( + params={ + "unregistered": ["e1", "e2"], + "local": ["e1", "e2", 1, 2, 3], + "remote": [1, 2, 3, 4, 5, 6], + "user": [4], + } +) def setup(request, mocker, comms, Implementation, fs): local_ids = request.param.get("local", []) remote_ids = request.param.get("remote", []) @@ -54,39 +61,52 @@ def setup(request, mocker, comms, Implementation, fs): @pytest.mark.parametrize( "setup", - [{"local": [283, 17, 493], "remote": [283, 1, 2], "user": [2]}], + [ + { + "unregistered": ["e1", "e2"], + "local": ["e1", "e2", 283], + "remote": [283, 1, 2], + "user": [2], + } + ], indirect=True, ) class TestAll: @pytest.fixture(autouse=True) def set_common_attributes(self, setup): self.ids = setup + self.unregistered_ids = set(self.ids["unregistered"]) self.local_ids = set(self.ids["local"]) self.remote_ids = set(self.ids["remote"]) self.user_ids = set(self.ids["user"]) - def test_all_returns_all_remote_and_local(self, Implementation): - # Arrange - all_ids = self.local_ids.union(self.remote_ids) - + def test_all_returns_all_remote_by_default(self, Implementation): # Act entities = Implementation.all() # Assert retrieved_ids = set([e.todict()["id"] for e in entities]) - assert all_ids == retrieved_ids + assert self.remote_ids == retrieved_ids - def test_all_local_only_returns_all_local(self, Implementation): + def test_all_unregistered_returns_all_unregistered(self, Implementation): # Act - entities = Implementation.all(local_only=True) + entities = Implementation.all(unregistered=True) # Assert - retrieved_ids = set([e.todict()["id"] for e in entities]) - assert self.local_ids == retrieved_ids + retrieved_names = set([e.name for e in entities]) + assert self.unregistered_ids == retrieved_names @pytest.mark.parametrize( - "setup", [{"local": [78], "remote": [479, 42, 7, 1]}], indirect=True + "setup", + [ + { + "unregistered": ["e1", "e2"], + "local": ["e1", "e2", 479], + "remote": [479, 42, 7, 1], + } + ], + indirect=True, ) class TestGet: def test_get_retrieves_entity_from_server(self, Implementation, setup): @@ -99,30 +119,20 @@ def test_get_retrieves_entity_from_server(self, Implementation, setup): # Assert assert entity.todict()["id"] == id - def test_get_retrieves_entity_local_if_not_on_server(self, Implementation, setup): - # Arrange - id = setup["local"][0] - - # Act - entity = Implementation.get(id) - - # Assert - assert entity.todict()["id"] == id - def test_get_raises_error_if_nonexistent(self, Implementation, setup): # Arrange id = str(19283) # Act & Assert - with pytest.raises(InvalidArgumentError): + with pytest.raises(CommunicationRetrievalError): Implementation.get(id) -@pytest.mark.parametrize("setup", [{"local": [742]}], indirect=True) +@pytest.mark.parametrize("setup", [{"remote": [742]}], indirect=True) class TestToDict: @pytest.fixture(autouse=True) def set_common_attributes(self, setup): - self.id = setup["local"][0] + self.id = setup["remote"][0] def test_todict_returns_dict_representation(self, Implementation): # Arrange @@ -147,7 +157,16 @@ def test_todict_can_recreate_object(self, Implementation): assert ent_dict == ent_copy_dict -@pytest.mark.parametrize("setup", [{"local": [36]}], indirect=True) +@pytest.mark.parametrize( + "setup", + [ + { + "unregistered": ["e1", "e2"], + "local": ["e1", "e2"], + } + ], + indirect=True, +) class TestUpload: @pytest.fixture(autouse=True) def set_common_attributes(self, setup): diff --git a/cli/medperf/tests/entities/utils.py b/cli/medperf/tests/entities/utils.py index 522251ca7..19c3178e3 100644 --- a/cli/medperf/tests/entities/utils.py +++ b/cli/medperf/tests/entities/utils.py @@ -15,14 +15,17 @@ # Setup Benchmark def setup_benchmark_fs(ents, fs): - bmks_path = config.benchmarks_folder for ent in ents: - if not isinstance(ent, dict): - # Assume we're passing ids - ent = {"id": str(ent)} - id = ent["id"] - bmk_filepath = os.path.join(bmks_path, str(id), config.benchmarks_filename) - bmk_contents = TestBenchmark(**ent) + # Assume we're passing ids, names, or dicts + if isinstance(ent, dict): + bmk_contents = TestBenchmark(**ent) + elif isinstance(ent, int) or isinstance(ent, str) and ent.isdigit(): + bmk_contents = TestBenchmark(id=str(ent)) + else: + bmk_contents = TestBenchmark(id=None, name=ent) + bmk_contents.generated_uid = ent + + bmk_filepath = os.path.join(bmk_contents.path, config.benchmarks_filename) cubes_ids = [] cubes_ids.append(bmk_contents.data_preparation_mlcube) cubes_ids.append(bmk_contents.reference_model_mlcube) @@ -30,7 +33,7 @@ def setup_benchmark_fs(ents, fs): cubes_ids = list(set(cubes_ids)) setup_cube_fs(cubes_ids, fs) try: - fs.create_file(bmk_filepath, contents=yaml.dump(bmk_contents.dict())) + fs.create_file(bmk_filepath, contents=yaml.dump(bmk_contents.todict())) except FileExistsError: pass @@ -51,17 +54,18 @@ def setup_benchmark_comms(mocker, comms, all_ents, user_ents, uploaded): # Setup Cube def setup_cube_fs(ents, fs): - cubes_path = config.cubes_folder for ent in ents: - if not isinstance(ent, dict): - # Assume we're passing ids - ent = {"id": str(ent)} - id = ent["id"] - meta_cube_file = os.path.join( - cubes_path, str(id), config.cube_metadata_filename - ) - cube = TestCube(**ent) - meta = cube.dict() + # Assume we're passing ids, names, or dicts + if isinstance(ent, dict): + cube = TestCube(**ent) + elif isinstance(ent, int) or isinstance(ent, str) and ent.isdigit(): + cube = TestCube(id=str(ent)) + else: + cube = TestCube(id=None, name=ent) + cube.generated_uid = ent + + meta_cube_file = os.path.join(cube.path, config.cube_metadata_filename) + meta = cube.todict() try: fs.create_file(meta_cube_file, contents=yaml.dump(meta)) except FileExistsError: @@ -124,18 +128,21 @@ def setup_cube_comms_downloads(mocker, fs): # Setup Dataset def setup_dset_fs(ents, fs): - dsets_path = config.datasets_folder for ent in ents: - if not isinstance(ent, dict): - # Assume passing ids - ent = {"id": str(ent)} - id = ent["id"] - reg_dset_file = os.path.join(dsets_path, str(id), config.reg_file) - dset_contents = TestDataset(**ent) + # Assume we're passing ids, names, or dicts + if isinstance(ent, dict): + dset_contents = TestDataset(**ent) + elif isinstance(ent, int) or isinstance(ent, str) and ent.isdigit(): + dset_contents = TestDataset(id=str(ent)) + else: + dset_contents = TestDataset(id=None, name=ent) + dset_contents.generated_uid = ent + + reg_dset_file = os.path.join(dset_contents.path, config.reg_file) cube_id = dset_contents.data_preparation_mlcube setup_cube_fs([cube_id], fs) try: - fs.create_file(reg_dset_file, contents=yaml.dump(dset_contents.dict())) + fs.create_file(reg_dset_file, contents=yaml.dump(dset_contents.todict())) except FileExistsError: pass @@ -155,22 +162,26 @@ def setup_dset_comms(mocker, comms, all_ents, user_ents, uploaded): # Setup Result def setup_result_fs(ents, fs): - results_path = config.results_folder for ent in ents: - if not isinstance(ent, dict): - # Assume passing ids - ent = {"id": str(ent)} - id = ent["id"] - result_file = os.path.join(results_path, str(id), config.results_info_file) - bmk_id = ent.get("benchmark", 1) - cube_id = ent.get("model", 1) - dataset_id = ent.get("dataset", 1) + # Assume we're passing ids, names, or dicts + if isinstance(ent, dict): + result_contents = TestResult(**ent) + elif isinstance(ent, int) or isinstance(ent, str) and ent.isdigit(): + result_contents = TestResult(id=str(ent)) + else: + result_contents = TestResult(id=None, name=ent) + result_contents.generated_uid = ent + + result_file = os.path.join(result_contents.path, config.results_info_file) + bmk_id = result_contents.benchmark + cube_id = result_contents.model + dataset_id = result_contents.dataset setup_benchmark_fs([bmk_id], fs) setup_cube_fs([cube_id], fs) setup_dset_fs([dataset_id], fs) - result_contents = TestResult(**ent) + try: - fs.create_file(result_file, contents=yaml.dump(result_contents.dict())) + fs.create_file(result_file, contents=yaml.dump(result_contents.todict())) except FileExistsError: pass diff --git a/cli/medperf/utils.py b/cli/medperf/utils.py index 35aa697d6..00d06abd8 100644 --- a/cli/medperf/utils.py +++ b/cli/medperf/utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +import base64 import re import os import signal @@ -15,7 +16,6 @@ import shutil from pexpect import spawn from datetime import datetime -from pydantic.datetime_parse import parse_datetime from typing import List from colorama import Fore, Style from pexpect.exceptions import TIMEOUT @@ -404,30 +404,6 @@ def get_cube_image_name(cube_path: str) -> str: raise MedperfException(msg) -def filter_latest_associations(associations, entity_key): - """Given a list of entity-benchmark associations, this function - retrieves a list containing the latest association of each - entity instance. - - Args: - associations (list[dict]): the list of associations - entity_key (str): either "dataset" or "model_mlcube" - - Returns: - list[dict]: the list containing the latest association of each - entity instance. - """ - - associations.sort(key=lambda assoc: parse_datetime(assoc["created_at"])) - latest_associations = {} - for assoc in associations: - entity_id = assoc[entity_key] - latest_associations[entity_id] = assoc - - latest_associations = list(latest_associations.values()) - return latest_associations - - def check_for_updates() -> None: """Check if the current branch is up-to-date with its remote counterpart using GitPython.""" repo = Repo(config.BASE_DIR) @@ -506,3 +482,16 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.proc.wait() # Return False to propagate exceptions, if any return False + + +def get_pki_assets_path(common_name: str, ca_name: str): + # Base64 encoding is used just to avoid special characters used in emails + # and server domains/ipaddresses. + cn_encoded = base64.b64encode(common_name.encode("utf-8")).decode("utf-8") + cn_encoded = cn_encoded.rstrip("=") + return os.path.join(config.pki_assets, cn_encoded, ca_name) + + +def get_participant_label(email, data_id): + # return f"d{data_id}" + return f"{email}" diff --git a/cli/requirements.txt b/cli/requirements.txt index 02d8ee05a..94384378c 100644 --- a/cli/requirements.txt +++ b/cli/requirements.txt @@ -22,6 +22,7 @@ setuptools<=66.1.1 email-validator==2.0.0 auth0-python==4.3.0 pandas==2.1.0 +numpy==1.26.4 watchdog==3.0.0 GitPython==3.1.41 psutil==5.9.8 diff --git a/cli/tests_setup.sh b/cli/tests_setup.sh index 91887ea0a..5b1cdce89 100644 --- a/cli/tests_setup.sh +++ b/cli/tests_setup.sh @@ -1,13 +1,12 @@ #! /bin/bash -while getopts s:d:c:ft: flag -do - case "${flag}" in - s) SERVER_URL=${OPTARG};; - d) DIRECTORY=${OPTARG};; - c) CLEANUP="true";; - f) FRESH="true";; - t) TIMEOUT=${OPTARG};; - esac +while getopts s:d:c:ft: flag; do + case "${flag}" in + s) SERVER_URL=${OPTARG} ;; + d) DIRECTORY=${OPTARG} ;; + c) CLEANUP="true" ;; + f) FRESH="true" ;; + t) TIMEOUT=${OPTARG} ;; + esac done SERVER_URL="${SERVER_URL:-https://localhost:8000}" @@ -16,7 +15,7 @@ CLEANUP="${CLEANUP:-false}" FRESH="${FRESH:-false}" TEST_ROOT="/tmp/medperf_tests_$(date +%Y%m%d%H%M%S)" -export MEDPERF_CONFIG_PATH="$TEST_ROOT/config.yaml" # env var +export MEDPERF_CONFIG_PATH="$TEST_ROOT/config.yaml" # env var MEDPERF_STORAGE="$TEST_ROOT/storage" SNAPSHOTS_FOLDER=$TEST_ROOT/snapshots @@ -31,18 +30,18 @@ SQLITE3_FILE="$(dirname $(dirname $(realpath "$0")))/server/db.sqlite3" echo "Server URL: $SERVER_URL" print_eval() { - local timestamp=$(date +%m%d%H%M%S) - local formatted_cmd=$(echo "$@" | sed 's/[^a-zA-Z0-9]\+/_/g' | cut -c 1-50) - LAST_SNAPSHOT_PATH="$SNAPSHOTS_FOLDER/${timestamp}_${formatted_cmd}.sqlite3" - cp $SQLITE3_FILE "$LAST_SNAPSHOT_PATH" - echo ">> $@" - eval "$@" - # local exit_code=$? - # echo "Exit code: $exit_code" - # return $exit_code + local timestamp=$(date +%m%d%H%M%S) + local formatted_cmd=$(echo "$@" | sed 's/[^a-zA-Z0-9]\+/_/g' | cut -c 1-50) + LAST_SNAPSHOT_PATH="$SNAPSHOTS_FOLDER/${timestamp}_${formatted_cmd}.sqlite3" + cp $SQLITE3_FILE "$LAST_SNAPSHOT_PATH" + echo ">> $@" + eval "$@" + # local exit_code=$? + # echo "Exit code: $exit_code" + # return $exit_code } # frequently used -clean(){ +clean() { echo "=====================================" echo "Cleaning up medperf tmp files" echo "=====================================" @@ -50,9 +49,13 @@ clean(){ rm -fr $DIRECTORY rm -fr $TEST_ROOT } -checkFailed(){ - if [ "$?" -ne "0" ]; then - if [ "$?" -eq 124 ]; then +checkFailed() { + EXITSTATUS="$?" + if [ -n "$2" ]; then + EXITSTATUS="1" + fi + if [ $EXITSTATUS -ne "0" ]; then + if [ $EXITSTATUS -eq 124 ]; then echo "Process timed out" fi echo $1 @@ -73,7 +76,7 @@ checkSucceeded() { if [ "$?" -eq 0 ]; then i_am_a_command_that_does_not_exist_and_hence_changes_the_last_exit_status_to_nonzero fi - checkFailed $1 + checkFailed "$1" } if ${FRESH}; then @@ -93,6 +96,7 @@ DEMO_URL="${ASSETS_URL}/assets/datasets/demo_dset1.tar.gz" # prep cubes PREP_MLCUBE="$ASSETS_URL/prep-sep/mlcube/mlcube.yaml" PREP_PARAMS="$ASSETS_URL/prep-sep/mlcube/workspace/parameters.yaml" +PREP_TRAINING_MLCUBE="https://storage.googleapis.com/medperf-storage/testfl/mlcube.yaml" # model cubes FAILING_MODEL_MLCUBE="$ASSETS_URL/model-bug/mlcube/mlcube.yaml" # doesn't fail with association @@ -114,17 +118,26 @@ MODEL_LOG_DEBUG_PARAMS="$ASSETS_URL/model-debug-logging/mlcube/workspace/paramet METRIC_MLCUBE="$ASSETS_URL/metrics/mlcube/mlcube.yaml" METRIC_PARAMS="$ASSETS_URL/metrics/mlcube/workspace/parameters.yaml" +# FL cubes +TRAIN_MLCUBE="https://raw.githubusercontent.com/hasan7n/medperf/19c80d88deaad27b353d1cb9bc180757534027aa/examples/fl/fl/mlcube/mlcube.yaml" +TRAIN_WEIGHTS="https://storage.googleapis.com/medperf-storage/testfl/init_weights_miccai.tar.gz" +FLADMIN_MLCUBE="https://raw.githubusercontent.com/hasan7n/medperf/bc431ffe6c3b761b28674816e6f26511e8b27042/examples/fl/fl_admin/mlcube/mlcube.yaml" + # test users credentials MODELOWNER="testmo@example.com" DATAOWNER="testdo@example.com" BENCHMARKOWNER="testbo@example.com" ADMIN="testadmin@example.com" +DATAOWNER2="testdo2@example.com" +AGGOWNER="testao@example.com" +FLADMIN="testfladmin@example.com" # local MLCubes for local compatibility tests PREP_LOCAL="$(dirname $(dirname $(realpath "$0")))/examples/chestxray_tutorial/data_preparator/mlcube" MODEL_LOCAL="$(dirname $(dirname $(realpath "$0")))/examples/chestxray_tutorial/model_custom_cnn/mlcube" METRIC_LOCAL="$(dirname $(dirname $(realpath "$0")))/examples/chestxray_tutorial/metrics/mlcube" +TRAINING_CONFIG="$(dirname $(dirname $(realpath "$0")))/examples/fl/fl/mlcube/workspace/training_config.yaml" # create storage folders mkdir -p "$TEST_ROOT" mkdir -p "$MEDPERF_STORAGE" @@ -141,5 +154,5 @@ print_eval medperf profile ls checkFailed "Creating config failed" echo "Moving storage setting to a new folder: ${MEDPERF_STORAGE}" -python $(dirname $(realpath "$0"))/cli_tests_move_storage.py $MEDPERF_CONFIG_PATH $MEDPERF_STORAGE +python $(dirname $(realpath "$0"))/cli_tests_move_storage.py $MEDPERF_CONFIG_PATH $MEDPERF_STORAGE checkFailed "Moving storage failed" diff --git a/examples/fl/cert/build.sh b/examples/fl/cert/build.sh new file mode 100644 index 000000000..d56304274 --- /dev/null +++ b/examples/fl/cert/build.sh @@ -0,0 +1 @@ +mlcube configure --mlcube ./mlcube -Pdocker.build_strategy=always diff --git a/examples/fl/cert/mlcube/mlcube.yaml b/examples/fl/cert/mlcube/mlcube.yaml new file mode 100644 index 000000000..5612dd379 --- /dev/null +++ b/examples/fl/cert/mlcube/mlcube.yaml @@ -0,0 +1,38 @@ +name: FL MLCube +description: FL MLCube +authors: + - { name: MLCommons Medical Working Group } + +platform: + accelerator_count: 0 + +docker: + # Image name + image: mlcommons/medperf-step-cli:0.0.0 + # Docker build context relative to $MLCUBE_ROOT. Default is `build`. + build_context: "../project" + # Docker file name within docker build context, default is `Dockerfile`. + build_file: "Dockerfile" + +tasks: + trust: + entrypoint: /bin/sh /mlcube_project/trust.sh trust + parameters: + inputs: + ca_config: ca_config.json + outputs: + pki_assets: pki_assets/ + get_client_cert: + entrypoint: /bin/sh /mlcube_project/get_cert.sh get_client_cert + parameters: + inputs: + ca_config: ca_config.json + outputs: + pki_assets: pki_assets/ + get_server_cert: + entrypoint: /bin/sh /mlcube_project/get_cert.sh get_server_cert + parameters: + inputs: + ca_config: ca_config.json + outputs: + pki_assets: pki_assets/ diff --git a/examples/fl/cert/mlcube/workspace/ca_config.json b/examples/fl/cert/mlcube/workspace/ca_config.json new file mode 100644 index 000000000..b701d00f8 --- /dev/null +++ b/examples/fl/cert/mlcube/workspace/ca_config.json @@ -0,0 +1,7 @@ +{ + "address": "https://flcerts.medperf.org", + "port": 443, + "fingerprint": "", + "client_provisioner": "auth0", + "server_provisioner": "acme" +} \ No newline at end of file diff --git a/examples/fl/cert/project/Dockerfile b/examples/fl/cert/project/Dockerfile new file mode 100644 index 000000000..20f966889 --- /dev/null +++ b/examples/fl/cert/project/Dockerfile @@ -0,0 +1,12 @@ +FROM python:3.11.9-alpine + +RUN apk update && apk add jq curl + +ARG VERSION=0.26.1 +RUN wget https://dl.smallstep.com/gh-release/cli/gh-release-header/v${VERSION}/step_linux_${VERSION}_amd64.tar.gz \ + && tar -xf step_linux_${VERSION}_amd64.tar.gz \ + && cp step_${VERSION}/bin/step /usr/bin + +COPY . /mlcube_project + +ENTRYPOINT ["/bin/sh"] \ No newline at end of file diff --git a/examples/fl/cert/project/get_cert.sh b/examples/fl/cert/project/get_cert.sh new file mode 100644 index 000000000..c0252b4be --- /dev/null +++ b/examples/fl/cert/project/get_cert.sh @@ -0,0 +1,95 @@ +#!/bin/bash + +# Read arguments +while [ "${1:-}" != "" ]; do + case "$1" in + "--ca_config"*) + ca_config="${1#*=}" + ;; + "--pki_assets"*) + pki_assets="${1#*=}" + ;; + *) + task=$1 + ;; + esac + shift +done + +# validate arguments +if [ -z "$ca_config" ]; then + echo "--ca_config is required" + exit 1 +fi + +if [ -z "$pki_assets" ]; then + echo "--pki_assets is required" + exit 1 +fi + +if [ -z "$MEDPERF_INPUT_CN" ]; then + echo "MEDPERF_INPUT_CN environment variable must be set" + exit 1 +fi + +CA_ADDRESS=$(jq -r '.address' $ca_config) +CA_PORT=$(jq -r '.port' $ca_config) +CA_FINGERPRINT=$(jq -r '.fingerprint' $ca_config) +CA_CLIENT_PROVISIONER=$(jq -r '.client_provisioner' $ca_config) +CA_SERVER_PROVISIONER=$(jq -r '.server_provisioner' $ca_config) + +export STEPPATH=$pki_assets/.step + +if [ "$task" = "get_server_cert" ]; then + PROVISIONER_ARGS="--provisioner $CA_SERVER_PROVISIONER" +elif [ "$task" = "get_client_cert" ]; then + PROVISIONER_ARGS="--provisioner $CA_CLIENT_PROVISIONER --console" +else + echo "Invalid task: Task should be get_server_cert or get_client_cert" + exit 1 +fi + +cert_path=$pki_assets/crt.crt +key_path=$pki_assets/key.key + +if [ -e $STEPPATH ]; then + echo ".step folder already exists" + exit 1 +fi + +if [ -e $cert_path ]; then + echo "cert file already exists" + exit 1 +fi + +if [ -e $key_path ]; then + echo "key file already exists" + exit 1 +fi + +if [ -n "$CA_FINGERPRINT" ]; then + # trust the CA. + step ca bootstrap --ca-url $CA_ADDRESS:$CA_PORT \ + --fingerprint $CA_FINGERPRINT + ROOT=$STEPPATH/certs/root_ca.crt +else + ROOT=/etc/ssl/certs/ca-certificates.crt +fi + +# generate private key and ask for a certificate +step ca certificate --ca-url $CA_ADDRESS:$CA_PORT \ + --root $ROOT \ + --kty=RSA \ + $PROVISIONER_ARGS \ + $MEDPERF_INPUT_CN $cert_path $key_path + +EXITSTATUS="$?" +if [ $EXITSTATUS -ne "0" ]; then + echo "Failed to get the certificate" + # cleanup + rm -rf $STEPPATH + exit 1 +fi + +# cleanup +rm -rf $STEPPATH diff --git a/examples/fl/cert/project/trust.sh b/examples/fl/cert/project/trust.sh new file mode 100644 index 000000000..e6045a7d4 --- /dev/null +++ b/examples/fl/cert/project/trust.sh @@ -0,0 +1,59 @@ +#!/bin/bash + +# Read arguments +while [ "${1:-}" != "" ]; do + case "$1" in + "--ca_config"*) + ca_config="${1#*=}" + ;; + "--pki_assets"*) + pki_assets="${1#*=}" + ;; + *) + task=$1 + ;; + esac + shift +done + +# validate arguments +if [ -z "$ca_config" ]; then + echo "--ca_config is required" + exit 1 +fi + +if [ -z "$pki_assets" ]; then + echo "--pki_assets is required" + exit 1 +fi + +if [ "$task" != "trust" ]; then + echo "Invalid task: Task should be 'trust'" + exit 1 +fi + +export STEPPATH=$pki_assets/.step + +CA_ADDRESS=$(jq -r '.address' $ca_config) +CA_PORT=$(jq -r '.port' $ca_config) +CA_FINGERPRINT=$(jq -r '.fingerprint' $ca_config) + +rm -rf $pki_assets/root_ca.crt + +if [ -n "$CA_FINGERPRINT" ]; then + # trust the CA. + step ca root $pki_assets/root_ca.crt --ca-url $CA_ADDRESS:$CA_PORT \ + --fingerprint $CA_FINGERPRINT +else + curl -o $pki_assets/root_ca.crt $CA_ADDRESS:$CA_PORT/roots.pem +fi +EXITSTATUS="$?" +if [ $EXITSTATUS -ne "0" ]; then + echo "Failed to retrieve the root certificate" + # cleanup + rm -rf $STEPPATH + exit 1 +fi + +# cleanup +rm -rf $STEPPATH diff --git a/examples/fl/cert/test.sh b/examples/fl/cert/test.sh new file mode 100644 index 000000000..e4862c8b6 --- /dev/null +++ b/examples/fl/cert/test.sh @@ -0,0 +1,6 @@ +medperf mlcube run --mlcube ./mlcube --task get_client_cert -e MEDPERF_INPUT_CN=hasan.kassem@mlcommons.org +medperf mlcube run --mlcube ./mlcube --task get_server_cert -e MEDPERF_INPUT_CN=34.41.173.238 -P 80 +# medperf mlcube run --mlcube ./mlcube --task get_server_cert +medperf mlcube run --mlcube ./mlcube --task trust +# docker run -it --entrypoint=/bin/bash --env MEDPERF_INPUT_CN=col1@example.com --volume '/home/hasan/work/medperf_ws/medperf/examples/fl/cert/mlcube/workspace:/mlcube_io0:ro' --volume '/home/hasan/work/medperf_ws/medperf/examples/fl/cert/mlcube/workspace/pki_assets:/mlcube_io1' mlcommons/medperf-step-cli:0.0.0 +# bash /mlcube_project/get_cert.sh get_client_cert --ca_config=/mlcube_io0/ca_config.json --pki_assets=/mlcube_io1 diff --git a/examples/fl/fl/.gitignore b/examples/fl/fl/.gitignore new file mode 100644 index 000000000..167a3778d --- /dev/null +++ b/examples/fl/fl/.gitignore @@ -0,0 +1,4 @@ +mlcube_* +ca +quick* +for_admin diff --git a/examples/fl/fl/README.md b/examples/fl/fl/README.md new file mode 100644 index 000000000..918f483e3 --- /dev/null +++ b/examples/fl/fl/README.md @@ -0,0 +1,6 @@ +# How to run tests + +- Run `setup_test.sh` just once to create certs and download required data. +- Run `test.sh` to start the aggregator and three collaborators. +- Run `clean.sh` to be able to rerun `test.sh` freshly. +- Run `setup_clean.sh` to clear what has been generated in step 1. diff --git a/examples/fl/fl/build.sh b/examples/fl/fl/build.sh new file mode 100644 index 000000000..08cdbb20c --- /dev/null +++ b/examples/fl/fl/build.sh @@ -0,0 +1,16 @@ +while getopts b flag; do + case "${flag}" in + b) BUILD_BASE="true" ;; + esac +done +BUILD_BASE="${BUILD_BASE:-false}" + +if ${BUILD_BASE}; then + git clone https://github.com/hasan7n/openfl.git + cd openfl + git checkout 7c9d4e7039f51014a4f7b3bedf5e2c7f1d353e68 + docker build -t local/openfl:local -f openfl-docker/Dockerfile.base . + cd .. + rm -rf openfl +fi +mlcube configure --mlcube ./mlcube -Pdocker.build_strategy=always diff --git a/examples/fl/fl/clean.sh b/examples/fl/fl/clean.sh new file mode 100644 index 000000000..ce7879606 --- /dev/null +++ b/examples/fl/fl/clean.sh @@ -0,0 +1,9 @@ +rm -rf mlcube_agg/workspace/final_weights +rm -rf mlcube_agg/workspace/logs +rm -rf mlcube_col1/workspace/logs +rm -rf mlcube_col2/workspace/logs +rm -rf mlcube_col3/workspace/logs +rm -rf mlcube_agg/workspace/plan.yaml +rm -rf mlcube_col1/workspace/plan.yaml +rm -rf mlcube_col2/workspace/plan.yaml +rm -rf mlcube_col3/workspace/plan.yaml diff --git a/examples/fl/fl/csr.conf b/examples/fl/fl/csr.conf new file mode 100644 index 000000000..c3b2d0f0c --- /dev/null +++ b/examples/fl/fl/csr.conf @@ -0,0 +1,31 @@ +[ req ] +default_bits = 3072 +prompt = no +default_md = sha384 +distinguished_name = req_distinguished_name + +[ req_distinguished_name ] +commonName = hasan-hp-zbook-15-g3.home + +[ alt_names ] +DNS.1 = hasan-hp-zbook-15-g3.home + +[ v3_client ] +basicConstraints = critical,CA:FALSE +keyUsage = critical,digitalSignature,keyEncipherment +subjectAltName = @alt_names +extendedKeyUsage = critical,clientAuth + +[ v3_server ] +basicConstraints = critical,CA:FALSE +keyUsage = critical,digitalSignature,keyEncipherment +subjectAltName = @alt_names +extendedKeyUsage = critical,serverAuth + +[ v3_client_crt ] +basicConstraints = critical,CA:FALSE +subjectAltName = @alt_names + +[ v3_server_crt ] +basicConstraints = critical,CA:FALSE +subjectAltName = @alt_names diff --git a/examples/fl/fl/mlcube/mlcube.yaml b/examples/fl/fl/mlcube/mlcube.yaml new file mode 100644 index 000000000..65692efbb --- /dev/null +++ b/examples/fl/fl/mlcube/mlcube.yaml @@ -0,0 +1,46 @@ +name: FL MLCube +description: FL MLCube +authors: + - { name: MLCommons Medical Working Group } + +platform: + accelerator_count: 0 + +docker: + # Image name + image: mlcommons/medperf-fl:1.0.0 + # Docker build context relative to $MLCUBE_ROOT. Default is `build`. + build_context: "../project" + # Docker file name within docker build context, default is `Dockerfile`. + build_file: "Dockerfile" + +tasks: + train: + parameters: + inputs: + data_path: data/ + labels_path: labels/ + node_cert_folder: node_cert/ + ca_cert_folder: ca_cert/ + plan_path: plan.yaml + outputs: + output_logs: logs/ + start_aggregator: + parameters: + inputs: + input_weights: additional_files/init_weights + node_cert_folder: node_cert/ + ca_cert_folder: ca_cert/ + plan_path: plan.yaml + collaborators: cols.yaml + outputs: + output_logs: logs/ + output_weights: final_weights/ + report_path: { type: "file", default: "report/report.yaml" } + generate_plan: + parameters: + inputs: + training_config_path: training_config.yaml + aggregator_config_path: aggregator_config.yaml + outputs: + plan_path: { type: "file", default: "plan/plan.yaml" } diff --git a/examples/fl/fl/mlcube/workspace/training_config.yaml b/examples/fl/fl/mlcube/workspace/training_config.yaml new file mode 100644 index 000000000..0b7c17aa5 --- /dev/null +++ b/examples/fl/fl/mlcube/workspace/training_config.yaml @@ -0,0 +1,176 @@ +aggregator: + settings: + best_state_path: save/classification_best.pbuf + db_store_rounds: 2 + init_state_path: save/classification_init.pbuf + last_state_path: save/classification_last.pbuf + rounds_to_train: 2 + write_logs: true + admins_endpoints_mapping: + testfladmin@example.com: + - GetExperimentStatus + - SetStragglerCuttoffTime + template: openfl.component.Aggregator +assigner: + settings: + template : openfl.component.assigner.DynamicRandomGroupedAssigner + task_groups: + - name: train_and_validate + percentage: 1.0 + tasks: + - aggregated_model_validation + - train + - locally_tuned_model_validation + template: openfl.component.RandomGroupedAssigner +collaborator: + settings: + db_store_rounds: 1 + delta_updates: false + opt_treatment: RESET + template: openfl.component.Collaborator +compression_pipeline: + settings: {} + template: openfl.pipelines.NoCompressionPipeline +data_loader: + settings: + feature_shape: + - 128 + - 128 + template: openfl.federated.data.loader_gandlf.GaNDLFDataLoaderWrapper +network: + settings: + cert_folder: cert + client_reconnect_interval: 5 + disable_client_auth: false + hash_salt: auto + tls: true + template: openfl.federation.Network +task_runner: + settings: + device: cpu + gandlf_config: + memory_save_mode: false # + batch_size: 16 + clip_grad: null + clip_mode: null + data_augmentation: {} + data_postprocessing: {} + data_preprocessing: + resize: + - 128 + - 128 + enable_padding: false + grid_aggregator_overlap: crop + in_memory: false + inference_mechanism: + grid_aggregator_overlap: crop + patch_overlap: 0 + learning_rate: 0.001 + loss_function: cel + medcam_enabled: false + metrics: + accuracy: + average: weighted + mdmc_average: samplewise + multi_class: true + subset_accuracy: false + threshold: 0.5 + balanced_accuracy: None + classification_accuracy: None + f1: + average: weighted + f1: + average: weighted + mdmc_average: samplewise + multi_class: true + threshold: 0.5 + modality: rad + model: + amp: false + architecture: resnet18 + base_filters: 32 + batch_norm: true + class_list: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + dimension: 2 + final_layer: sigmoid + ignore_label_validation: None + n_channels: 3 + norm_type: batch + num_channels: 3 + save_at_every_epoch: false + type: torch + nested_training: + testing: 1 + validation: -5 + num_epochs: 2 + opt: adam + optimizer: + type: adam + output_dir: . + parallel_compute_command: "" + patch_sampler: uniform + patch_size: + - 128 + - 128 + - 1 + patience: 1 + pin_memory_dataloader: false + print_rgb_label_warning: true + q_max_length: 5 + q_num_workers: 0 + q_samples_per_volume: 1 + q_verbose: false + save_masks: false + save_output: false + save_training: false + scaling_factor: 1 + scheduler: + step_size: 0.0002 + type: triangle + track_memory_usage: false + verbose: false + version: + maximum: 0.0.20-dev + minimum: 0.0.20-dev + weighted_loss: true + train_csv: train_path_full.csv + val_csv: val_path_full.csv + template: openfl.federated.task.runner_gandlf.GaNDLFTaskRunner +tasks: + aggregated_model_validation: + function: validate + kwargs: + apply: global + metrics: + - valid_loss + - valid_accuracy + locally_tuned_model_validation: + function: validate + kwargs: + apply: local + metrics: + - valid_loss + - valid_accuracy + settings: {} + train: + function: train + kwargs: + epochs: 1 + metrics: + - loss + - train_accuracy + +straggler_handling_policy : + template : openfl.component.straggler_handling_functions.CutoffTimeBasedStragglerHandling + settings : + straggler_cutoff_time : 600 + minimum_reporting : 2 \ No newline at end of file diff --git a/examples/fl/fl/project/Dockerfile b/examples/fl/fl/project/Dockerfile new file mode 100644 index 000000000..0d0da62a5 --- /dev/null +++ b/examples/fl/fl/project/Dockerfile @@ -0,0 +1,22 @@ +FROM local/openfl:local + +ENV LANG C.UTF-8 +ENV CUDA_VISIBLE_DEVICES="0" + + +# install project dependencies +RUN apt-get update && apt-get install --no-install-recommends -y git zlib1g-dev libffi-dev libgl1 libgtk2.0-dev gcc g++ + +RUN pip install --no-cache-dir torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --extra-index-url https://download.pytorch.org/whl/cu118 && \ + pip install --no-cache-dir openvino-dev==2023.0.1 && \ + git clone https://github.com/mlcommons/GaNDLF.git && \ + cd GaNDLF && git checkout dd88b8883cb0e57a0ac615e9cb5be7416d0dada4 && \ + pip install --no-cache-dir -e . + +COPY ./requirements.txt /mlcube_project/requirements.txt +RUN pip install --no-cache-dir -r /mlcube_project/requirements.txt + +# Copy mlcube workspace +COPY . /mlcube_project + +ENTRYPOINT ["sh", "/mlcube_project/entrypoint.sh"] \ No newline at end of file diff --git a/examples/fl/fl/project/README.md b/examples/fl/fl/project/README.md new file mode 100644 index 000000000..1e348651b --- /dev/null +++ b/examples/fl/fl/project/README.md @@ -0,0 +1,38 @@ +# How to configure container build for your application + +- List your pip requirements in `requirements.txt` +- List your software requirements in `Dockerfile` +- Modify the functions in `hooks.py` as needed. (Explanation TBD) + +# How to configure container for custom FL software + +- Change the base Docker image as needed. +- modify `aggregator.py` and `collaborator.py` as needed. Follow the implemented schema steps. + +# How to build + +- Build the openfl base image: + +```bash +git clone https://github.com/securefederatedai/openfl.git +cd openfl +git checkout e6f3f5fd4462307b2c9431184190167aa43d962f +docker build -t local/openfl:local -f openfl-docker/Dockerfile.base . +cd .. +rm -rf openfl +``` + +- Build the MLCube + +```bash +cd .. +bash build.sh +``` + +# NOTE + +for local experiments, internal IP address or localhost will not work. Use internal fqdn. + +# How to customize + +TBD diff --git a/examples/fl/fl/project/aggregator.py b/examples/fl/fl/project/aggregator.py new file mode 100644 index 000000000..8a1e7f283 --- /dev/null +++ b/examples/fl/fl/project/aggregator.py @@ -0,0 +1,32 @@ +import os +import shutil +from subprocess import check_call +from distutils.dir_util import copy_tree + + +def start_aggregator(workspace_folder, output_logs, output_weights, report_path): + + check_call(["fx", "aggregator", "start"], cwd=workspace_folder) + + # TODO: check how to copy logs during runtime. + # perhaps investigate overriding plan entries? + + # NOTE: logs and weights are copied, even if target folders are not empty + if os.path.exists(os.path.join(workspace_folder, "logs")): + copy_tree(os.path.join(workspace_folder, "logs"), output_logs) + + # NOTE: conversion fails since openfl needs sample data... + # weights_paths = get_weights_path(fl_workspace) + # out_best = os.path.join(output_weights, "best") + # out_last = os.path.join(output_weights, "last") + # check_call( + # ["fx", "model", "save", "-i", weights_paths["best"], "-o", out_best], + # cwd=workspace_folder, + # ) + copy_tree(os.path.join(workspace_folder, "save"), output_weights) + + # Cleanup + shutil.rmtree(workspace_folder, ignore_errors=True) + + with open(report_path, "w") as f: + f.write("IsDone: 1") diff --git a/examples/fl/fl/project/collaborator.py b/examples/fl/fl/project/collaborator.py new file mode 100644 index 000000000..fb4cdd1c2 --- /dev/null +++ b/examples/fl/fl/project/collaborator.py @@ -0,0 +1,29 @@ +import os +from utils import get_collaborator_cn +import shutil +from subprocess import check_call + + +def start_collaborator(workspace_folder): + cn = get_collaborator_cn() + check_call( + [os.environ.get("OPENFL_EXECUTABLE", "fx"), "collaborator", "start", "-n", cn], + cwd=workspace_folder, + ) + + # Cleanup + shutil.rmtree(workspace_folder, ignore_errors=True) + + +def check_connectivity(workspace_folder): + cn = get_collaborator_cn() + check_call( + [ + os.environ.get("OPENFL_EXECUTABLE", "fx"), + "collaborator", + "connectivity_check", + "-n", + cn, + ], + cwd=workspace_folder, + ) diff --git a/examples/fl/fl/project/entrypoint.sh b/examples/fl/fl/project/entrypoint.sh new file mode 100644 index 000000000..7a5e15306 --- /dev/null +++ b/examples/fl/fl/project/entrypoint.sh @@ -0,0 +1 @@ +python /mlcube_project/mlcube.py $@ diff --git a/examples/fl/fl/project/hooks.py b/examples/fl/fl/project/hooks.py new file mode 100644 index 000000000..9dc59582f --- /dev/null +++ b/examples/fl/fl/project/hooks.py @@ -0,0 +1,102 @@ +import os +import pandas as pd +from utils import get_collaborator_cn + + +def __modify_df(df): + # gandlf convention: labels columns could be "target", "label", "mask" + # subject id column is subjectid. data columns are Channel_0. + # Others could be scalars. # TODO + labels_columns = ["target", "label", "mask"] + data_columns = ["channel_0"] + subject_id_column = "subjectid" + for column in df.columns: + if column.lower() == subject_id_column: + continue + if column.lower() in labels_columns: + prepend_str = "labels/" + elif column.lower() in data_columns: + prepend_str = "data/" + else: + continue + + df[column] = prepend_str + df[column].astype(str) + + +def collaborator_pre_training_hook( + data_path, + labels_path, + node_cert_folder, + ca_cert_folder, + plan_path, + output_logs, + workspace_folder, +): + cn = get_collaborator_cn() + + target_data_folder = os.path.join(workspace_folder, "data", cn) + os.makedirs(target_data_folder, exist_ok=True) + target_data_data_folder = os.path.join(target_data_folder, "data") + target_data_labels_folder = os.path.join(target_data_folder, "labels") + target_train_csv = os.path.join(target_data_folder, "train.csv") + target_valid_csv = os.path.join(target_data_folder, "valid.csv") + + os.symlink(data_path, target_data_data_folder) + os.symlink(labels_path, target_data_labels_folder) + train_csv = os.path.join(data_path, "train.csv") + valid_csv = os.path.join(data_path, "valid.csv") + + train_df = pd.read_csv(train_csv) + __modify_df(train_df) + train_df.to_csv(target_train_csv, index=False) + + valid_df = pd.read_csv(valid_csv) + __modify_df(valid_df) + valid_df.to_csv(target_valid_csv, index=False) + + data_config = f"{cn},data/{cn}" + plan_folder = os.path.join(workspace_folder, "plan") + os.makedirs(plan_folder, exist_ok=True) + data_config_path = os.path.join(plan_folder, "data.yaml") + with open(data_config_path, "w") as f: + f.write(data_config) + + +def collaborator_post_training_hook( + data_path, + labels_path, + node_cert_folder, + ca_cert_folder, + plan_path, + output_logs, + workspace_folder, +): + pass + + +def aggregator_pre_training_hook( + input_weights, + node_cert_folder, + ca_cert_folder, + output_logs, + output_weights, + plan_path, + collaborators, + report_path, + workspace_folder, +): + pass + + +def aggregator_post_training_hook( + input_weights, + node_cert_folder, + ca_cert_folder, + output_logs, + output_weights, + plan_path, + collaborators, + report_path, + workspace_folder, +): + pass diff --git a/examples/fl/fl/project/mlcube.py b/examples/fl/fl/project/mlcube.py new file mode 100644 index 000000000..064440e95 --- /dev/null +++ b/examples/fl/fl/project/mlcube.py @@ -0,0 +1,127 @@ +"""MLCube handler file""" + +import typer +from collaborator import start_collaborator, check_connectivity +from aggregator import start_aggregator +from plan import generate_plan +from hooks import ( + aggregator_pre_training_hook, + aggregator_post_training_hook, + collaborator_pre_training_hook, + collaborator_post_training_hook, +) +from utils import generic_setup, generic_teardown, setup_collaborator, setup_aggregator + +app = typer.Typer() + + +@app.command("train") +def train( + data_path: str = typer.Option(..., "--data_path"), + labels_path: str = typer.Option(..., "--labels_path"), + node_cert_folder: str = typer.Option(..., "--node_cert_folder"), + ca_cert_folder: str = typer.Option(..., "--ca_cert_folder"), + plan_path: str = typer.Option(..., "--plan_path"), + output_logs: str = typer.Option(..., "--output_logs"), +): + workspace_folder = generic_setup(output_logs) + setup_collaborator( + data_path=data_path, + labels_path=labels_path, + node_cert_folder=node_cert_folder, + ca_cert_folder=ca_cert_folder, + plan_path=plan_path, + output_logs=output_logs, + workspace_folder=workspace_folder, + ) + check_connectivity(workspace_folder) + collaborator_pre_training_hook( + data_path=data_path, + labels_path=labels_path, + node_cert_folder=node_cert_folder, + ca_cert_folder=ca_cert_folder, + plan_path=plan_path, + output_logs=output_logs, + workspace_folder=workspace_folder, + ) + start_collaborator(workspace_folder=workspace_folder) + collaborator_post_training_hook( + data_path=data_path, + labels_path=labels_path, + node_cert_folder=node_cert_folder, + ca_cert_folder=ca_cert_folder, + plan_path=plan_path, + output_logs=output_logs, + workspace_folder=workspace_folder, + ) + generic_teardown(output_logs) + + +@app.command("start_aggregator") +def start_aggregator_( + input_weights: str = typer.Option(..., "--input_weights"), + node_cert_folder: str = typer.Option(..., "--node_cert_folder"), + ca_cert_folder: str = typer.Option(..., "--ca_cert_folder"), + output_logs: str = typer.Option(..., "--output_logs"), + output_weights: str = typer.Option(..., "--output_weights"), + plan_path: str = typer.Option(..., "--plan_path"), + collaborators: str = typer.Option(..., "--collaborators"), + report_path: str = typer.Option(..., "--report_path"), +): + workspace_folder = generic_setup(output_logs) + setup_aggregator( + input_weights=input_weights, + node_cert_folder=node_cert_folder, + ca_cert_folder=ca_cert_folder, + output_logs=output_logs, + output_weights=output_weights, + plan_path=plan_path, + collaborators=collaborators, + report_path=report_path, + workspace_folder=workspace_folder, + ) + aggregator_pre_training_hook( + input_weights=input_weights, + node_cert_folder=node_cert_folder, + ca_cert_folder=ca_cert_folder, + output_logs=output_logs, + output_weights=output_weights, + plan_path=plan_path, + collaborators=collaborators, + report_path=report_path, + workspace_folder=workspace_folder, + ) + start_aggregator( + workspace_folder=workspace_folder, + output_logs=output_logs, + output_weights=output_weights, + report_path=report_path, + ) + aggregator_post_training_hook( + input_weights=input_weights, + node_cert_folder=node_cert_folder, + ca_cert_folder=ca_cert_folder, + output_logs=output_logs, + output_weights=output_weights, + plan_path=plan_path, + collaborators=collaborators, + report_path=report_path, + workspace_folder=workspace_folder, + ) + generic_teardown(output_logs) + + +@app.command("generate_plan") +def generate_plan_( + training_config_path: str = typer.Option(..., "--training_config_path"), + aggregator_config_path: str = typer.Option(..., "--aggregator_config_path"), + plan_path: str = typer.Option(..., "--plan_path"), +): + # no _setup here since there is no writable output mounted volume. + # later if need this we think of a solution. Currently the create_plam + # logic is assumed to not write within the container. + generate_plan(training_config_path, aggregator_config_path, plan_path) + + +if __name__ == "__main__": + app() diff --git a/examples/fl/fl/project/plan.py b/examples/fl/fl/project/plan.py new file mode 100644 index 000000000..2feb1bf52 --- /dev/null +++ b/examples/fl/fl/project/plan.py @@ -0,0 +1,16 @@ +import yaml + + +def generate_plan(training_config_path, aggregator_config_path, plan_path): + with open(training_config_path) as f: + training_config = yaml.safe_load(f) + with open(aggregator_config_path) as f: + aggregator_config = yaml.safe_load(f) + + # TODO: key checks. Also, define what should be considered aggregator_config + # (e.g., tls=true, reconnect_interval, ...) + training_config["network"]["settings"]["agg_addr"] = aggregator_config["address"] + training_config["network"]["settings"]["agg_port"] = aggregator_config["port"] + + with open(plan_path, "w") as f: + yaml.dump(training_config, f) diff --git a/examples/fl/fl/project/requirements.txt b/examples/fl/fl/project/requirements.txt new file mode 100644 index 000000000..c7ec9886d --- /dev/null +++ b/examples/fl/fl/project/requirements.txt @@ -0,0 +1,2 @@ +onnx==1.13.0 +typer==0.9.0 \ No newline at end of file diff --git a/examples/fl/fl/project/utils.py b/examples/fl/fl/project/utils.py new file mode 100644 index 000000000..a0da69a16 --- /dev/null +++ b/examples/fl/fl/project/utils.py @@ -0,0 +1,175 @@ +import yaml +import os +import shutil + + +def generic_setup(output_logs): + tmpfolder = os.path.join(output_logs, ".tmp") + os.makedirs(tmpfolder, exist_ok=True) + # NOTE: this should be set before any code imports tempfile + os.environ["TMPDIR"] = tmpfolder + workspace_folder = os.path.join(output_logs, "workspace") + os.makedirs(workspace_folder, exist_ok=True) + create_workspace(workspace_folder) + return workspace_folder + + +def setup_collaborator( + data_path, + labels_path, + node_cert_folder, + ca_cert_folder, + plan_path, + output_logs, + workspace_folder, +): + prepare_plan(plan_path, workspace_folder) + cn = get_collaborator_cn() + prepare_node_cert(node_cert_folder, "client", f"col_{cn}", workspace_folder) + prepare_ca_cert(ca_cert_folder, workspace_folder) + + +def setup_aggregator( + input_weights, + node_cert_folder, + ca_cert_folder, + output_logs, + output_weights, + plan_path, + collaborators, + report_path, + workspace_folder, +): + prepare_plan(plan_path, workspace_folder) + prepare_cols_list(collaborators, workspace_folder) + prepare_init_weights(input_weights, workspace_folder) + fqdn = get_aggregator_fqdn(workspace_folder) + prepare_node_cert(node_cert_folder, "server", f"agg_{fqdn}", workspace_folder) + prepare_ca_cert(ca_cert_folder, workspace_folder) + + +def generic_teardown(output_logs): + tmp_folder = os.path.join(output_logs, ".tmp") + shutil.rmtree(tmp_folder, ignore_errors=True) + + +def create_workspace(fl_workspace): + plan_folder = os.path.join(fl_workspace, "plan") + workspace_config = os.path.join(fl_workspace, ".workspace") + defaults_file = os.path.join(plan_folder, "defaults") + + os.makedirs(plan_folder, exist_ok=True) + with open(defaults_file, "w") as f: + f.write("../../workspace/plan/defaults\n\n") + with open(workspace_config, "w") as f: + f.write("current_plan_name: default\n\n") + + +def get_aggregator_fqdn(fl_workspace): + plan_path = os.path.join(fl_workspace, "plan", "plan.yaml") + plan = yaml.safe_load(open(plan_path)) + return plan["network"]["settings"]["agg_addr"].lower() + + +def get_collaborator_cn(): + return os.environ["MEDPERF_PARTICIPANT_LABEL"] + + +def get_weights_path(fl_workspace): + plan_path = os.path.join(fl_workspace, "plan", "plan.yaml") + plan = yaml.safe_load(open(plan_path)) + return { + "init": plan["aggregator"]["settings"]["init_state_path"], + "best": plan["aggregator"]["settings"]["best_state_path"], + "last": plan["aggregator"]["settings"]["last_state_path"], + } + + +def prepare_plan(plan_path, fl_workspace): + target_plan_folder = os.path.join(fl_workspace, "plan") + os.makedirs(target_plan_folder, exist_ok=True) + + target_plan_file = os.path.join(target_plan_folder, "plan.yaml") + shutil.copyfile(plan_path, target_plan_file) + + +def prepare_cols_list(collaborators_file, fl_workspace): + with open(collaborators_file) as f: + cols_dict = yaml.safe_load(f) + cn_different = False + for col_label in cols_dict.keys(): + cn = cols_dict[col_label] + if cn != col_label: + cn_different = True + if not cn_different: + # quick hack to support old and new openfl versions + cols_dict = list(cols_dict.keys()) + + target_plan_folder = os.path.join(fl_workspace, "plan") + os.makedirs(target_plan_folder, exist_ok=True) + target_plan_file = os.path.join(target_plan_folder, "cols.yaml") + with open(target_plan_file, "w") as f: + yaml.dump({"collaborators": cols_dict}, f) + + +def prepare_init_weights(input_weights, fl_workspace): + error_msg = f"{input_weights} should contain only one file: *.pbuf" + + files = os.listdir(input_weights) + file = files[0] # TODO: this may cause failure in MAC OS + if len(files) != 1 or not file.endswith(".pbuf"): + raise RuntimeError(error_msg) + + file = os.path.join(input_weights, file) + + target_weights_subpath = get_weights_path(fl_workspace)["init"] + target_weights_path = os.path.join(fl_workspace, target_weights_subpath) + target_weights_folder = os.path.dirname(target_weights_path) + os.makedirs(target_weights_folder, exist_ok=True) + os.symlink(file, target_weights_path) + + +def prepare_node_cert( + node_cert_folder, target_cert_folder_name, target_cert_name, fl_workspace +): + error_msg = f"{node_cert_folder} should contain only two files: *.crt and *.key" + + files = os.listdir(node_cert_folder) + file_extensions = [file.split(".")[-1] for file in files] + if len(files) != 2 or sorted(file_extensions) != ["crt", "key"]: + raise RuntimeError(error_msg) + + if files[0].endswith(".crt") and files[1].endswith(".key"): + cert_file = files[0] + key_file = files[1] + else: + key_file = files[0] + cert_file = files[1] + + key_file = os.path.join(node_cert_folder, key_file) + cert_file = os.path.join(node_cert_folder, cert_file) + + target_cert_folder = os.path.join(fl_workspace, "cert", target_cert_folder_name) + os.makedirs(target_cert_folder, exist_ok=True) + target_cert_file = os.path.join(target_cert_folder, f"{target_cert_name}.crt") + target_key_file = os.path.join(target_cert_folder, f"{target_cert_name}.key") + + os.symlink(key_file, target_key_file) + os.symlink(cert_file, target_cert_file) + + +def prepare_ca_cert(ca_cert_folder, fl_workspace): + error_msg = f"{ca_cert_folder} should contain only one file: *.crt" + + files = os.listdir(ca_cert_folder) + file = files[0] + if len(files) != 1 or not file.endswith(".crt"): + raise RuntimeError(error_msg) + + file = os.path.join(ca_cert_folder, file) + + target_ca_cert_folder = os.path.join(fl_workspace, "cert") + os.makedirs(target_ca_cert_folder, exist_ok=True) + target_ca_cert_file = os.path.join(target_ca_cert_folder, "cert_chain.crt") + + os.symlink(file, target_ca_cert_file) diff --git a/examples/fl/fl/setup_clean.sh b/examples/fl/fl/setup_clean.sh new file mode 100644 index 000000000..6615c2968 --- /dev/null +++ b/examples/fl/fl/setup_clean.sh @@ -0,0 +1,6 @@ +rm -rf ./mlcube_agg +rm -rf ./mlcube_col1 +rm -rf ./mlcube_col2 +rm -rf ./mlcube_col3 +rm -rf ./ca +rm -rf ./for_admin diff --git a/examples/fl/fl/setup_test.sh b/examples/fl/fl/setup_test.sh new file mode 100644 index 000000000..542dd7164 --- /dev/null +++ b/examples/fl/fl/setup_test.sh @@ -0,0 +1,93 @@ +while getopts t flag; do + case "${flag}" in + t) TWO_COL_SAME_CERT="true" ;; + esac +done +TWO_COL_SAME_CERT="${TWO_COL_SAME_CERT:-false}" + +COL1_CN="col1@example.com" +COL2_CN="col2@example.com" +COL3_CN="col3@example.com" +COL1_LABEL="col1@example.com" +COL2_LABEL="col2@example.com" +COL3_LABEL="col3@example.com" + +if ${TWO_COL_SAME_CERT}; then + COL1_CN="org1@example.com" + COL2_CN="org2@example.com" + COL3_CN="org2@example.com" # in this case this var is not used actually. it's OK +fi + +cp -r ./mlcube ./mlcube_agg +cp -r ./mlcube ./mlcube_col1 +cp -r ./mlcube ./mlcube_col2 +cp -r ./mlcube ./mlcube_col3 + +mkdir ./mlcube_agg/workspace/node_cert ./mlcube_agg/workspace/ca_cert +mkdir ./mlcube_col1/workspace/node_cert ./mlcube_col1/workspace/ca_cert +mkdir ./mlcube_col2/workspace/node_cert ./mlcube_col2/workspace/ca_cert +mkdir ./mlcube_col3/workspace/node_cert ./mlcube_col3/workspace/ca_cert +mkdir ./ca + +HOSTNAME_=$(hostname -A | cut -d " " -f 1) + +medperf mlcube run --mlcube ../mock_cert/mlcube --task trust +mv ../mock_cert/mlcube/workspace/pki_assets/* ./ca + +# col1 +medperf mlcube run --mlcube ../mock_cert/mlcube --task get_client_cert -e MEDPERF_INPUT_CN=$COL1_CN +mv ../mock_cert/mlcube/workspace/pki_assets/* ./mlcube_col1/workspace/node_cert +cp -r ./ca/* ./mlcube_col1/workspace/ca_cert + +# col2 +medperf mlcube run --mlcube ../mock_cert/mlcube --task get_client_cert -e MEDPERF_INPUT_CN=$COL2_CN +mv ../mock_cert/mlcube/workspace/pki_assets/* ./mlcube_col2/workspace/node_cert +cp -r ./ca/* ./mlcube_col2/workspace/ca_cert + +# col3 +if ${TWO_COL_SAME_CERT}; then + cp mlcube_col2/workspace/node_cert/* mlcube_col3/workspace/node_cert + cp mlcube_col2/workspace/ca_cert/* mlcube_col3/workspace/ca_cert +else + medperf mlcube run --mlcube ../mock_cert/mlcube --task get_client_cert -e MEDPERF_INPUT_CN=$COL3_CN + mv ../mock_cert/mlcube/workspace/pki_assets/* ./mlcube_col3/workspace/node_cert + cp -r ./ca/* ./mlcube_col3/workspace/ca_cert +fi + +medperf mlcube run --mlcube ../mock_cert/mlcube --task get_server_cert -e MEDPERF_INPUT_CN=$HOSTNAME_ +mv ../mock_cert/mlcube/workspace/pki_assets/* ./mlcube_agg/workspace/node_cert +cp -r ./ca/* ./mlcube_agg/workspace/ca_cert + +# aggregator_config +echo "address: $HOSTNAME_" >>mlcube_agg/workspace/aggregator_config.yaml +echo "port: 50273" >>mlcube_agg/workspace/aggregator_config.yaml + +# cols file +echo "$COL1_LABEL: $COL1_CN" >>mlcube_agg/workspace/cols.yaml +echo "$COL2_LABEL: $COL2_CN" >>mlcube_agg/workspace/cols.yaml +echo "$COL3_LABEL: $COL3_CN" >>mlcube_agg/workspace/cols.yaml + +# data download +cd mlcube_col1/workspace/ +wget https://storage.googleapis.com/medperf-storage/testfl/col1_prepared.tar.gz +tar -xf col1_prepared.tar.gz +rm col1_prepared.tar.gz +cd ../.. + +cd mlcube_col2/workspace/ +wget https://storage.googleapis.com/medperf-storage/testfl/col2_prepared.tar.gz +tar -xf col2_prepared.tar.gz +rm col2_prepared.tar.gz +cd ../.. + +cp -r mlcube_col2/workspace/data mlcube_col3/workspace +cp -r mlcube_col2/workspace/labels mlcube_col3/workspace + +# weights download +cd mlcube_agg/workspace/ +mkdir additional_files +cd additional_files +wget https://storage.googleapis.com/medperf-storage/testfl/init_weights_miccai.tar.gz +tar -xf init_weights_miccai.tar.gz +rm init_weights_miccai.tar.gz +cd ../../.. diff --git a/examples/fl/fl/setup_test_no_docker.sh b/examples/fl/fl/setup_test_no_docker.sh new file mode 100644 index 000000000..606847a54 --- /dev/null +++ b/examples/fl/fl/setup_test_no_docker.sh @@ -0,0 +1,142 @@ +while getopts t flag; do + case "${flag}" in + t) TWO_COL_SAME_CERT="true" ;; + esac +done +TWO_COL_SAME_CERT="${TWO_COL_SAME_CERT:-false}" + +COL1_CN="col1@example.com" +COL2_CN="col2@example.com" +COL3_CN="col3@example.com" +COL1_LABEL="col1@example.com" +COL2_LABEL="col2@example.com" +COL3_LABEL="col3@example.com" + +if ${TWO_COL_SAME_CERT}; then + COL1_CN="org1@example.com" + COL2_CN="org2@example.com" + COL3_CN="org2@example.com" # in this case this var is not used actually. it's OK +fi + +cp -r ./mlcube ./mlcube_agg +cp -r ./mlcube ./mlcube_col1 +cp -r ./mlcube ./mlcube_col2 +cp -r ./mlcube ./mlcube_col3 + +mkdir ./mlcube_agg/workspace/node_cert ./mlcube_agg/workspace/ca_cert +mkdir ./mlcube_col1/workspace/node_cert ./mlcube_col1/workspace/ca_cert +mkdir ./mlcube_col2/workspace/node_cert ./mlcube_col2/workspace/ca_cert +mkdir ./mlcube_col3/workspace/node_cert ./mlcube_col3/workspace/ca_cert +mkdir ./ca + +HOSTNAME_=$(hostname -A | cut -d " " -f 1) + +# root ca +openssl genpkey -algorithm RSA -out ca/root.key -pkeyopt rsa_keygen_bits:3072 +openssl req -x509 -new -nodes -key ca/root.key -sha384 -days 36500 -out ca/root.crt \ + -subj "/DC=org/DC=simple/CN=Simple Root CA/O=Simple Inc/OU=Simple Root CA" + +# col1 +sed -i "/^commonName = /c\commonName = $COL1_CN" csr.conf +sed -i "/^DNS\.1 = /c\DNS.1 = $COL1_CN" csr.conf +cd mlcube_col1/workspace/node_cert +openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 +openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_client +openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ + -CAcreateserial -out crt.crt -days 36500 -sha384 -extensions v3_client_crt -extfile ../../../csr.conf +rm csr.csr +cp ../../../ca/root.crt ../ca_cert/ +cd ../../../ + +# col2 +sed -i "/^commonName = /c\commonName = $COL2_CN" csr.conf +sed -i "/^DNS\.1 = /c\DNS.1 = $COL2_CN" csr.conf +cd mlcube_col2/workspace/node_cert +openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 +openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_client +openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ + -CAcreateserial -out crt.crt -days 36500 -sha384 -extensions v3_client_crt -extfile ../../../csr.conf +rm csr.csr +cp ../../../ca/root.crt ../ca_cert/ +cd ../../../ + +# col3 +if ${TWO_COL_SAME_CERT}; then + cp mlcube_col2/workspace/node_cert/* mlcube_col3/workspace/node_cert + cp mlcube_col2/workspace/ca_cert/* mlcube_col3/workspace/ca_cert +else + sed -i "/^commonName = /c\commonName = $COL3_CN" csr.conf + sed -i "/^DNS\.1 = /c\DNS.1 = $COL3_CN" csr.conf + cd mlcube_col3/workspace/node_cert + openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 + openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_client + openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ + -CAcreateserial -out crt.crt -days 36500 -sha384 -extensions v3_client_crt -extfile ../../../csr.conf + rm csr.csr + cp ../../../ca/root.crt ../ca_cert/ + cd ../../../ +fi + +# agg +sed -i "/^commonName = /c\commonName = $HOSTNAME_" csr.conf +sed -i "/^DNS\.1 = /c\DNS.1 = $HOSTNAME_" csr.conf +cd mlcube_agg/workspace/node_cert +openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 +openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_server +openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ + -CAcreateserial -out crt.crt -days 36500 -sha384 -extensions v3_server_crt -extfile ../../../csr.conf +rm csr.csr +cp ../../../ca/root.crt ../ca_cert/ +cd ../../../ + +# aggregator_config +echo "address: $HOSTNAME_" >>mlcube_agg/workspace/aggregator_config.yaml +echo "port: 50273" >>mlcube_agg/workspace/aggregator_config.yaml + +# cols file +echo "$COL1_LABEL: $COL1_CN" >>mlcube_agg/workspace/cols.yaml +echo "$COL2_LABEL: $COL2_CN" >>mlcube_agg/workspace/cols.yaml +echo "$COL3_LABEL: $COL3_CN" >>mlcube_agg/workspace/cols.yaml + +# data download +cd mlcube_col1/workspace/ +wget https://storage.googleapis.com/medperf-storage/testfl/col1_prepared.tar.gz +tar -xf col1_prepared.tar.gz +rm col1_prepared.tar.gz +cd ../.. + +cd mlcube_col2/workspace/ +wget https://storage.googleapis.com/medperf-storage/testfl/col2_prepared.tar.gz +tar -xf col2_prepared.tar.gz +rm col2_prepared.tar.gz +cd ../.. + +cp -r mlcube_col2/workspace/data mlcube_col3/workspace +cp -r mlcube_col2/workspace/labels mlcube_col3/workspace + +# weights download +cd mlcube_agg/workspace/ +mkdir additional_files +cd additional_files +wget https://storage.googleapis.com/medperf-storage/testfl/init_weights_miccai.tar.gz +tar -xf init_weights_miccai.tar.gz +rm init_weights_miccai.tar.gz +cd ../../.. + +# for admin +ADMIN_CN="testfladmin@example.com" + +mkdir ./for_admin +mkdir ./for_admin/node_cert + +sed -i "/^commonName = /c\commonName = $ADMIN_CN" csr.conf +sed -i "/^DNS\.1 = /c\DNS.1 = $ADMIN_CN" csr.conf +cd for_admin/node_cert +openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 +openssl req -new -key key.key -out csr.csr -config ../../csr.conf -extensions v3_client +openssl x509 -req -in csr.csr -CA ../../ca/root.crt -CAkey ../../ca/root.key \ + -CAcreateserial -out crt.crt -days 36500 -sha384 -extensions v3_client_crt -extfile ../../csr.conf +rm csr.csr +mkdir ../ca_cert +cp -r ../../ca/root.crt ../ca_cert/root.crt +cd ../.. diff --git a/examples/fl/fl/sync.sh b/examples/fl/fl/sync.sh new file mode 100644 index 000000000..a5375ce54 --- /dev/null +++ b/examples/fl/fl/sync.sh @@ -0,0 +1,6 @@ +cp mlcube/workspace/training_config.yaml mlcube_agg/workspace/training_config.yaml + +cp mlcube/mlcube.yaml mlcube_agg/mlcube.yaml +cp mlcube/mlcube.yaml mlcube_col1/mlcube.yaml +cp mlcube/mlcube.yaml mlcube_col2/mlcube.yaml +cp mlcube/mlcube.yaml mlcube_col3/mlcube.yaml diff --git a/examples/fl/fl/test.sh b/examples/fl/fl/test.sh new file mode 100644 index 000000000..ae856d794 --- /dev/null +++ b/examples/fl/fl/test.sh @@ -0,0 +1,31 @@ +# generate plan and copy it to each node +medperf mlcube run --mlcube ./mlcube_agg --task generate_plan +mv ./mlcube_agg/workspace/plan/plan.yaml ./mlcube_agg/workspace +rm -r ./mlcube_agg/workspace/plan +cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col1/workspace +cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col2/workspace +cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col3/workspace +cp ./mlcube_agg/workspace/plan.yaml ./for_admin + +# Run nodes +AGG="medperf mlcube run --mlcube ./mlcube_agg --task start_aggregator -P 50273" +COL1="medperf mlcube run --mlcube ./mlcube_col1 --task train -e MEDPERF_PARTICIPANT_LABEL=col1@example.com" +COL2="medperf mlcube run --mlcube ./mlcube_col2 --task train -e MEDPERF_PARTICIPANT_LABEL=col2@example.com" +COL3="medperf mlcube run --mlcube ./mlcube_col3 --task train -e MEDPERF_PARTICIPANT_LABEL=col3@example.com" + +# gnome-terminal -- bash -c "$AGG; bash" +# gnome-terminal -- bash -c "$COL1; bash" +# gnome-terminal -- bash -c "$COL2; bash" +# gnome-terminal -- bash -c "$COL3; bash" +rm agg.log col1.log col2.log col3.log +$AGG >>agg.log & +sleep 6 +$COL1 >>col1.log & +sleep 6 +$COL2 >>col2.log & +sleep 6 +$COL3 >>col3.log & +wait + +# docker run --env MEDPERF_PARTICIPANT_LABEL=col1@example.com --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/data:/mlcube_io0:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/labels:/mlcube_io1:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/node_cert:/mlcube_io2:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/ca_cert:/mlcube_io3:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace:/mlcube_io4:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/logs:/mlcube_io5 -it --entrypoint bash mlcommons/medperf-fl:1.0.0 +# python /mlcube_project/mlcube.py train --data_path=/mlcube_io0 --labels_path=/mlcube_io1 --node_cert_folder=/mlcube_io2 --ca_cert_folder=/mlcube_io3 --plan_path=/mlcube_io4/plan.yaml --output_logs=/mlcube_io5 diff --git a/examples/fl/fl_admin/.gitignore b/examples/fl/fl_admin/.gitignore new file mode 100644 index 000000000..6bd8bf2e2 --- /dev/null +++ b/examples/fl/fl_admin/.gitignore @@ -0,0 +1,3 @@ +mlcube_* +ca +quick* diff --git a/examples/fl/fl_admin/README.md b/examples/fl/fl_admin/README.md new file mode 100644 index 000000000..918f483e3 --- /dev/null +++ b/examples/fl/fl_admin/README.md @@ -0,0 +1,6 @@ +# How to run tests + +- Run `setup_test.sh` just once to create certs and download required data. +- Run `test.sh` to start the aggregator and three collaborators. +- Run `clean.sh` to be able to rerun `test.sh` freshly. +- Run `setup_clean.sh` to clear what has been generated in step 1. diff --git a/examples/fl/fl_admin/build.sh b/examples/fl/fl_admin/build.sh new file mode 100644 index 000000000..28e76c014 --- /dev/null +++ b/examples/fl/fl_admin/build.sh @@ -0,0 +1,16 @@ +while getopts b flag; do + case "${flag}" in + b) BUILD_BASE="true" ;; + esac +done +BUILD_BASE="${BUILD_BASE:-false}" + +if ${BUILD_BASE}; then + git clone https://github.com/hasan7n/openfl.git + cd openfl + git checkout 9467f829687b6284a6e380d31f90d31bc9de023f + docker build -t local/openfl:local -f openfl-docker/Dockerfile.base . + cd .. + rm -rf openfl +fi +mlcube configure --mlcube ./mlcube -Pdocker.build_strategy=always diff --git a/examples/fl/fl_admin/clean.sh b/examples/fl/fl_admin/clean.sh new file mode 100644 index 000000000..582a3249e --- /dev/null +++ b/examples/fl/fl_admin/clean.sh @@ -0,0 +1,2 @@ +rm -rf mlcube_admin/workspace/status +rm -rf mlcube_admin/workspace/tmp diff --git a/examples/fl/fl_admin/mlcube/mlcube.yaml b/examples/fl/fl_admin/mlcube/mlcube.yaml new file mode 100644 index 000000000..7e0394c59 --- /dev/null +++ b/examples/fl/fl_admin/mlcube/mlcube.yaml @@ -0,0 +1,50 @@ +name: FL MLCube +description: FL MLCube +authors: + - { name: MLCommons Medical Working Group } + +platform: + accelerator_count: 0 + +docker: + # Image name + image: mlcommons/medperf-fl-admin:1.0.0 + # Docker build context relative to $MLCUBE_ROOT. Default is `build`. + build_context: "../project" + # Docker file name within docker build context, default is `Dockerfile`. + build_file: "Dockerfile" + +tasks: + get_experiment_status: + parameters: + inputs: + node_cert_folder: node_cert/ + ca_cert_folder: ca_cert/ + plan_path: plan.yaml + outputs: + output_status_file: { type: "file", default: "status/status.yaml" } + temp_dir: tmp/ + add_collaborator: + parameters: + inputs: + node_cert_folder: node_cert/ + ca_cert_folder: ca_cert/ + plan_path: plan.yaml + outputs: + temp_dir: tmp/ + remove_collaborator: + parameters: + inputs: + node_cert_folder: node_cert/ + ca_cert_folder: ca_cert/ + plan_path: plan.yaml + outputs: + temp_dir: tmp/ + update_plan: + parameters: + inputs: + node_cert_folder: node_cert/ + ca_cert_folder: ca_cert/ + plan_path: plan.yaml + outputs: + temp_dir: tmp/ \ No newline at end of file diff --git a/examples/fl/fl_admin/mlcube/workspace/plan.yaml b/examples/fl/fl_admin/mlcube/workspace/plan.yaml new file mode 100644 index 000000000..202a615ff --- /dev/null +++ b/examples/fl/fl_admin/mlcube/workspace/plan.yaml @@ -0,0 +1,177 @@ +aggregator: + settings: + best_state_path: save/classification_best.pbuf + db_store_rounds: 2 + init_state_path: save/classification_init.pbuf + last_state_path: save/classification_last.pbuf + rounds_to_train: 2 + write_logs: true + admins_endpoints_mapping: + testfladmin@example.com: + - GetExperimentStatus + - SetStragglerCuttoffTime + + template: openfl.component.Aggregator +assigner: + settings: + template : openfl.component.assigner.DynamicRandomGroupedAssigner + task_groups: + - name: train_and_validate + percentage: 1.0 + tasks: + - aggregated_model_validation + - train + - locally_tuned_model_validation + template: openfl.component.RandomGroupedAssigner +collaborator: + settings: + db_store_rounds: 1 + delta_updates: false + opt_treatment: RESET + template: openfl.component.Collaborator +compression_pipeline: + settings: {} + template: openfl.pipelines.NoCompressionPipeline +data_loader: + settings: + feature_shape: + - 128 + - 128 + template: openfl.federated.data.loader_gandlf.GaNDLFDataLoaderWrapper +network: + settings: + cert_folder: cert + client_reconnect_interval: 5 + disable_client_auth: false + hash_salt: auto + tls: true + template: openfl.federation.Network +task_runner: + settings: + device: cpu + gandlf_config: + memory_save_mode: false # + batch_size: 16 + clip_grad: null + clip_mode: null + data_augmentation: {} + data_postprocessing: {} + data_preprocessing: + resize: + - 128 + - 128 + enable_padding: false + grid_aggregator_overlap: crop + in_memory: false + inference_mechanism: + grid_aggregator_overlap: crop + patch_overlap: 0 + learning_rate: 0.001 + loss_function: cel + medcam_enabled: false + metrics: + accuracy: + average: weighted + mdmc_average: samplewise + multi_class: true + subset_accuracy: false + threshold: 0.5 + balanced_accuracy: None + classification_accuracy: None + f1: + average: weighted + f1: + average: weighted + mdmc_average: samplewise + multi_class: true + threshold: 0.5 + modality: rad + model: + amp: false + architecture: resnet18 + base_filters: 32 + batch_norm: true + class_list: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + dimension: 2 + final_layer: sigmoid + ignore_label_validation: None + n_channels: 3 + norm_type: batch + num_channels: 3 + save_at_every_epoch: false + type: torch + nested_training: + testing: 1 + validation: -5 + num_epochs: 2 + opt: adam + optimizer: + type: adam + output_dir: . + parallel_compute_command: "" + patch_sampler: uniform + patch_size: + - 128 + - 128 + - 1 + patience: 1 + pin_memory_dataloader: false + print_rgb_label_warning: true + q_max_length: 5 + q_num_workers: 0 + q_samples_per_volume: 1 + q_verbose: false + save_masks: false + save_output: false + save_training: false + scaling_factor: 1 + scheduler: + step_size: 0.0002 + type: triangle + track_memory_usage: false + verbose: false + version: + maximum: 0.0.20-dev + minimum: 0.0.20-dev + weighted_loss: true + train_csv: train_path_full.csv + val_csv: val_path_full.csv + template: openfl.federated.task.runner_gandlf.GaNDLFTaskRunner +tasks: + aggregated_model_validation: + function: validate + kwargs: + apply: global + metrics: + - valid_loss + - valid_accuracy + locally_tuned_model_validation: + function: validate + kwargs: + apply: local + metrics: + - valid_loss + - valid_accuracy + settings: {} + train: + function: train + kwargs: + epochs: 1 + metrics: + - loss + - train_accuracy + +straggler_handling_policy : + template : openfl.component.straggler_handling_functions.CutoffTimeBasedStragglerHandling + settings : + straggler_cutoff_time : 600 + minimum_reporting : 2 \ No newline at end of file diff --git a/examples/fl/fl_admin/project/Dockerfile b/examples/fl/fl_admin/project/Dockerfile new file mode 100644 index 000000000..fc7afaf2e --- /dev/null +++ b/examples/fl/fl_admin/project/Dockerfile @@ -0,0 +1,11 @@ +FROM local/openfl:local + +ENV LANG C.UTF-8 + +COPY ./requirements.txt /mlcube_project/requirements.txt +RUN pip install --no-cache-dir -r /mlcube_project/requirements.txt + +# Copy mlcube workspace +COPY . /mlcube_project + +ENTRYPOINT ["python", "/mlcube_project/mlcube.py"] \ No newline at end of file diff --git a/examples/fl/fl_admin/project/README.md b/examples/fl/fl_admin/project/README.md new file mode 100644 index 000000000..f9ee6768d --- /dev/null +++ b/examples/fl/fl_admin/project/README.md @@ -0,0 +1,35 @@ +# How to configure container build for your application + +- (Explanation TBD) + +# How to configure container for custom FL software + +- (Explanation TBD) + +# How to build + +- Build the openfl base image: + +```bash +git clone https://github.com/securefederatedai/openfl.git +cd openfl +git checkout e6f3f5fd4462307b2c9431184190167aa43d962f +docker build -t local/openfl:local -f openfl-docker/Dockerfile.base . +cd .. +rm -rf openfl +``` + +- Build the MLCube + +```bash +cd .. +bash build.sh +``` + +# NOTE + +for local experiments, internal IP address or localhost will not work. Use internal fqdn. + +# How to customize + +TBD diff --git a/examples/fl/fl_admin/project/admin.py b/examples/fl/fl_admin/project/admin.py new file mode 100644 index 000000000..fd3aa85b2 --- /dev/null +++ b/examples/fl/fl_admin/project/admin.py @@ -0,0 +1,83 @@ +from subprocess import check_call +from utils import ( + get_col_label_to_add, + get_col_cn_to_add, + get_col_label_to_remove, + get_col_cn_to_remove, + get_update_field_name, + get_update_value_name, +) +from update_plan import set_straggler_cutoff_time, set_dynamic_task_arg + + +def get_experiment_status(workspace_folder, admin_cn, output_status_file): + check_call( + [ + "fx", + "admin", + "get_experiment_status", + "-n", + admin_cn, + "--output_file", + output_status_file, + ], + cwd=workspace_folder, + ) + + +def add_collaborator(workspace_folder, admin_cn): + col_label = get_col_label_to_add() + col_cn = get_col_cn_to_add() + check_call( + [ + "fx", + "admin", + "add_collaborator", + "-n", + admin_cn, + "--col_label", + col_label, + "--col_cn", + col_cn, + ], + cwd=workspace_folder, + ) + + +def remove_collaborator(workspace_folder, admin_cn): + col_label = get_col_label_to_remove() + col_cn = get_col_cn_to_remove() + check_call( + [ + "fx", + "admin", + "remove_collaborator", + "-n", + admin_cn, + "--col_label", + col_label, + "--col_cn", + col_cn, + ], + cwd=workspace_folder, + ) + + +def update_plan(workspace_folder, admin_cn): + field_name = get_update_field_name() + field_value = get_update_value_name() + if field_name == "straggler_handling_policy.settings.straggler_cutoff_time": + set_straggler_cutoff_time(workspace_folder, admin_cn, field_value) + elif field_name.startswith("dynamictaskargs"): + assert field_name in [ + "dynamictaskargs.train.train_cutoff_time", + "dynamictaskargs.train.val_cutoff_time", + "dynamictaskargs.train.train_completion_dampener", + "dynamictaskargs.aggregated_model_validation.val_cutoff_time", + ] + _, task_name, arg_name = field_name.strip().split(".") + set_dynamic_task_arg( + workspace_folder, admin_cn, task_name, arg_name, field_value + ) + else: + raise ValueError(f"Unsupported field name: {field_name}") diff --git a/examples/fl/fl_admin/project/mlcube.py b/examples/fl/fl_admin/project/mlcube.py new file mode 100644 index 000000000..7e412f743 --- /dev/null +++ b/examples/fl/fl_admin/project/mlcube.py @@ -0,0 +1,92 @@ +"""MLCube handler file""" + +import os +import shutil +import typer +from utils import setup_ws +from admin import ( + get_experiment_status, + add_collaborator, + remove_collaborator, + update_plan, +) + +app = typer.Typer() + + +def _setup(temp_dir): + tmp_folder = os.path.join(temp_dir, ".tmp") + os.makedirs(tmp_folder, exist_ok=True) + # TODO: this should be set before any code imports tempfile + os.environ["TMPDIR"] = tmp_folder + os.environ["GRPC_VERBOSITY"] = "ERROR" + + +def _teardown(temp_dir): + tmp_folder = os.path.join(temp_dir, ".tmp") + shutil.rmtree(tmp_folder, ignore_errors=True) + + +@app.command("get_experiment_status") +def get_experiment_status_( + node_cert_folder: str = typer.Option(..., "--node_cert_folder"), + ca_cert_folder: str = typer.Option(..., "--ca_cert_folder"), + plan_path: str = typer.Option(..., "--plan_path"), + output_status_file: str = typer.Option(..., "--output_status_file"), + temp_dir: str = typer.Option(..., "--temp_dir"), +): + _setup(temp_dir) + workspace_folder, admin_cn = setup_ws( + node_cert_folder, ca_cert_folder, plan_path, temp_dir + ) + get_experiment_status(workspace_folder, admin_cn, output_status_file) + _teardown(temp_dir) + + +@app.command("add_collaborator") +def add_collaborator_( + node_cert_folder: str = typer.Option(..., "--node_cert_folder"), + ca_cert_folder: str = typer.Option(..., "--ca_cert_folder"), + plan_path: str = typer.Option(..., "--plan_path"), + temp_dir: str = typer.Option(..., "--temp_dir"), +): + _setup(temp_dir) + workspace_folder, admin_cn = setup_ws( + node_cert_folder, ca_cert_folder, plan_path, temp_dir + ) + add_collaborator(workspace_folder, admin_cn) + _teardown(temp_dir) + + +@app.command("remove_collaborator") +def remove_collaborator_( + node_cert_folder: str = typer.Option(..., "--node_cert_folder"), + ca_cert_folder: str = typer.Option(..., "--ca_cert_folder"), + plan_path: str = typer.Option(..., "--plan_path"), + temp_dir: str = typer.Option(..., "--temp_dir"), +): + _setup(temp_dir) + workspace_folder, admin_cn = setup_ws( + node_cert_folder, ca_cert_folder, plan_path, temp_dir + ) + remove_collaborator(workspace_folder, admin_cn) + _teardown(temp_dir) + + +@app.command("update_plan") +def update_plan_( + node_cert_folder: str = typer.Option(..., "--node_cert_folder"), + ca_cert_folder: str = typer.Option(..., "--ca_cert_folder"), + plan_path: str = typer.Option(..., "--plan_path"), + temp_dir: str = typer.Option(..., "--temp_dir"), +): + _setup(temp_dir) + workspace_folder, admin_cn = setup_ws( + node_cert_folder, ca_cert_folder, plan_path, temp_dir + ) + update_plan(workspace_folder, admin_cn) + _teardown(temp_dir) + + +if __name__ == "__main__": + app() diff --git a/examples/fl/fl_admin/project/requirements.txt b/examples/fl/fl_admin/project/requirements.txt new file mode 100644 index 000000000..92c979407 --- /dev/null +++ b/examples/fl/fl_admin/project/requirements.txt @@ -0,0 +1 @@ +typer==0.9.0 \ No newline at end of file diff --git a/examples/fl/fl_admin/project/update_plan.py b/examples/fl/fl_admin/project/update_plan.py new file mode 100644 index 000000000..4c3c1b9c8 --- /dev/null +++ b/examples/fl/fl_admin/project/update_plan.py @@ -0,0 +1,44 @@ +from subprocess import check_call + + +def set_straggler_cutoff_time(workspace_folder, admin_cn, field_value): + if not field_value.isnumeric(): + raise TypeError( + f"Expected an integer for straggler cutoff time, got {field_value}" + ) + check_call( + [ + "fx", + "admin", + "set_straggler_cutoff_time", + "-n", + admin_cn, + "--timeout_in_seconds", + field_value, + ], + cwd=workspace_folder, + ) + + +def set_dynamic_task_arg(workspace_folder, admin_cn, task_name, arg_name, field_value): + try: + float(field_value) + except ValueError: + TypeError(f"Expected a float for dynamic task arg, got {field_value}") + + check_call( + [ + "fx", + "admin", + "set_dynamic_task_arg", + "-n", + admin_cn, + "--task_name", + task_name, + "--arg_name", + arg_name, + "--value", + field_value, + ], + cwd=workspace_folder, + ) diff --git a/examples/fl/fl_admin/project/utils.py b/examples/fl/fl_admin/project/utils.py new file mode 100644 index 000000000..eaa39dc4c --- /dev/null +++ b/examples/fl/fl_admin/project/utils.py @@ -0,0 +1,106 @@ +import os +import shutil + + +def setup_ws(node_cert_folder, ca_cert_folder, plan_path, temp_dir): + workspace_folder = os.path.join(temp_dir, "workspace") + create_workspace(workspace_folder) + prepare_plan(plan_path, workspace_folder) + cn = get_admin_cn() + prepare_node_cert(node_cert_folder, "client", f"col_{cn}", workspace_folder) + prepare_ca_cert(ca_cert_folder, workspace_folder) + return workspace_folder, cn + + +def create_workspace(fl_workspace): + plan_folder = os.path.join(fl_workspace, "plan") + workspace_config = os.path.join(fl_workspace, ".workspace") + defaults_file = os.path.join(plan_folder, "defaults") + + os.makedirs(plan_folder, exist_ok=True) + with open(defaults_file, "w") as f: + f.write("../../workspace/plan/defaults\n\n") + with open(workspace_config, "w") as f: + f.write("current_plan_name: default\n\n") + + +def get_admin_cn(): + return os.environ["MEDPERF_ADMIN_PARTICIPANT_CN"] + + +def get_col_label_to_add(): + return os.environ["MEDPERF_COLLABORATOR_LABEL_TO_ADD"] + + +def get_col_cn_to_add(): + return os.environ["MEDPERF_COLLABORATOR_CN_TO_ADD"] + + +def get_col_label_to_remove(): + return os.environ["MEDPERF_COLLABORATOR_LABEL_TO_REMOVE"] + + +def get_col_cn_to_remove(): + return os.environ["MEDPERF_COLLABORATOR_CN_TO_REMOVE"] + + +def get_update_field_name(): + return os.environ["MEDPERF_UPDATE_FIELD_NAME"] + + +def get_update_value_name(): + return os.environ["MEDPERF_UPDATE_FIELD_VALUE"] + + +def prepare_plan(plan_path, fl_workspace): + target_plan_folder = os.path.join(fl_workspace, "plan") + os.makedirs(target_plan_folder, exist_ok=True) + + target_plan_file = os.path.join(target_plan_folder, "plan.yaml") + shutil.copyfile(plan_path, target_plan_file) + + +def prepare_node_cert( + node_cert_folder, target_cert_folder_name, target_cert_name, fl_workspace +): + error_msg = f"{node_cert_folder} should contain only two files: *.crt and *.key" + + files = os.listdir(node_cert_folder) + file_extensions = [file.split(".")[-1] for file in files] + if len(files) != 2 or sorted(file_extensions) != ["crt", "key"]: + raise RuntimeError(error_msg) + + if files[0].endswith(".crt") and files[1].endswith(".key"): + cert_file = files[0] + key_file = files[1] + else: + key_file = files[0] + cert_file = files[1] + + key_file = os.path.join(node_cert_folder, key_file) + cert_file = os.path.join(node_cert_folder, cert_file) + + target_cert_folder = os.path.join(fl_workspace, "cert", target_cert_folder_name) + os.makedirs(target_cert_folder, exist_ok=True) + target_cert_file = os.path.join(target_cert_folder, f"{target_cert_name}.crt") + target_key_file = os.path.join(target_cert_folder, f"{target_cert_name}.key") + + os.symlink(key_file, target_key_file) + os.symlink(cert_file, target_cert_file) + + +def prepare_ca_cert(ca_cert_folder, fl_workspace): + error_msg = f"{ca_cert_folder} should contain only one file: *.crt" + + files = os.listdir(ca_cert_folder) + file = files[0] + if len(files) != 1 or not file.endswith(".crt"): + raise RuntimeError(error_msg) + + file = os.path.join(ca_cert_folder, file) + + target_ca_cert_folder = os.path.join(fl_workspace, "cert") + os.makedirs(target_ca_cert_folder, exist_ok=True) + target_ca_cert_file = os.path.join(target_ca_cert_folder, "cert_chain.crt") + + os.symlink(file, target_ca_cert_file) diff --git a/examples/fl/fl_admin/setup_clean.sh b/examples/fl/fl_admin/setup_clean.sh new file mode 100644 index 000000000..b82c06ab4 --- /dev/null +++ b/examples/fl/fl_admin/setup_clean.sh @@ -0,0 +1 @@ +rm -rf ./mlcube_admin diff --git a/examples/fl/fl_admin/setup_test_no_docker.sh b/examples/fl/fl_admin/setup_test_no_docker.sh new file mode 100644 index 000000000..183d0f368 --- /dev/null +++ b/examples/fl/fl_admin/setup_test_no_docker.sh @@ -0,0 +1,8 @@ +cp -r ./mlcube ./mlcube_admin + +# Get your node cert folder and ca cert folder from the aggregator setup. Modify paths as needed. +cp -r ../../fl_post/fl/for_admin/node_cert ./mlcube_admin/workspace/node_cert +cp -r ../../fl_post/fl/for_admin/ca_cert ./mlcube_admin/workspace/ca_cert + +# Note that you should use the same plan used in the federation +cp ../../fl_post/fl/for_admin/plan.yaml ./mlcube_admin/workspace/plan.yaml diff --git a/examples/fl/fl_admin/test.sh b/examples/fl/fl_admin/test.sh new file mode 100644 index 000000000..93412bb50 --- /dev/null +++ b/examples/fl/fl_admin/test.sh @@ -0,0 +1,44 @@ +# Make sure an aggregator is up somewhere, and it is configured to +# accept admin@example.com as an admin and to allow any endpoints you are willing to test + +# Uncommend and test + +# GET EXPERIMENT STATUS +env_arg1="MEDPERF_ADMIN_PARTICIPANT_CN=col1@example.com" +env_args=$env_arg1 +medperf mlcube run --mlcube ./mlcube_admin --task get_experiment_status \ + -e $env_args + +## ADD COLLABORATOR +# env_arg1="MEDPERF_ADMIN_PARTICIPANT_CN=admin@example.com" +# env_arg2="MEDPERF_COLLABORATOR_LABEL_TO_ADD=col3@example.com" +# env_arg3="MEDPERF_COLLABORATOR_CN_TO_ADD=col3@example.com" +# env_args="$env_arg1,$env_arg2,$env_arg3" +# medperf mlcube run --mlcube ./mlcube_admin --task add_collaborator \ +# -e $env_args + +## REMOVE COLLABORATOR +# env_arg1="MEDPERF_ADMIN_PARTICIPANT_CN=admin@example.com" +# env_arg2="MEDPERF_COLLABORATOR_LABEL_TO_REMOVE=col3@example.com" +# env_arg3="MEDPERF_COLLABORATOR_CN_TO_REMOVE=col3@example.com" +# env_args="$env_arg1,$env_arg2,$env_arg3" +# medperf mlcube run --mlcube ./mlcube_admin --task remove_collaborator \ +# -e $env_args + +# # SET STRAGGLER CUTOFF +# env_arg1="MEDPERF_ADMIN_PARTICIPANT_CN=admin@example.com" +# env_arg2="MEDPERF_UPDATE_FIELD_NAME=straggler_handling_policy.settings.straggler_cutoff_time" +# env_arg3="MEDPERF_UPDATE_FIELD_VALUE=1200" + +# env_args="$env_arg1,$env_arg2,$env_arg3" +# medperf mlcube run --mlcube ./mlcube_admin --task update_plan \ +# -e $env_args + +## SET DYNAMIC TASK ARG +# env_arg1="MEDPERF_ADMIN_PARTICIPANT_CN=col1@example.com" +# env_arg2="MEDPERF_UPDATE_FIELD_NAME=dynamictaskargs.train.train_cutoff_time" +# env_arg3="MEDPERF_UPDATE_FIELD_VALUE=20" + +# env_args="$env_arg1,$env_arg2,$env_arg3" +# medperf mlcube run --mlcube ./mlcube_admin --task update_plan \ +# -e $env_args diff --git a/examples/fl/mock_cert/build.sh b/examples/fl/mock_cert/build.sh new file mode 100644 index 000000000..d56304274 --- /dev/null +++ b/examples/fl/mock_cert/build.sh @@ -0,0 +1 @@ +mlcube configure --mlcube ./mlcube -Pdocker.build_strategy=always diff --git a/examples/fl/mock_cert/clean.sh b/examples/fl/mock_cert/clean.sh new file mode 100644 index 000000000..fa72ddd6d --- /dev/null +++ b/examples/fl/mock_cert/clean.sh @@ -0,0 +1 @@ +rm -rf mlcube/workspace/pki_assets diff --git a/examples/fl/mock_cert/mlcube/mlcube.yaml b/examples/fl/mock_cert/mlcube/mlcube.yaml new file mode 100644 index 000000000..8019d3579 --- /dev/null +++ b/examples/fl/mock_cert/mlcube/mlcube.yaml @@ -0,0 +1,35 @@ +name: FL MLCube +description: FL MLCube +authors: + - { name: MLCommons Medical Working Group } + +platform: + accelerator_count: 0 + +docker: + # Image name + image: mlcommons/medperf-test-ca:0.0.0 + # Docker build context relative to $MLCUBE_ROOT. Default is `build`. + build_context: "../project" + # Docker file name within docker build context, default is `Dockerfile`. + build_file: "Dockerfile" + +tasks: + trust: + parameters: + inputs: + ca_config: ca_config.json + outputs: + pki_assets: pki_assets/ + get_client_cert: + parameters: + inputs: + ca_config: ca_config.json + outputs: + pki_assets: pki_assets/ + get_server_cert: + parameters: + inputs: + ca_config: ca_config.json + outputs: + pki_assets: pki_assets/ diff --git a/examples/fl/mock_cert/mlcube/workspace/ca_config.json b/examples/fl/mock_cert/mlcube/workspace/ca_config.json new file mode 100644 index 000000000..bcf246a03 --- /dev/null +++ b/examples/fl/mock_cert/mlcube/workspace/ca_config.json @@ -0,0 +1,7 @@ +{ + "address": "https://127.0.0.1", + "port": 443, + "fingerprint": "fingerprint", + "client_provisioner": "auth0", + "server_provisioner": "acme" +} \ No newline at end of file diff --git a/examples/fl/mock_cert/project/Dockerfile b/examples/fl/mock_cert/project/Dockerfile new file mode 100644 index 000000000..cf625ca6b --- /dev/null +++ b/examples/fl/mock_cert/project/Dockerfile @@ -0,0 +1,13 @@ +FROM python:3.9.16-slim + +COPY ./requirements.txt /mlcube_project/requirements.txt + +RUN pip3 install --no-cache-dir -r /mlcube_project/requirements.txt + +ENV LANG C.UTF-8 + +COPY . /mlcube_project + +RUN chmod a+r /mlcube_project/ca/root.key + +ENTRYPOINT ["python3", "/mlcube_project/mlcube.py"] \ No newline at end of file diff --git a/examples/fl/mock_cert/project/ca/cert/root.crt b/examples/fl/mock_cert/project/ca/cert/root.crt new file mode 100644 index 000000000..813cd7165 --- /dev/null +++ b/examples/fl/mock_cert/project/ca/cert/root.crt @@ -0,0 +1,28 @@ +-----BEGIN CERTIFICATE----- +MIIEyzCCAzOgAwIBAgIUd8btUDxu7RR87iJZhUjzturqti8wDQYJKoZIhvcNAQEM +BQAwdDETMBEGCgmSJomT8ixkARkWA29yZzEWMBQGCgmSJomT8ixkARkWBnNpbXBs +ZTEXMBUGA1UEAwwOU2ltcGxlIFJvb3QgQ0ExEzARBgNVBAoMClNpbXBsZSBJbmMx +FzAVBgNVBAsMDlNpbXBsZSBSb290IENBMCAXDTI0MDQyOTEzNTYxMloYDzIxMjQw +NDA1MTM1NjEyWjB0MRMwEQYKCZImiZPyLGQBGRYDb3JnMRYwFAYKCZImiZPyLGQB +GRYGc2ltcGxlMRcwFQYDVQQDDA5TaW1wbGUgUm9vdCBDQTETMBEGA1UECgwKU2lt +cGxlIEluYzEXMBUGA1UECwwOU2ltcGxlIFJvb3QgQ0EwggGiMA0GCSqGSIb3DQEB +AQUAA4IBjwAwggGKAoIBgQDMqoMT6iE/qRFJ/X+N9pp/WUoYaDDliUYuDdgb0pqq +6uqA50tfYmCOWal1K1Gq/4Hgi0OKsyj0bMemtRNXXH8r8qtjLNNmGmyeZICDe9FT +37gNr9uYWtVuwWpTI9bksxGVg9E0qx0U6fo+Puiu5ImDF/iYy1931ghijbOj0qWQ +M2dQi85baF/6uEHZ18b+c7K/toXCNhzJWrpw88DUyPerhkoe/JTI2kSNKZwULuan +VKazUZ4JIPF6NWhQGb/hcI+tTBkJXlETjrpN8A3hqVp6vpZTiZfXfy5eGmSyypwE +Z1gnSBOuh1EQxOLPXhykMeHaPZ5lZMeAprD/eHzWqw3lgTrcFkPBQMTAUWhHY4Wp +DKKRWZa2gbWf5peYzbpQtL0vgyDnDKgDyMkiXJksf97ITfbbP+VCSxLjYUaPWGmL +w4Ik6hzQmSSdSF/Va364W1tUY5D8D+DClrwg97K73nODifmeenwYHIn9Amdzt9Mh +cdQnbgZFJYDFEU1ZbJSlZfUCAwEAAaNTMFEwHQYDVR0OBBYEFCMAiSxHknRxYn6+ +SGNOAxhxDGbVMB8GA1UdIwQYMBaAFCMAiSxHknRxYn6+SGNOAxhxDGbVMA8GA1Ud +EwEB/wQFMAMBAf8wDQYJKoZIhvcNAQEMBQADggGBAEwfiT0gFgWoQ2NNWCL/bUlN +WrBYlwy6ixL8V97LoJMmiLVq/2fxNm66r6AgKKGJLEXPtEQIT6dwpHIlVYvlvhCw +zC079E4/vNXEldoEiHWxzkeQL1hKFFm36Vm5k8hVoHPIeSgJyGO/e4x10m6mYKqZ +iSKcaSCdx8x7B8EHnve1E3H6LQjSilIfdg80niFdKzjE6v7zrTYDyGUz0Shasnms +oO9KtkhRFe04Fm495EmieoaMA7eT7ojucoZ5dwgMf+wDgfXsZ+hIiCeC9nVB1Waq +Biwv9sIJ8tI19Y5FPDXnaP6alkDm05u4PkE9l+BJ5Ky5VQcjLhyZ4X37W33aUU+z +ng3/c27j8MzXtkNIhO2yUX3Lr2ExJqaLB/hrWlWDJH/yRG3hGNGVl3RdCL82hHD+ +TMd6MKKM3XTqnwvHbUiJw63Xa5upOcoXvcoS6/sDvpJQKnuNB3DZ8LBsvnWPlucP +Ctnnve8XKCvXfNVrV0uXB5rjWIvNZ5eiNJsml8e4tQ== +-----END CERTIFICATE----- diff --git a/examples/fl/mock_cert/project/ca/root.key b/examples/fl/mock_cert/project/ca/root.key new file mode 100644 index 000000000..25617bed6 --- /dev/null +++ b/examples/fl/mock_cert/project/ca/root.key @@ -0,0 +1,40 @@ +-----BEGIN PRIVATE KEY----- +MIIG/wIBADANBgkqhkiG9w0BAQEFAASCBukwggblAgEAAoIBgQDMqoMT6iE/qRFJ +/X+N9pp/WUoYaDDliUYuDdgb0pqq6uqA50tfYmCOWal1K1Gq/4Hgi0OKsyj0bMem +tRNXXH8r8qtjLNNmGmyeZICDe9FT37gNr9uYWtVuwWpTI9bksxGVg9E0qx0U6fo+ +Puiu5ImDF/iYy1931ghijbOj0qWQM2dQi85baF/6uEHZ18b+c7K/toXCNhzJWrpw +88DUyPerhkoe/JTI2kSNKZwULuanVKazUZ4JIPF6NWhQGb/hcI+tTBkJXlETjrpN +8A3hqVp6vpZTiZfXfy5eGmSyypwEZ1gnSBOuh1EQxOLPXhykMeHaPZ5lZMeAprD/ +eHzWqw3lgTrcFkPBQMTAUWhHY4WpDKKRWZa2gbWf5peYzbpQtL0vgyDnDKgDyMki +XJksf97ITfbbP+VCSxLjYUaPWGmLw4Ik6hzQmSSdSF/Va364W1tUY5D8D+DClrwg +97K73nODifmeenwYHIn9Amdzt9MhcdQnbgZFJYDFEU1ZbJSlZfUCAwEAAQKCAYAK +0M/7KMT3tPA29XCHiLVYGYMy5alVLVCfdRfV5eaf4FONhauUBNeOw5ToSOZt9PFg +yRCZWdJw6EwSC+upEuLy6EYVCEoQ++sq4QDG8gTOkToMGckEX3h7++NUisZRG61y +4J5uUW9Iqvy7IV6b2h6c5j1lmwnBLsxj+Oe6C+himx97QDLNiHyEprbEKQUDxArO +2s8YGP1NyWpPjHIaTJcvYfoKHSr3r6EucePPpT2HMOVbqz/WF6mV3btU0FI1kFnP +KvUYJy+qEhZGgDHBGm80Y7MAjV/7Iu34oikQTv42QgBd8CwPODZAJs/VRW5U2OMS +DOj/quLSChknofVb5rEcjz2HXVmilsGoLAjdbt19r16XlwlFSihn73zZ/kfJWud+ +IATJ5FW1A9B64QZGhB45hGJcESqHWYq+x4i9puRL1XtuuV0uJQq79w0SKDSQXlrs +AZ1OzEaRdubFE7M49BU8MoSza9QvNhzPADOewVWlpkKrVPYRYXYgouq4FKpRpFkC +gcEA8vTVWh7m9S48W2qnaUGu4EhY3X0PuHUuH/muWEXC4n36xBRNb7jZkcC4qO1Z +B9bznOWJcVKr5AQ2Cq61DlQesltMBMvo8LCdyyV+9XgBpVtFJ4OlOVex7xsR6/Lp +gVO6H9SC0ej1LAHAK37tTLfejPvuMQSnAYuvmDfLoB8nZX4xaRuRz4iHMxi9ED6f +N2Zdhtutp84DypxFQWxeo8SZshAF5i77L9wEkCeA6JjPMwv2SknEuIl+oX2kyNw7 +0oxXAoHBANena6pl0NHqZV0GlrdeVYC4nQrKtLq9cHs5E/8nLmPH29b3T25wPKGU +jP75S0DBb21LC00slVvp8aNMZIv4WBHcAlwUUj6rutJ0Bvm7ZrPZEAZrfPil7BRY +QG2x7lrI+biLj/7hmNKjflfIi1XSlxlfP99Wy37ImIoebZdKEOqP4M7E8NlhK0Lt +YPGg3qxA+0NQsqk+XKrls3AK1pVq2aZTsfAjH+Z0wSqmLSM4tXT7v0rQFYTCe06E ++NB5TfWwkwKBwQCyArV3zICIUBIlIOX8dwW8iwWhcwpbqm/bOcOGJcb+0DM2C3IZ +U6UF5+Dk1NKQrevcn0mu4FXVQUifVxaNoxDCuaXfNdA82gsjVxvImt8J2u+2Xfxn +IVvbx0fAS0DPYxtSSxB24GsSjU3SELOprGbBga0p+TCsLz6/FtJ5RZpGAMoPKwYQ +uwXkaFHOXzOlEbmhH8AC3S1l/E25+77z2w6JqrfHydB9ZoVpYahPw/a8fh08nQQn ++YXwqPBdww+J2w0CgcEAnFaUOBDd5QBPgbQgGUk7JTkxKDyx7tsdK0fC1mv6Nn4S +QvJBVGfrnJwL52ClDInvFMWdqNIUaXDdK6xbDBn7Bt9/mm9k/GgU5TMWR39zQhiv +hGfyTnRDBLDB7IRcrtYaK46J0paL6tB57HvHf21O+ybRMEFE/2G/LApJGq+oOdQa +fuvJS14lNbzPVfxw0WG+hht/mjBKj948SpKg4+t1ZB4y1ksweirUSu3ztSAMdIV5 +NWxK3Vb8e3zswH3gZagfAoHBAJI6LPi7K8POviGrV2Aw0EHv/Fs3GY+zgiphqf69 +pR7fZqcKYOSwPugw5gYR1l9REpWyD9qfNzKXeuQbwwqDaW7VZlUEpp3mndDYG4tT +63W5h4O/vAmtAdF8oasyDv2iSni6xppI2QmlAoSrDIyCM3HyNOM0l4pHRQ+V1ncE +JXwiXwePyt/lx30ua7VWU442ZfnVfA78cK1AaTl2KRV00veHR1KLBvIkyVE5AHD8 +Ynsi1KB4GfSKoLUr9t/n0hqZ8g== +-----END PRIVATE KEY----- diff --git a/examples/fl/mock_cert/project/csr.conf b/examples/fl/mock_cert/project/csr.conf new file mode 100644 index 000000000..c3b2d0f0c --- /dev/null +++ b/examples/fl/mock_cert/project/csr.conf @@ -0,0 +1,31 @@ +[ req ] +default_bits = 3072 +prompt = no +default_md = sha384 +distinguished_name = req_distinguished_name + +[ req_distinguished_name ] +commonName = hasan-hp-zbook-15-g3.home + +[ alt_names ] +DNS.1 = hasan-hp-zbook-15-g3.home + +[ v3_client ] +basicConstraints = critical,CA:FALSE +keyUsage = critical,digitalSignature,keyEncipherment +subjectAltName = @alt_names +extendedKeyUsage = critical,clientAuth + +[ v3_server ] +basicConstraints = critical,CA:FALSE +keyUsage = critical,digitalSignature,keyEncipherment +subjectAltName = @alt_names +extendedKeyUsage = critical,serverAuth + +[ v3_client_crt ] +basicConstraints = critical,CA:FALSE +subjectAltName = @alt_names + +[ v3_server_crt ] +basicConstraints = critical,CA:FALSE +subjectAltName = @alt_names diff --git a/examples/fl/mock_cert/project/mlcube.py b/examples/fl/mock_cert/project/mlcube.py new file mode 100644 index 000000000..0fffa51ab --- /dev/null +++ b/examples/fl/mock_cert/project/mlcube.py @@ -0,0 +1,49 @@ +"""MLCube handler file""" + +import typer +import shutil +import json +import os + +app = typer.Typer() + + +def asserts(ca_config): + with open(ca_config) as f: + config = json.load(f) + assert config["address"] == "https://127.0.0.1" + assert config["port"] == 443 + assert config["fingerprint"] == "fingerprint" + assert config["client_provisioner"] == "auth0" + assert config["server_provisioner"] == "acme" + + +@app.command("trust") +def trust( + ca_config: str = typer.Option(..., "--ca_config"), + pki_assets: str = typer.Option(..., "--pki_assets"), +): + asserts(ca_config) + shutil.copytree("/mlcube_project/ca/cert", pki_assets, dirs_exist_ok=True) + + +@app.command("get_client_cert") +def get_client_cert( + ca_config: str = typer.Option(..., "--ca_config"), + pki_assets: str = typer.Option(..., "--pki_assets"), +): + asserts(ca_config) + os.system(f"sh /mlcube_project/sign.sh -o {pki_assets}") + + +@app.command("get_server_cert") +def get_server_cert( + ca_config: str = typer.Option(..., "--ca_config"), + pki_assets: str = typer.Option(..., "--pki_assets"), +): + asserts(ca_config) + os.system(f"sh /mlcube_project/sign.sh -o {pki_assets} -s") + + +if __name__ == "__main__": + app() diff --git a/examples/fl/mock_cert/project/requirements.txt b/examples/fl/mock_cert/project/requirements.txt new file mode 100644 index 000000000..a1662dd93 --- /dev/null +++ b/examples/fl/mock_cert/project/requirements.txt @@ -0,0 +1,2 @@ +typer==0.9.0 +PyYAML==6.0 \ No newline at end of file diff --git a/examples/fl/mock_cert/project/sign.sh b/examples/fl/mock_cert/project/sign.sh new file mode 100644 index 000000000..4295351df --- /dev/null +++ b/examples/fl/mock_cert/project/sign.sh @@ -0,0 +1,36 @@ +while getopts so: flag; do + case "${flag}" in + o) OUT=${OPTARG} ;; + s) EXT="v3_server" ;; + esac +done + +EXT="${EXT:-v3_client}" + +if [ -z "$OUT" ]; then + echo "-o is required" + exit 1 +fi + +if [ -z "$MEDPERF_INPUT_CN" ]; then + echo "MEDPERF_INPUT_CN env var is required" + exit 1 +fi + +mkdir -p $OUT +cp /mlcube_project/csr.conf $OUT/ +cp -r /mlcube_project/ca $OUT/ +CSR_CONF=$OUT/csr.conf +CA_KEY=$OUT/ca/root.key +CA_CERT=$OUT/ca/cert/root.crt + +sed -i "/^commonName = /c\commonName = $MEDPERF_INPUT_CN" $CSR_CONF +sed -i "/^DNS\.1 = /c\DNS.1 = $MEDPERF_INPUT_CN" $CSR_CONF + +openssl genpkey -algorithm RSA -out $OUT/key.key -pkeyopt rsa_keygen_bits:3072 +openssl req -new -key $OUT/key.key -out $OUT/csr.csr -config $CSR_CONF -extensions $EXT +openssl x509 -req -in $OUT/csr.csr -CA $CA_CERT -CAkey $CA_KEY \ + -CAcreateserial -out $OUT/crt.crt -days 36500 -sha384 -extensions ${EXT}_crt -extfile $CSR_CONF +rm $OUT/csr.csr +rm -r $OUT/ca +rm -r $OUT/csr.conf diff --git a/examples/fl/mock_cert/test.sh b/examples/fl/mock_cert/test.sh new file mode 100644 index 000000000..79eaa584f --- /dev/null +++ b/examples/fl/mock_cert/test.sh @@ -0,0 +1,6 @@ +medperf mlcube run --mlcube ./mlcube --task trust +sh clean.sh +medperf mlcube run --mlcube ./mlcube --task get_client_cert -e MEDPERF_INPUT_CN=user@example.com +sh clean.sh +medperf mlcube run --mlcube ./mlcube --task get_server_cert -e MEDPERF_INPUT_CN=https://example.com +sh clean.sh diff --git a/examples/fl/prep/README.md b/examples/fl/prep/README.md new file mode 100644 index 000000000..b8bbdffad --- /dev/null +++ b/examples/fl/prep/README.md @@ -0,0 +1,10 @@ +# How to test + +1. download a dataset: + + - train: + - test: + +2. Extract the dataset. Place the folder `col1` under the workspace folder and rename it to `input_data`. +3. Create an empty folder named `input_labels` under the workspace folder. +4. Run `test.sh` diff --git a/examples/fl/prep/build.sh b/examples/fl/prep/build.sh new file mode 100644 index 000000000..d56304274 --- /dev/null +++ b/examples/fl/prep/build.sh @@ -0,0 +1 @@ +mlcube configure --mlcube ./mlcube -Pdocker.build_strategy=always diff --git a/examples/fl/prep/clean.sh b/examples/fl/prep/clean.sh new file mode 100644 index 000000000..68beabb21 --- /dev/null +++ b/examples/fl/prep/clean.sh @@ -0,0 +1,3 @@ +rm -rf mlcube/workspace/data +rm -rf mlcube/workspace/labels +rm -rf mlcube/workspace/statistics.yaml diff --git a/examples/fl/prep/mlcube/.gitignore b/examples/fl/prep/mlcube/.gitignore new file mode 100644 index 000000000..f1981605f --- /dev/null +++ b/examples/fl/prep/mlcube/.gitignore @@ -0,0 +1 @@ +workspace \ No newline at end of file diff --git a/examples/fl/prep/mlcube/mlcube.yaml b/examples/fl/prep/mlcube/mlcube.yaml new file mode 100644 index 000000000..15c0133b9 --- /dev/null +++ b/examples/fl/prep/mlcube/mlcube.yaml @@ -0,0 +1,40 @@ +name: pathmnist data preparation MLCube +description: pathmnist data preparation MLCube +authors: + - { name: MLCommons Medical Working Group } + +platform: + accelerator_count: 0 + +docker: + # Image name + image: mlcommons/fl-test-prep:0.0.0 + # Docker build context relative to $MLCUBE_ROOT. Default is `build`. + build_context: "../project" + # Docker file name within docker build context, default is `Dockerfile`. + build_file: "Dockerfile" + +tasks: + prepare: + parameters: + inputs: + { + data_path: input_data, + labels_path: input_labels, + } + outputs: { output_path: data/, output_labels_path: labels/ } + sanity_check: + parameters: + inputs: + { + data_path: data/, + labels_path: labels/, + } + statistics: + parameters: + inputs: + { + data_path: data/, + labels_path: labels/, + } + outputs: { output_path: { type: file, default: statistics.yaml } } diff --git a/examples/fl/prep/project/Dockerfile b/examples/fl/prep/project/Dockerfile new file mode 100644 index 000000000..91c477415 --- /dev/null +++ b/examples/fl/prep/project/Dockerfile @@ -0,0 +1,11 @@ +FROM python:3.9.16-slim + +COPY ./requirements.txt /mlcube_project/requirements.txt + +RUN pip3 install --no-cache-dir -r /mlcube_project/requirements.txt + +ENV LANG C.UTF-8 + +COPY . /mlcube_project + +ENTRYPOINT ["python3", "/mlcube_project/mlcube.py"] \ No newline at end of file diff --git a/examples/fl/prep/project/mlcube.py b/examples/fl/prep/project/mlcube.py new file mode 100644 index 000000000..2e3f03556 --- /dev/null +++ b/examples/fl/prep/project/mlcube.py @@ -0,0 +1,38 @@ +"""MLCube handler file""" +import typer +from prepare import prepare_dataset +from sanity_check import perform_sanity_checks +from stats import generate_statistics + +app = typer.Typer() + + +@app.command("prepare") +def prepare( + data_path: str = typer.Option(..., "--data_path"), + labels_path: str = typer.Option(..., "--labels_path"), + output_path: str = typer.Option(..., "--output_path"), + output_labels_path: str = typer.Option(..., "--output_labels_path"), +): + prepare_dataset(data_path, labels_path, output_path, output_labels_path) + + +@app.command("sanity_check") +def sanity_check( + data_path: str = typer.Option(..., "--data_path"), + labels_path: str = typer.Option(..., "--labels_path"), +): + perform_sanity_checks(data_path, labels_path) + + +@app.command("statistics") +def statistics( + data_path: str = typer.Option(..., "--data_path"), + labels_path: str = typer.Option(..., "--labels_path"), + out_path: str = typer.Option(..., "--output_path"), +): + generate_statistics(data_path, labels_path, out_path) + + +if __name__ == "__main__": + app() diff --git a/examples/fl/prep/project/prepare.py b/examples/fl/prep/project/prepare.py new file mode 100644 index 000000000..9bb44f418 --- /dev/null +++ b/examples/fl/prep/project/prepare.py @@ -0,0 +1,50 @@ +import os +import numpy as np +import pandas as pd +import shutil +from tqdm import tqdm +from PIL import Image + + +def prepare_split(split, arrays, output_path, output_labels_path): + subfolder = os.path.join(output_path, "pathmnist") + os.makedirs(subfolder, exist_ok=True) + + arrs = arrays[f"{split}_images"] + labels = arrays[f"{split}_labels"] + csv_data = [] + for i in tqdm(range(arrs.shape[0])): + name = f"{split}_{i}.png" + out_path = os.path.join(subfolder, name) + Image.fromarray(arrs[i]).save(out_path) + record = { + "SubjectID": str(i), + "Channel_0": os.path.join("pathmnist", name), + "valuetopredict": labels[i][0], + } + csv_data.append(record) + + if split == "train": + csv_file = os.path.join(output_path, "train.csv") + if split == "val": + csv_file = os.path.join(output_path, "valid.csv") + if split == "test": + csv_file = os.path.join(output_path, "data.csv") + + pd.DataFrame(csv_data).to_csv(csv_file, index=False) + + if split == "test": + csv_file_in_labels = os.path.join(output_labels_path, "data.csv") + shutil.copyfile(csv_file, csv_file_in_labels) + + +def prepare_dataset(data_path, labels_path, output_path, output_labels_path): + os.makedirs(output_path, exist_ok=True) + os.makedirs(output_labels_path, exist_ok=True) + + file_path = os.path.join(data_path, "pathmnist.npz") + arrays = np.load(file_path) + for key in arrays.keys(): + if key.endswith("images"): + split = key.split("_")[0] + prepare_split(split, arrays, output_path, output_labels_path) diff --git a/examples/fl/prep/project/requirements.txt b/examples/fl/prep/project/requirements.txt new file mode 100644 index 000000000..751a77d17 --- /dev/null +++ b/examples/fl/prep/project/requirements.txt @@ -0,0 +1,6 @@ +typer==0.9.0 +numpy==1.26.0 +PyYAML==6.0 +Pillow==10.2.0 +pandas==2.2.1 +tqdm \ No newline at end of file diff --git a/examples/fl/prep/project/sanity_check.py b/examples/fl/prep/project/sanity_check.py new file mode 100644 index 000000000..5fb23d72c --- /dev/null +++ b/examples/fl/prep/project/sanity_check.py @@ -0,0 +1,11 @@ +import os + + +def perform_sanity_checks(data_path, labels_path): + images_files = os.listdir(os.path.join(data_path, "pathmnist")) + + assert all( + [image.endswith(".png") for image in images_files] + ), "images should be .png" + + print("Sanity checks ran successfully.") diff --git a/examples/fl/prep/project/stats.py b/examples/fl/prep/project/stats.py new file mode 100644 index 000000000..872ed6f02 --- /dev/null +++ b/examples/fl/prep/project/stats.py @@ -0,0 +1,23 @@ +import os +import yaml + + +def generate_statistics(data_path, labels_path, out_path): + # number of cases + cases = os.listdir(os.path.join(data_path, "pathmnist")) + if cases[0].startswith("test"): + statistics = { + "num_cases": len(cases), + } + else: + num_train_cases = len([file for file in cases if file.startswith("train")]) + num_val_cases = len([file for file in cases if file.startswith("val")]) + statistics = { + "num_train_cases": num_train_cases, + "num_val_cases": num_val_cases, + } + + + # write statistics + with open(out_path, "w") as f: + yaml.safe_dump(statistics, f) diff --git a/examples/fl/prep/test.sh b/examples/fl/prep/test.sh new file mode 100644 index 000000000..c16ee1ca0 --- /dev/null +++ b/examples/fl/prep/test.sh @@ -0,0 +1,7 @@ +# mlcube run --mlcube ./mlcube --task prepare +# mlcube run --mlcube ./mlcube --task sanity_check +# mlcube run --mlcube ./mlcube --task statistics + +medperf mlcube run --mlcube ./mlcube --task prepare -o ./logs_prep.log +medperf mlcube run --mlcube ./mlcube --task sanity_check -o ./logs_sanity.log +medperf mlcube run --mlcube ./mlcube --task statistics -o ./logs_stats.log diff --git a/examples/fl_post/fl/.gitignore b/examples/fl_post/fl/.gitignore new file mode 100644 index 000000000..13ab94d3a --- /dev/null +++ b/examples/fl_post/fl/.gitignore @@ -0,0 +1,6 @@ +mlcube_* +ca +quick* +mlcube/workspace/additional_files/init_nnunet/* +mlcube/workspace/additional_files/init_weights/* +for_admin diff --git a/examples/fl_post/fl/README.md b/examples/fl_post/fl/README.md new file mode 100644 index 000000000..a890dff42 --- /dev/null +++ b/examples/fl_post/fl/README.md @@ -0,0 +1,34 @@ +# How to run tests (see next section for a detailed guide) + +- Run `setup_test_no_docker.sh` just once to create certs and download required data. +- Run `test.sh ` to start the aggregator and three collaborators. (requires BASH, see next section) +- Run `clean.sh` to be able to rerun `test.sh` freshly. +- Run `setup_clean.sh` to clear what has been generated in step 1. + +## Detailed Guide + +- Go to your medperf repo and checkout the required branch. +- Have medperf virtual environment activated (and medperf installed) +- run: `sh setup_test_no_docker.sh` to setup the test (you should `sh setup_clean.sh` if you already ran this before you run it again). +- run: `bash test.sh --d1 absolute_path --l2 absolute_path ...` to run the test + - data paths can be specified in the command. --dn is for data path of collaborator n, --ln is for labels_path of collaborator n. + - make sure gpu IDs are set as expected in `test.sh` script. +- to stop: `CTRL+C` in the terminal where you ran `test.sh`, then, `docker container ls`, then take the container IDs, then `docker container stop `, to stop relevant running containers (to identify containers to stop, they should have an IMAGE field same name as the one configured in docker image field in `mlcube.yaml`). You can at the end use `docker container prune` to delete all stopped containers if you want (not necessary). +- To rerun: you should first run `sh clean.sh`, then `bash test.sh ...` again. + +## What to do when you want to + +- change port: either change `setup_test_no_docker.sh` then clean setup and run setup again, or, go to `mlcube_agg/workspace/aggregator_config.yaml` and modify the file directly. +- Change address: change `setup_test_no_docker.sh` then clean setup and run setup again. (since the cert needs to be generated) +- change training_config: modify `mlcube/workspace/training_config.yaml` then run `sync.sh`. +- use custom data paths: pass data paths when running `test.sh` (`--d1, --d2, --l1, ...`) +- change weights: modify `mlcube/workspace/additional_files` then run `sync.sh`. +- fl_admin? connect to container and run fx commands. make sure a colab is an admin (to be detailed later) + +- to use three collaborators instead of two: + - go to `mlcube_agg/workspace/cols.yaml` and modify the list by adding col3. + - in `test.sh`, uncomment col3's run command. + +## to rebuild + +sh build.sh (or with -b if you want to rebuild the openfl base as well. Configure `build.sh` to change how openfl base is built) diff --git a/examples/fl_post/fl/build.sh b/examples/fl_post/fl/build.sh new file mode 100755 index 000000000..28e76c014 --- /dev/null +++ b/examples/fl_post/fl/build.sh @@ -0,0 +1,16 @@ +while getopts b flag; do + case "${flag}" in + b) BUILD_BASE="true" ;; + esac +done +BUILD_BASE="${BUILD_BASE:-false}" + +if ${BUILD_BASE}; then + git clone https://github.com/hasan7n/openfl.git + cd openfl + git checkout 9467f829687b6284a6e380d31f90d31bc9de023f + docker build -t local/openfl:local -f openfl-docker/Dockerfile.base . + cd .. + rm -rf openfl +fi +mlcube configure --mlcube ./mlcube -Pdocker.build_strategy=always diff --git a/examples/fl_post/fl/clean.sh b/examples/fl_post/fl/clean.sh new file mode 100755 index 000000000..789f36d07 --- /dev/null +++ b/examples/fl_post/fl/clean.sh @@ -0,0 +1,5 @@ +rm -rf mlcube_agg/workspace/final_weights +rm -rf mlcube_agg/workspace/logs +rm -rf mlcube_agg/workspace/plan.yaml +rm -rf mlcube_col*/workspace/logs +rm -rf mlcube_col*/workspace/plan.yaml diff --git a/examples/fl_post/fl/csr.conf b/examples/fl_post/fl/csr.conf new file mode 100644 index 000000000..c3b2d0f0c --- /dev/null +++ b/examples/fl_post/fl/csr.conf @@ -0,0 +1,31 @@ +[ req ] +default_bits = 3072 +prompt = no +default_md = sha384 +distinguished_name = req_distinguished_name + +[ req_distinguished_name ] +commonName = hasan-hp-zbook-15-g3.home + +[ alt_names ] +DNS.1 = hasan-hp-zbook-15-g3.home + +[ v3_client ] +basicConstraints = critical,CA:FALSE +keyUsage = critical,digitalSignature,keyEncipherment +subjectAltName = @alt_names +extendedKeyUsage = critical,clientAuth + +[ v3_server ] +basicConstraints = critical,CA:FALSE +keyUsage = critical,digitalSignature,keyEncipherment +subjectAltName = @alt_names +extendedKeyUsage = critical,serverAuth + +[ v3_client_crt ] +basicConstraints = critical,CA:FALSE +subjectAltName = @alt_names + +[ v3_server_crt ] +basicConstraints = critical,CA:FALSE +subjectAltName = @alt_names diff --git a/examples/fl_post/fl/mlcube/mlcube.yaml b/examples/fl_post/fl/mlcube/mlcube.yaml new file mode 100644 index 000000000..835e39ea3 --- /dev/null +++ b/examples/fl_post/fl/mlcube/mlcube.yaml @@ -0,0 +1,56 @@ +name: FL MLCube +description: FL MLCube +authors: + - { name: MLCommons Medical Working Group } + +platform: + accelerator_count: 0 + +docker: + gpu_args: "--shm-size 12g" + # Image name + image: mlcommons/rano-fl:30-oct-2024 + # Docker build context relative to $MLCUBE_ROOT. Default is `build`. + build_context: "../project" + # Docker file name within docker build context, default is `Dockerfile`. + build_file: "Dockerfile" + +tasks: + train: + parameters: + inputs: + data_path: data/ + labels_path: labels/ + node_cert_folder: node_cert/ + ca_cert_folder: ca_cert/ + plan_path: plan.yaml + init_nnunet_directory: additional_files/init_nnunet/ + outputs: + output_logs: logs/ + start_aggregator: + parameters: + inputs: + input_weights: additional_files/init_weights + node_cert_folder: node_cert/ + ca_cert_folder: ca_cert/ + plan_path: plan.yaml + collaborators: cols.yaml + outputs: + output_logs: logs/ + output_weights: final_weights/ + report_path: { type: "file", default: "report/report.yaml" } + generate_plan: + parameters: + inputs: + training_config_path: training_config.yaml + aggregator_config_path: aggregator_config.yaml + outputs: + plan_path: { type: "file", default: "plan/plan.yaml" } + train_initial_model: + parameters: + inputs: + data_path: data/ + labels_path: labels/ + outputs: + output_logs: logs/ + init_nnunet_directory: init_nnunet_directory/ diff --git a/examples/fl_post/fl/mlcube/workspace/training_config.yaml b/examples/fl_post/fl/mlcube/workspace/training_config.yaml new file mode 100644 index 000000000..3a1e357df --- /dev/null +++ b/examples/fl_post/fl/mlcube/workspace/training_config.yaml @@ -0,0 +1,119 @@ +aggregator : + defaults : plan/defaults/aggregator.yaml + template : openfl.component.Aggregator + settings : + init_state_path : save/fl_post_two_init.pbuf + best_state_path : save/fl_post_two_best.pbuf + last_state_path : save/fl_post_two_last.pbuf + rounds_to_train : &rounds_to_train 20 + admins_endpoints_mapping: + col1@example.com: + - GetExperimentStatus + - SetStragglerCuttoffTime + - SetDynamicTaskArg + - GetDynamicTaskArg + + dynamictaskargs: &dynamictaskargs + train: + train_cutoff_time: + admin_settable: True + min: 10 # 10 seconds + max: 86400 # one day + value: 300 # one day + val_cutoff_time: + admin_settable: True + min: 10 # 10 seconds + max: 86400 # one day + value: 20 # one day + train_completion_dampener: # train_completed -> (train_completed)**(train_completion_dampener) NOTE: Value close to zero zero shifts non 0.0 completion rates much closer to 1.0 + admin_settable: True + min: -1.0 # inverts train_comleted, so this would be a way to have checkpoint_weighting = train_completed * data_size (as opposed to data_size / train_completed) + max: 1.0 # leaves completion rates as is + value: 0.0 + + aggregated_model_validation: + val_cutoff_time: + admin_settable: True + min: 10 # 10 seconds + max: 86400 # one day + value: 20 # one day + + +collaborator : + defaults : plan/defaults/collaborator.yaml + template : openfl.component.Collaborator + settings : + delta_updates : false + opt_treatment : CONTINUE_LOCAL + dynamictaskargs: *dynamictaskargs + +data_loader : + defaults : plan/defaults/data_loader.yaml + template : src.nnunet_dummy_dataloader.NNUNetDummyDataLoader + settings : + p_train : 0.8 + +# TODO: make checkpoint-only truly generic and create the task runner within src +task_runner : + defaults : plan/defaults/task_runner.yaml + template : src.runner_nnunetv1.PyTorchNNUNetCheckpointTaskRunner + settings : + device : cuda + gpu_num_string : '0' + nnunet_task : Task537_FLPost + actual_max_num_epochs : *rounds_to_train + +network : + defaults : plan/defaults/network.yaml + settings: {} + +assigner : + defaults : plan/defaults/assigner.yaml + template : openfl.component.assigner.DynamicRandomGroupedAssigner + settings : + task_groups : + - name : train_and_validate + percentage : 1.0 + tasks : + - aggregated_model_validation + - train + - locally_tuned_model_validation + +tasks : + defaults : plan/defaults/tasks_torch.yaml + aggregated_model_validation: + function : validate + kwargs : + metrics : + - val_eval + - val_eval_C1 + - val_eval_C2 + - val_eval_C3 + - val_eval_C4 + apply : global + train: + function : train + kwargs : + metrics : + - train_loss + epochs : 1 + locally_tuned_model_validation: + function : validate + kwargs : + metrics : + - val_eval + - val_eval_C1 + - val_eval_C2 + - val_eval_C3 + - val_eval_C4 + apply : local + from_checkpoint: true + +compression_pipeline : + defaults : plan/defaults/compression_pipeline.yaml + +straggler_handling_policy : + template : openfl.component.straggler_handling_functions.CutoffTimeBasedStragglerHandling + settings : + straggler_cutoff_time : 1200 + minimum_reporting : 2 diff --git a/examples/fl_post/fl/project/Dockerfile b/examples/fl_post/fl/project/Dockerfile new file mode 100644 index 000000000..7a196cf41 --- /dev/null +++ b/examples/fl_post/fl/project/Dockerfile @@ -0,0 +1,26 @@ +FROM local/openfl:local + +ENV LANG C.UTF-8 +ENV CUDA_VISIBLE_DEVICES="0" + + +# install project dependencies +RUN apt-get update && apt-get install --no-install-recommends -y git zlib1g-dev libffi-dev libgl1 libgtk2.0-dev gcc g++ + +RUN pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu121 + +COPY ./requirements.txt /mlcube_project/requirements.txt +RUN pip install --no-cache-dir -r /mlcube_project/requirements.txt + +# Create similar env with cuda118 +RUN apt-get update && apt-get install python3.10-venv -y +RUN python -m venv /cuda118 +RUN /cuda118/bin/pip install --no-cache-dir --upgrade pip setuptools && /cuda118/bin/pip install --no-cache-dir wheel +RUN /cuda118/bin/pip install --no-cache-dir /openfl +RUN /cuda118/bin/pip install --no-cache-dir torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --extra-index-url https://download.pytorch.org/whl/cu118 +RUN /cuda118/bin/pip install --no-cache-dir -r /mlcube_project/requirements.txt + +# Copy mlcube project folder +COPY . /mlcube_project + +ENTRYPOINT ["sh", "/mlcube_project/entrypoint.sh"] \ No newline at end of file diff --git a/examples/fl_post/fl/project/README.md b/examples/fl_post/fl/project/README.md new file mode 100644 index 000000000..1e348651b --- /dev/null +++ b/examples/fl_post/fl/project/README.md @@ -0,0 +1,38 @@ +# How to configure container build for your application + +- List your pip requirements in `requirements.txt` +- List your software requirements in `Dockerfile` +- Modify the functions in `hooks.py` as needed. (Explanation TBD) + +# How to configure container for custom FL software + +- Change the base Docker image as needed. +- modify `aggregator.py` and `collaborator.py` as needed. Follow the implemented schema steps. + +# How to build + +- Build the openfl base image: + +```bash +git clone https://github.com/securefederatedai/openfl.git +cd openfl +git checkout e6f3f5fd4462307b2c9431184190167aa43d962f +docker build -t local/openfl:local -f openfl-docker/Dockerfile.base . +cd .. +rm -rf openfl +``` + +- Build the MLCube + +```bash +cd .. +bash build.sh +``` + +# NOTE + +for local experiments, internal IP address or localhost will not work. Use internal fqdn. + +# How to customize + +TBD diff --git a/examples/fl_post/fl/project/aggregator.py b/examples/fl_post/fl/project/aggregator.py new file mode 100644 index 000000000..8a1e7f283 --- /dev/null +++ b/examples/fl_post/fl/project/aggregator.py @@ -0,0 +1,32 @@ +import os +import shutil +from subprocess import check_call +from distutils.dir_util import copy_tree + + +def start_aggregator(workspace_folder, output_logs, output_weights, report_path): + + check_call(["fx", "aggregator", "start"], cwd=workspace_folder) + + # TODO: check how to copy logs during runtime. + # perhaps investigate overriding plan entries? + + # NOTE: logs and weights are copied, even if target folders are not empty + if os.path.exists(os.path.join(workspace_folder, "logs")): + copy_tree(os.path.join(workspace_folder, "logs"), output_logs) + + # NOTE: conversion fails since openfl needs sample data... + # weights_paths = get_weights_path(fl_workspace) + # out_best = os.path.join(output_weights, "best") + # out_last = os.path.join(output_weights, "last") + # check_call( + # ["fx", "model", "save", "-i", weights_paths["best"], "-o", out_best], + # cwd=workspace_folder, + # ) + copy_tree(os.path.join(workspace_folder, "save"), output_weights) + + # Cleanup + shutil.rmtree(workspace_folder, ignore_errors=True) + + with open(report_path, "w") as f: + f.write("IsDone: 1") diff --git a/examples/fl_post/fl/project/collaborator.py b/examples/fl_post/fl/project/collaborator.py new file mode 100644 index 000000000..fb4cdd1c2 --- /dev/null +++ b/examples/fl_post/fl/project/collaborator.py @@ -0,0 +1,29 @@ +import os +from utils import get_collaborator_cn +import shutil +from subprocess import check_call + + +def start_collaborator(workspace_folder): + cn = get_collaborator_cn() + check_call( + [os.environ.get("OPENFL_EXECUTABLE", "fx"), "collaborator", "start", "-n", cn], + cwd=workspace_folder, + ) + + # Cleanup + shutil.rmtree(workspace_folder, ignore_errors=True) + + +def check_connectivity(workspace_folder): + cn = get_collaborator_cn() + check_call( + [ + os.environ.get("OPENFL_EXECUTABLE", "fx"), + "collaborator", + "connectivity_check", + "-n", + cn, + ], + cwd=workspace_folder, + ) diff --git a/examples/fl_post/fl/project/entrypoint.sh b/examples/fl_post/fl/project/entrypoint.sh new file mode 100644 index 000000000..14d0056da --- /dev/null +++ b/examples/fl_post/fl/project/entrypoint.sh @@ -0,0 +1,24 @@ +PYTHONSCRIPT="import torch; torch.tensor([1.0, 2.0, 3.0, 4.0]).to('cuda')" + +if [ "$1" = "start_aggregator" ] || [ "$1" = "generate_plan" ]; then + # no need for gpu, don't test cuda + python /mlcube_project/mlcube.py $@ +else + echo "Testing which cuda version to use" + python -c "$PYTHONSCRIPT" + if [ "$?" -ne "0" ]; then + echo "cuda 12 failed. Trying with cuda 11.8" + /cuda118/bin/python -c "$PYTHONSCRIPT" + if [ "$?" -ne "0" ]; then + echo "No suppored cuda version satisfies the machine driver. Exiting." + exit 1 + else + echo "cuda 11.8 seems to be working. Will use cuda 11.8" + export OPENFL_EXECUTABLE="/cuda118/bin/fx" + /cuda118/bin/python /mlcube_project/mlcube.py $@ + fi + else + echo "cuda 12 seems to be working. Will use cuda 12" + python /mlcube_project/mlcube.py $@ + fi +fi diff --git a/examples/fl_post/fl/project/hooks.py b/examples/fl_post/fl/project/hooks.py new file mode 100644 index 000000000..516853743 --- /dev/null +++ b/examples/fl_post/fl/project/hooks.py @@ -0,0 +1,90 @@ +import os +import shutil +from utils import get_collaborator_cn + + +def collaborator_pre_training_hook( + data_path, + labels_path, + node_cert_folder, + ca_cert_folder, + plan_path, + output_logs, + init_nnunet_directory, + workspace_folder, +): + import nnunet_setup + + cn = get_collaborator_cn() + + os.symlink(data_path, f"{workspace_folder}/data", target_is_directory=True) + os.symlink(labels_path, f"{workspace_folder}/labels", target_is_directory=True) + + # this function returns metadata (model weights and config file) to be distributed out of band + # evan should use this without stuff to overwrite/sync so that it produces the correct metdata + # when evan runs, init_model_path, init_model_info_path should be None + # plans_path should also be None (the returned thing will point to where it lives so that it will be synced with others) + + nnunet_setup.main( + postopp_pardir=workspace_folder, + three_digit_task_num=537, # FIXME: does this need to be set in any particular way? + init_model_path=f"{init_nnunet_directory}/model_initial_checkpoint.model", + init_model_info_path=f"{init_nnunet_directory}/model_initial_checkpoint.model.pkl", + task_name="FLPost", + percent_train=0.8, + split_logic="by_subject_time_pair", + network="3d_fullres", + network_trainer="nnUNetTrainerV2", + fold="0", + plans_path=f"{init_nnunet_directory}/nnUNetPlans_pretrained_POSTOPP_plans_3D.pkl", # NOTE: IT IS NOT AN OPENFL PLAN + cuda_device="0", + verbose=False, + ) + + data_config = f"{cn},Task537_FLPost" + plan_folder = os.path.join(workspace_folder, "plan") + os.makedirs(plan_folder, exist_ok=True) + data_config_path = os.path.join(plan_folder, "data.yaml") + with open(data_config_path, "w") as f: + f.write(data_config) + shutil.copytree("/mlcube_project/src", os.path.join(workspace_folder, "src")) + + +def collaborator_post_training_hook( + data_path, + labels_path, + node_cert_folder, + ca_cert_folder, + plan_path, + output_logs, + workspace_folder, +): + pass + + +def aggregator_pre_training_hook( + input_weights, + node_cert_folder, + ca_cert_folder, + output_logs, + output_weights, + plan_path, + collaborators, + report_path, + workspace_folder, +): + pass + + +def aggregator_post_training_hook( + input_weights, + node_cert_folder, + ca_cert_folder, + output_logs, + output_weights, + plan_path, + collaborators, + report_path, + workspace_folder, +): + pass diff --git a/examples/fl_post/fl/project/init_model.py b/examples/fl_post/fl/project/init_model.py new file mode 100644 index 000000000..c2106b505 --- /dev/null +++ b/examples/fl_post/fl/project/init_model.py @@ -0,0 +1,35 @@ +import os +import shutil + + +def train_initial_model( + data_path, labels_path, init_nnunet_directory, workspace_folder +): + import nnunet_setup + + os.symlink(data_path, f"{workspace_folder}/data", target_is_directory=True) + os.symlink(labels_path, f"{workspace_folder}/labels", target_is_directory=True) + + res = nnunet_setup.main( + postopp_pardir=workspace_folder, + three_digit_task_num=537, # FIXME: does this need to be set in any particular way? + init_model_path=None, + init_model_info_path=None, + task_name="FLPost", + percent_train=0.8, + split_logic="by_subject_time_pair", + network="3d_fullres", + network_trainer="nnUNetTrainerV2", + fold="0", + plans_path=None, + cuda_device="0", + verbose=False, + ) + + initial_model_path = res["initial_model_path"] + initial_model_info_path = res["initial_model_info_path"] + plans_path = res["plans_path"] + + shutil.move(initial_model_path, init_nnunet_directory) + shutil.move(initial_model_info_path, init_nnunet_directory) + shutil.move(plans_path, init_nnunet_directory) diff --git a/examples/fl_post/fl/project/mlcube.py b/examples/fl_post/fl/project/mlcube.py new file mode 100644 index 000000000..14694df94 --- /dev/null +++ b/examples/fl_post/fl/project/mlcube.py @@ -0,0 +1,147 @@ +"""MLCube handler file""" + +import typer +from collaborator import start_collaborator, check_connectivity +from aggregator import start_aggregator +from plan import generate_plan +from hooks import ( + aggregator_pre_training_hook, + aggregator_post_training_hook, + collaborator_pre_training_hook, + collaborator_post_training_hook, +) +from utils import generic_setup, generic_teardown, setup_collaborator, setup_aggregator +from init_model import train_initial_model + +app = typer.Typer() + + +@app.command("train") +def train( + data_path: str = typer.Option(..., "--data_path"), + labels_path: str = typer.Option(..., "--labels_path"), + node_cert_folder: str = typer.Option(..., "--node_cert_folder"), + ca_cert_folder: str = typer.Option(..., "--ca_cert_folder"), + plan_path: str = typer.Option(..., "--plan_path"), + output_logs: str = typer.Option(..., "--output_logs"), + init_nnunet_directory: str = typer.Option(..., "--init_nnunet_directory"), +): + workspace_folder = generic_setup(output_logs) + setup_collaborator( + data_path=data_path, + labels_path=labels_path, + node_cert_folder=node_cert_folder, + ca_cert_folder=ca_cert_folder, + plan_path=plan_path, + output_logs=output_logs, + workspace_folder=workspace_folder, + ) + check_connectivity(workspace_folder) + collaborator_pre_training_hook( + data_path=data_path, + labels_path=labels_path, + node_cert_folder=node_cert_folder, + ca_cert_folder=ca_cert_folder, + plan_path=plan_path, + output_logs=output_logs, + init_nnunet_directory=init_nnunet_directory, + workspace_folder=workspace_folder, + ) + start_collaborator(workspace_folder=workspace_folder) + collaborator_post_training_hook( + data_path=data_path, + labels_path=labels_path, + node_cert_folder=node_cert_folder, + ca_cert_folder=ca_cert_folder, + plan_path=plan_path, + output_logs=output_logs, + workspace_folder=workspace_folder, + ) + generic_teardown(output_logs) + + +@app.command("start_aggregator") +def start_aggregator_( + input_weights: str = typer.Option(..., "--input_weights"), + node_cert_folder: str = typer.Option(..., "--node_cert_folder"), + ca_cert_folder: str = typer.Option(..., "--ca_cert_folder"), + output_logs: str = typer.Option(..., "--output_logs"), + output_weights: str = typer.Option(..., "--output_weights"), + plan_path: str = typer.Option(..., "--plan_path"), + collaborators: str = typer.Option(..., "--collaborators"), + report_path: str = typer.Option(..., "--report_path"), +): + workspace_folder = generic_setup(output_logs) + setup_aggregator( + input_weights=input_weights, + node_cert_folder=node_cert_folder, + ca_cert_folder=ca_cert_folder, + output_logs=output_logs, + output_weights=output_weights, + plan_path=plan_path, + collaborators=collaborators, + report_path=report_path, + workspace_folder=workspace_folder, + ) + aggregator_pre_training_hook( + input_weights=input_weights, + node_cert_folder=node_cert_folder, + ca_cert_folder=ca_cert_folder, + output_logs=output_logs, + output_weights=output_weights, + plan_path=plan_path, + collaborators=collaborators, + report_path=report_path, + workspace_folder=workspace_folder, + ) + start_aggregator( + workspace_folder=workspace_folder, + output_logs=output_logs, + output_weights=output_weights, + report_path=report_path, + ) + aggregator_post_training_hook( + input_weights=input_weights, + node_cert_folder=node_cert_folder, + ca_cert_folder=ca_cert_folder, + output_logs=output_logs, + output_weights=output_weights, + plan_path=plan_path, + collaborators=collaborators, + report_path=report_path, + workspace_folder=workspace_folder, + ) + generic_teardown(output_logs) + + +@app.command("generate_plan") +def generate_plan_( + training_config_path: str = typer.Option(..., "--training_config_path"), + aggregator_config_path: str = typer.Option(..., "--aggregator_config_path"), + plan_path: str = typer.Option(..., "--plan_path"), +): + # no _setup here since there is no writable output mounted volume. + # later if need this we think of a solution. Currently the create_plam + # logic is assumed to not write within the container. + generate_plan(training_config_path, aggregator_config_path, plan_path) + + +@app.command("train_initial_model") +def train_initial_model_( + data_path: str = typer.Option(..., "--data_path"), + labels_path: str = typer.Option(..., "--labels_path"), + output_logs: str = typer.Option(..., "--output_logs"), + init_nnunet_directory: str = typer.Option(..., "--init_nnunet_directory"), +): + workspace_folder = generic_setup(output_logs) + train_initial_model( + data_path=data_path, + labels_path=labels_path, + init_nnunet_directory=init_nnunet_directory, + workspace_folder=workspace_folder, + ) + generic_teardown(output_logs) + + +if __name__ == "__main__": + app() diff --git a/examples/fl_post/fl/project/nnunet_data_setup.py b/examples/fl_post/fl/project/nnunet_data_setup.py new file mode 100644 index 000000000..6db3663d6 --- /dev/null +++ b/examples/fl_post/fl/project/nnunet_data_setup.py @@ -0,0 +1,421 @@ + +import os +import subprocess +import pickle as pkl +import shutil +import numpy as np + +from collections import OrderedDict + +from nnunet.dataset_conversion.utils import generate_dataset_json + +from nnunet_model_setup import trim_data_and_setup_model + + +num_to_modality = {'_0000': '_brain_t1n.nii.gz', + '_0001': '_brain_t2w.nii.gz', + '_0002': '_brain_t2f.nii.gz', + '_0003': '_brain_t1c.nii.gz'} + +def get_subdirs(parent_directory): + subjects = os.listdir(parent_directory) + subjects = [p for p in subjects if os.path.isdir(os.path.join(parent_directory, p)) and not p.startswith(".")] + return subjects + + +def subject_time_to_mask_path(pardir, subject, timestamp): + mask_fname = f'{subject}_{timestamp}_tumorMask_model_0.nii.gz' + return os.path.join(pardir, 'labels', '.tumor_segmentation_backup', subject, timestamp,'TumorMasksForQC', mask_fname) + + +def create_task_folders(task_num, task_name, overwrite_nnunet_datadirs): + task = f'Task{str(task_num)}_{task_name}' + + # The NNUnet data path is obtained from an environmental variable + nnunet_dst_pardir = os.path.join(os.environ['nnUNet_raw_data_base'], 'nnUNet_raw_data', f'{task}') + + nnunet_images_train_pardir = os.path.join(nnunet_dst_pardir, 'imagesTr') + nnunet_labels_train_pardir = os.path.join(nnunet_dst_pardir, 'labelsTr') + + task_cropped_pardir = os.path.join(os.environ['nnUNet_raw_data_base'], 'nnUNet_cropped_data', f'{task}') + task_preprocessed_pardir = os.path.join(os.environ['nnUNet_raw_data_base'], 'nnUNet_preprocessed', f'{task}') + + if not overwrite_nnunet_datadirs: + if os.path.exists(nnunet_images_train_pardir) and os.path.exists(nnunet_labels_train_pardir): + raise ValueError(f"Train images pardirs: {nnunet_images_train_pardir} and {nnunet_labels_train_pardir} both already exist. Please move them both and rerun to prevent overwriting.") + elif os.path.exists(nnunet_images_train_pardir): + raise ValueError(f"Train images pardir: {nnunet_images_train_pardir} already exists, please move and run again to prevent overwriting.") + elif os.path.exists(nnunet_labels_train_pardir): + raise ValueError(f"Train labels pardir: {nnunet_labels_train_pardir} already exists, please move and run again to prevent overwriting.") + + if os.path.exists(task_cropped_pardir): + raise ValueError(f"Cropped data pardir: {task_cropped_pardir} already exists, please move and run again to prevent overwriting.") + if os.path.exists(task_preprocessed_pardir): + raise ValueError(f"Preprocessed data pardir: {task_preprocessed_pardir} already exists, please move and run again to prevent overwriting.") + else: + if os.path.exists(task_cropped_pardir): + shutil.rmtree(task_cropped_pardir) + if os.path.exists(task_preprocessed_pardir): + shutil.rmtree(task_preprocessed_pardir) + if os.path.exists(nnunet_images_train_pardir): + shutil.rmtree(nnunet_images_train_pardir) + if os.path.exists(nnunet_labels_train_pardir): + shutil.rmtree(nnunet_labels_train_pardir) + + + os.makedirs(nnunet_images_train_pardir, exist_ok=False) + os.makedirs(nnunet_labels_train_pardir, exist_ok=False) + + return task, nnunet_dst_pardir, nnunet_images_train_pardir, nnunet_labels_train_pardir + + +def symlink_one_subject(postopp_subject_dir, postopp_data_dirpath, postopp_labels_dirpath, nnunet_images_train_pardir, nnunet_labels_train_pardir, timestamp_selection, verbose=False): + if verbose: + print(f"\n#######\nsymlinking subject: {postopp_subject_dir}\n########\nPostopp_data_dirpath: {postopp_data_dirpath}\n\n\n\n") + postopp_subject_dirpath = os.path.join(postopp_data_dirpath, postopp_subject_dir) + all_timestamps = sorted(list(get_subdirs(postopp_subject_dirpath))) + if timestamp_selection == 'latest': + timestamps = all_timestamps[-1:] + elif timestamp_selection == 'earliest': + timestamps = all_timestamps[0:1] + elif timestamp_selection == 'all': + timestamps = all_timestamps + else: + raise ValueError(f"timestamp_selection currently only supports 'latest', 'earliest', and 'all', but you have requested: '{timestamp_selection}'") + + for timestamp in timestamps: + postopp_subject_timestamp_dirpath = os.path.join(postopp_subject_dirpath, timestamp) + postopp_subject_timestamp_label_dirpath = os.path.join(postopp_labels_dirpath, postopp_subject_dir, timestamp) + if not os.path.exists(postopp_subject_timestamp_label_dirpath): + raise ValueError(f"Subject label file for data at: {postopp_subject_timestamp_dirpath} was not found in the expected location: {postopp_subject_timestamp_label_dirpath}") + + timed_subject = postopp_subject_dir + '_' + timestamp + + # Symlink label first + label_src_path = os.path.join(postopp_subject_timestamp_label_dirpath, timed_subject + '_final_seg.nii.gz') + label_dst_path = os.path.join(nnunet_labels_train_pardir, timed_subject + '.nii.gz') + os.symlink(src=label_src_path, dst=label_dst_path) + + # Symlink images + for num in num_to_modality: + src_path = os.path.join(postopp_subject_timestamp_dirpath, timed_subject + num_to_modality[num]) + dst_path = os.path.join(nnunet_images_train_pardir,timed_subject + num + '.nii.gz') + os.symlink(src=src_path, dst=dst_path) + + return timestamps + + +def doublecheck_postopp_pardir(postopp_pardir, verbose=False): + if verbose: + print(f"Checking postopp_pardir: {postopp_pardir}") + postopp_subdirs = list(get_subdirs(postopp_pardir)) + if 'data' not in postopp_subdirs: + raise ValueError(f"'data' must be a subdirectory of postopp_src_pardir:{postopp_pardir}, but it is not.") + if 'labels' not in postopp_subdirs: + raise ValueError(f"'labels' must be a subdirectory of postopp_src_pardir:{postopp_pardir}, but it is not.") + + +def split_by_subject(subject_to_timestamps, percent_train, split_seed, verbose=False): + """ + NOTE: An attempt is made to put percent_train of the total subjects into train (as opposed to val) regardless of how many timestamps there are for each subject. + No subject is allowed to have samples in both train and val. + """ + + subjects = list(subject_to_timestamps.keys()) + # create a random number generator with our seed + rng = np.random.default_rng(split_seed) + rng.shuffle(subjects) + + train_cutoff = int(len(subjects) * percent_train) + + train_subject_to_timestamps = {subject: subject_to_timestamps[subject] for subject in subjects[:train_cutoff] } + val_subject_to_timestamps = {subject: subject_to_timestamps[subject] for subject in subjects[train_cutoff:]} + + return train_subject_to_timestamps, val_subject_to_timestamps + + +def split_by_timed_subjects(subject_to_timestamps, percent_train, split_seed, random_tries=30, verbose=False): + """ + NOTE: An attempt is made to put percent_train of the subject timestamp combinations into train (as opposed to val) regardless of what that does to the subject ratios. + No subject is allowed to have samples in both train and val. + """ + def percent_train_for_split(train_subjects, grand_total): + sub_total = 0 + for subject in train_subjects: + sub_total += subject_counts[subject] + return sub_total/grand_total + + def shuffle_and_cut(subject_counts, grand_total, percent_train, seed, verbose=False): + subjects = list(subject_counts.keys()) + # create a random number generator with our seed + rng = np.random.default_rng(seed) + rng.shuffle(subjects) + for idx in range(2,len(subjects)+1): + train_subjects = subjects[:idx-1] + val_subjects = subjects[idx-1:] + percent_train_estimate = percent_train_for_split(train_subjects=train_subjects, grand_total=grand_total) + if percent_train_estimate >= percent_train: + """ + if verbose: + print(f"SPLIT COMPUTE - Found one split with percent_train of: {percent_train_estimate}") + """ + break + return train_subjects, val_subjects, percent_train_estimate + # above should return by end of loop as percent_train_estimate should be strictly increasing with final value 1.0 + + + subject_counts = {subject: len(subject_to_timestamps[subject]) for subject in subject_to_timestamps} + subjects_copy = list(subject_counts.keys()).copy() + grand_total = 0 + for subject in subject_counts: + grand_total += subject_counts[subject] + + # create a valid split of counts for comparison + best_train_subjects = subjects_copy[:1] + best_val_subjects = subjects_copy[1:] + best_percent_train = percent_train_for_split(train_subjects=best_train_subjects, grand_total=grand_total) + + # random shuffle times in order to find the closest we can get to honoring the percent_train requirement (train and val both need to be non-empty) + for _try in range(random_tries): + seed = split_seed + _try + train_subjects, val_subjects, percent_train_estimate = shuffle_and_cut(subject_counts=subject_counts, grand_total=grand_total, percent_train=percent_train, seed=seed, verbose=verbose) + if abs(percent_train_estimate - percent_train) < abs(best_percent_train - percent_train): + best_train_subjects = train_subjects + best_val_subjects = val_subjects + best_percent_train = percent_train_estimate + if verbose: + print(f"\n#########\n Split was performed by timed subject and an error of {abs(best_percent_train - percent_train)} was acheived in the percent train target.") + train_subject_to_timestamps = {subject: subject_to_timestamps[subject] for subject in best_train_subjects} + val_subject_to_timestamps = {subject: subject_to_timestamps[subject] for subject in best_val_subjects} + return train_subject_to_timestamps, val_subject_to_timestamps + + +def write_splits_file(subject_to_timestamps, percent_train, split_logic, split_seed, fold, task, splits_fname='splits_final.pkl', verbose=False): + # double check we are in the right folder to modify the splits file + splits_fpath = os.path.join(os.environ['nnUNet_preprocessed'], f'{task}', splits_fname) + POSTOPP_splits_fpath = os.path.join(os.environ['nnUNet_preprocessed'], f'{task}', 'POSTOPP_BACKUP_' + splits_fname) + + # now split + if split_logic == 'by_subject': + train_subject_to_timestamps, val_subject_to_timestamps = split_by_subject(subject_to_timestamps=subject_to_timestamps, percent_train=percent_train, split_seed=split_seed, verbose=verbose) + elif split_logic == 'by_subject_time_pair': + train_subject_to_timestamps, val_subject_to_timestamps = split_by_timed_subjects(subject_to_timestamps=subject_to_timestamps, percent_train=percent_train, split_seed=split_seed, verbose=verbose) + else: + raise ValueError(f"Split logic of 'by_subject' and 'by_subject_time_pair' are the only ones supported, whereas a split_logic value of {split_logic} was provided.") + + # Now construct the list of subjects + train_subjects_list = [] + val_subjects_list = [] + for subject in train_subject_to_timestamps: + for timestamp in train_subject_to_timestamps[subject]: + train_subjects_list.append(subject + '_' + timestamp) + for subject in val_subject_to_timestamps: + for timestamp in val_subject_to_timestamps[subject]: + val_subjects_list.append(subject + '_' + timestamp) + + # Now write the splits file (note None is put into the folds that we don't use as a safety measure so that no unintended folds are used) + new_folds = [None, None, None, None, None] + new_folds[int(fold)] = OrderedDict({'train': np.array(train_subjects_list), 'val': np.array(val_subjects_list)}) + + with open(splits_fpath, 'wb') as f: + pkl.dump(new_folds, f) + + # Making an extra copy to test that things are not overwriten later + with open(POSTOPP_splits_fpath, 'wb') as f: + pkl.dump(new_folds, f) + + +def setup_fl_data(postopp_pardir, + three_digit_task_num, + task_name, + percent_train, + split_logic, + fold, + timestamp_selection, + network, + network_trainer, + local_plans_identifier, + shared_plans_identifier, + init_model_path, + init_model_info_path, + cuda_device, + overwrite_nnunet_datadirs, + split_seed=7777777, + plans_path=None, + verbose=False): + """ + Generates symlinks to be used for NNUnet training, assuming we already have a + dataset on file coming from MLCommons RANO experiment data prep. + + Also creates the json file for the data, as well as runs nnunet preprocessing. + + should be run using a virtual environment that has nnunet version 1 installed. + + args: + postopp_pardir(str) : Parent directory for postopp data. + This directory should have 'data' and 'labels' subdirectories, with structure: + ├── data + │ ├── AAAC_0 + │ │ ├── 2008.03.30 + │ │ │ ├── AAAC_0_2008.03.30_brain_t1c.nii.gz + │ │ │ ├── AAAC_0_2008.03.30_brain_t1n.nii.gz + │ │ │ ├── AAAC_0_2008.03.30_brain_t2f.nii.gz + │ │ │ └── AAAC_0_2008.03.30_brain_t2w.nii.gz + │ │ └── 2008.12.17 + │ │ ├── AAAC_0_2008.12.17_brain_t1c.nii.gz + │ │ ├── AAAC_0_2008.12.17_brain_t1n.nii.gz + │ │ ├── AAAC_0_2008.12.17_brain_t2f.nii.gz + │ │ └── AAAC_0_2008.12.17_brain_t2w.nii.gz + │ ├── AAAC_1 + │ │ ├── 2008.03.30_duplicate + │ │ │ ├── AAAC_1_2008.03.30_duplicate_brain_t1c.nii.gz + │ │ │ ├── AAAC_1_2008.03.30_duplicate_brain_t1n.nii.gz + │ │ │ ├── AAAC_1_2008.03.30_duplicate_brain_t2f.nii.gz + │ │ │ └── AAAC_1_2008.03.30_duplicate_brain_t2w.nii.gz + │ │ └── 2008.12.17_duplicate + │ │ ├── AAAC_1_2008.12.17_duplicate_brain_t1c.nii.gz + │ │ ├── AAAC_1_2008.12.17_duplicate_brain_t1n.nii.gz + │ │ ├── AAAC_1_2008.12.17_duplicate_brain_t2f.nii.gz + │ │ └── AAAC_1_2008.12.17_duplicate_brain_t2w.nii.gz + │ ├── AAAC_extra + │ │ └── 2008.12.10 + │ │ ├── AAAC_extra_2008.12.10_brain_t1c.nii.gz + │ │ ├── AAAC_extra_2008.12.10_brain_t1n.nii.gz + │ │ ├── AAAC_extra_2008.12.10_brain_t2f.nii.gz + │ │ └── AAAC_extra_2008.12.10_brain_t2w.nii.gz + │ ├── data.csv + │ └── splits.csv + ├── labels + │ ├── AAAC_0 + │ │ ├── 2008.03.30 + │ │ │ └── AAAC_0_2008.03.30_final_seg.nii.gz + │ │ └── 2008.12.17 + │ │ └── AAAC_0_2008.12.17_final_seg.nii.gz + │ ├── AAAC_1 + │ │ ├── 2008.03.30_duplicate + │ │ │ └── AAAC_1_2008.03.30_duplicate_final_seg.nii.gz + │ │ └── 2008.12.17_duplicate + │ │ └── AAAC_1_2008.12.17_duplicate_final_seg.nii.gz + │ └── AAAC_extra + │ └── 2008.12.10 + │ └── AAAC_extra_2008.12.10_final_seg.nii.gz + └── report.yaml + + three_digit_task_num(str): Should start with '5'. + task_name(str) : Any string task name. + network(str) : Which network is being used for NNUnet + network_trainer(str) : Which network trainer class is being used for NNUnet + local_plans_identifier(str) : Used in the plans file name for a collaborator that will be performing local training to produce an initial model + shared_plans_identifier(str) : Used in the plans file name for creation and dissemination of the shared plan to be used in the federation + init_model_path(str) : Path to the initial model + init_model_info_path(str) : Path to the initial model info (pkl) file + cuda_device(str) : Device to perform training ('cpu' or 'cuda') + overwrite_nnunet_datadirs(bool) : Allows for overwriting past instances of NNUnet data directories using the task numbers from first_three_digit_task_num to that plus one less than number of insitutions. + split_seed (int) : Base seed for seeds used for the random number generators within the split logic + plans_path(str) : Path to the training plans (pkl) + percent_train(float) : What percentage of timestamped subjects to attempt dedicate to train versus val. Will be only approximately acheived in general since + all timestamps associated with the same subject need to land exclusively in either train or val. + split_logic(str) : Determines how the percent_train is computed. Choices are: 'by_subject' and 'by_subject_time_pair'. + fold(str) : Fold to train on, can be a sting indicating an int, or can be 'all' + timestamp_selection(str) : Determines which timestamps are used for each subject. Can be 'earliest', 'latest', or 'all' + verbose(bool) : Debugging output if True. + + Returns: + task_nums, tasks, nnunet_dst_pardirs, nnunet_images_train_pardirs, nnunet_labels_train_pardirs + """ + + task, nnunet_dst_pardir, nnunet_images_train_pardir, nnunet_labels_train_pardir = \ + create_task_folders(task_num=three_digit_task_num, task_name=task_name, overwrite_nnunet_datadirs=overwrite_nnunet_datadirs) + + doublecheck_postopp_pardir(postopp_pardir, verbose=verbose) + postopp_data_dirpath = os.path.join(postopp_pardir, 'data') + postopp_labels_dirpath = os.path.join(postopp_pardir, 'labels') + + all_subjects = list(get_subdirs(postopp_data_dirpath)) + + # Track the subjects and timestamps for each shard + subject_to_timestamps = {} + + for postopp_subject_dir in all_subjects: + subject_to_timestamps[postopp_subject_dir] = symlink_one_subject(postopp_subject_dir=postopp_subject_dir, + postopp_data_dirpath=postopp_data_dirpath, + postopp_labels_dirpath=postopp_labels_dirpath, + nnunet_images_train_pardir=nnunet_images_train_pardir, + nnunet_labels_train_pardir=nnunet_labels_train_pardir, + timestamp_selection=timestamp_selection, + verbose=verbose) + + # Generate json file for the dataset + print(f"\n######### GENERATING DATA JSON FILE #########\n") + json_path = os.path.join(nnunet_dst_pardir, 'dataset.json') + labels = {0: 'Background', 1: 'Necrosis', 2: 'Edema', 3: 'Enhancing Tumor', 4: 'Cavity'} + generate_dataset_json(output_file=json_path, imagesTr_dir=nnunet_images_train_pardir, imagesTs_dir=None, modalities=tuple(num_to_modality.keys()), + labels=labels, dataset_name='RANO Postopp') + + # Now call the os process to preprocess the data + print(f"\n######### OS CALL TO PREPROCESS DATA #########\n") + if plans_path: + subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "-pl2d", "None", "--verify_dataset_integrity"]) + subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "-pl3d", "ExperimentPlanner3D_v21_Pretrained", "-pl2d", "None", "-overwrite_plans", f"{plans_path}", "-overwrite_plans_identifier", "POSTOPP", "-no_pp"]) + plans_identifier_for_model_writing = shared_plans_identifier + else: + # this is a preliminary data setup, which will be passed over to the pretrained plan similar to above after we perform training on this plan + subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "--verify_dataset_integrity", "-pl2d", "None"]) + plans_identifier_for_model_writing = local_plans_identifier + + # Now compute our own stratified splits file, keeping all timestampts for a given subject exclusively in either train or val + write_splits_file(subject_to_timestamps=subject_to_timestamps, + percent_train=percent_train, + split_logic=split_logic, + split_seed=split_seed, + fold=fold, + task=task, + verbose=verbose) + + # trim 2d data if not working with 2d model, then train an initial model if needed (initial_model_path is None) or write in provided model otherwise + col_paths = {} + col_paths['initial_model_path'], \ + col_paths['final_model_path'], \ + col_paths['initial_model_info_path'], \ + col_paths['final_model_info_path'], \ + col_paths['plans_path'] = trim_data_and_setup_model(task=task, + network=network, + network_trainer=network_trainer, + plans_identifier=plans_identifier_for_model_writing, + fold=fold, + init_model_path=init_model_path, + init_model_info_path=init_model_info_path, + plans_path=plans_path, + cuda_device=cuda_device) + + if not plans_path: + # In this case we have created an initial model with this data, so running preprocesssing again in order to create a 'pretrained' plan similar to what other collaborators will create with our initial plan + subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "-pl3d", "ExperimentPlanner3D_v21_Pretrained", "-overwrite_plans", f"{col_paths['plans_path']}", "-overwrite_plans_identifier", "POSTOPP", "--verify_dataset_integrity", "-no_pp"]) + # Now coying the collaborator paths above to a new location that uses the pretrained planner that will be shared across federation + new_col_paths = {} + new_col_paths['initial_model_path'], \ + new_col_paths['final_model_path'], \ + new_col_paths['initial_model_info_path'], \ + new_col_paths['final_model_info_path'], \ + new_col_paths['plans_path'] = trim_data_and_setup_model(task=task, + network=network, + network_trainer=network_trainer, + plans_identifier=shared_plans_identifier, + fold=fold, + init_model_path=col_paths['initial_model_path'], + init_model_info_path=col_paths['initial_model_info_path'], + plans_path=col_paths['plans_path'], + cuda_device=cuda_device) + + col_paths = new_col_paths + + print(f"\n### ### ### ### ### ### ###\n") + print(f"A MODEL HAS TRAINED. HERE ARE PATHS WHERE FILES CAN BE OBTAINED:\n") + print(f"initial_model_path: {col_paths['initial_model_path']}") + print(f"initial_model_info_path: {col_paths['initial_model_info_path']}") + print(f"final_model_path: {col_paths['final_model_path']}") + print(f"final_model_info_path: {col_paths['final_model_info_path']}") + print(f"plans_path: {col_paths['plans_path']}") + print(f"\n### ### ### ### ### ### ###\n") + + return col_paths diff --git a/examples/fl_post/fl/project/nnunet_model_setup.py b/examples/fl_post/fl/project/nnunet_model_setup.py new file mode 100644 index 000000000..a647a2f44 --- /dev/null +++ b/examples/fl_post/fl/project/nnunet_model_setup.py @@ -0,0 +1,105 @@ +import os +import pickle as pkl +import shutil + +from src.nnunet_v1 import train_nnunet +from nnunet.paths import default_plans_identifier + +def train_on_task(task, network, network_trainer, fold, cuda_device, plans_identifier, continue_training=False, current_epoch=0): + os.environ['CUDA_VISIBLE_DEVICES']=cuda_device + print(f"###########\nStarting training a single epoch for task: {task}\n") + # Function below is now hard coded for a single epoch of training. + train_nnunet(actual_max_num_epochs=1000, + fl_round=current_epoch, + network=network, + task=task, + network_trainer=network_trainer, + fold=fold, + continue_training=continue_training, + p=plans_identifier) + + +def get_model_folder(network, task, network_trainer, plans_identifier, fold, results_folder=os.environ['RESULTS_FOLDER']): + return os.path.join(results_folder, 'nnUNet',network, task, network_trainer + '__' + plans_identifier, f'fold_{fold}') + + +def get_col_model_paths(model_folder): + return {'initial_model_path': os.path.join(model_folder, 'model_initial_checkpoint.model'), + 'final_model_path': os.path.join(model_folder, 'model_final_checkpoint.model'), + 'initial_model_info_path': os.path.join(model_folder, 'model_initial_checkpoint.model.pkl'), + 'final_model_info_path': os.path.join(model_folder, 'model_final_checkpoint.model.pkl')} + + +def get_col_plans_path(network, task, plans_identifier): + # returning a dictionary in ordre to incorporate it more easily into another paths dict + preprocessed_path = os.environ['nnUNet_preprocessed'] + plans_write_dirpath = os.path.join(preprocessed_path, task) + plans_write_path_2d = os.path.join(plans_write_dirpath, plans_identifier + "_plans_2D.pkl") + plans_write_path_3d = os.path.join(plans_write_dirpath, plans_identifier + "_plans_3D.pkl") + + if network =='2d': + plans_write_path = plans_write_path_2d + else: + plans_write_path = plans_write_path_3d + + return {'plans_path': plans_write_path} + +def delete_2d_data(network, task, plans_identifier): + if network == '2d': + raise ValueError(f"2D data should not be deleted when performing 2d training.") + else: + preprocessed_path = os.environ['nnUNet_preprocessed'] + plan_dirpath = os.path.join(preprocessed_path, task) + plan_path_2d = os.path.join(plan_dirpath, "nnUNetPlansv2.1_plans_2D.pkl") + + if os.path.exists(plan_dirpath): + # load 2d plan to help construct 2D data directory + with open(plan_path_2d, 'rb') as _file: + plan_2d = pkl.load(_file) + data_dir_2d = os.path.join(plan_dirpath, plan_2d['data_identifier'] + '_stage' + str(list(plan_2d['plans_per_stage'].keys())[-1])) + if os.path.exists(data_dir_2d): + print(f"\n###########\nDeleting 2D data directory at: {data_dir_2d} \n##############\n") + shutil.rmtree(data_dir_2d) + + + +def trim_data_and_setup_model(task, network, network_trainer, plans_identifier, fold, init_model_path, init_model_info_path, plans_path, cuda_device='0'): + """ + Note that plans_identifier here is designated from fl_setup.py and is an alternative to the default one due to overwriting of the local plans by a globally distributed one + """ + + # get or create architecture info + + model_folder = get_model_folder(network=network, + task=task, + network_trainer=network_trainer, + plans_identifier=plans_identifier, + fold=fold) + if not os.path.exists(model_folder): + os.makedirs(model_folder, exist_ok=False) + + col_paths = get_col_model_paths(model_folder=get_model_folder(network=network, + task=task, + network_trainer=network_trainer, + plans_identifier=plans_identifier, + fold=fold)) + col_paths.update(get_col_plans_path(network=network, task=task, plans_identifier=plans_identifier)) + + if not init_model_path: + if plans_path: + raise ValueError(f"If the initial model is not provided then we do not expect the plans_path to be provided either (plans file and initial model are sourced the same way).") + # train for a single epoch to get an initial model (this uses the default plans identifier) + train_on_task(task=task, network=network, network_trainer=network_trainer, fold=fold, cuda_device=cuda_device, plans_identifier=default_plans_identifier) + # now copy the trained final model and info into the initial paths + shutil.copyfile(src=col_paths['final_model_path'],dst=col_paths['initial_model_path']) + shutil.copyfile(src=col_paths['final_model_info_path'],dst=col_paths['initial_model_info_path']) + else: + print(f"\n######### WRITING MODEL, MODEL INFO, and PLANS #########\n") + shutil.copy(src=plans_path,dst=col_paths['plans_path']) + shutil.copyfile(src=init_model_path,dst=col_paths['initial_model_path']) + shutil.copyfile(src=init_model_info_path,dst=col_paths['initial_model_info_path']) + # now copy these files also into the final paths + shutil.copyfile(src=col_paths['initial_model_path'],dst=col_paths['final_model_path']) + shutil.copyfile(src=col_paths['initial_model_info_path'],dst=col_paths['final_model_info_path']) + + return col_paths['initial_model_path'], col_paths['final_model_path'], col_paths['initial_model_info_path'], col_paths['final_model_info_path'], col_paths['plans_path'] \ No newline at end of file diff --git a/examples/fl_post/fl/project/nnunet_setup.py b/examples/fl_post/fl/project/nnunet_setup.py new file mode 100644 index 000000000..e9b23f814 --- /dev/null +++ b/examples/fl_post/fl/project/nnunet_setup.py @@ -0,0 +1,236 @@ +import argparse + +# We will be syncing training across many nodes who independently preprocess data +# In order to do this we will need to sync the training plans (defining the model architecture etc.) +# NNUnet does this by overwriting the plans file which includes a unique alternative plans identifier other than the default one + +from nnunet.paths import default_plans_identifier + +from nnunet_data_setup import setup_fl_data + +local_plans_identifier = default_plans_identifier +shared_plans_identifier = 'nnUNetPlans_pretrained_POSTOPP' + + +def list_of_strings(arg): + return arg.split(',') + +def main(postopp_pardir, + three_digit_task_num, + task_name, + percent_train=0.8, + split_logic='by_subject_time_pair', + split_seed=7777777, + network='3d_fullres', + network_trainer='nnUNetTrainerV2', + fold='0', + init_model_path=None, + init_model_info_path=None, + plans_path=None, + local_plans_identifier=local_plans_identifier, + shared_plans_identifier=shared_plans_identifier, + overwrite_nnunet_datadirs=True, + timestamp_selection='all', + cuda_device='0', + verbose=False): + """ + Generates symlinks to be used for NNUnet training, assuming we already have a + dataset on file coming from MLCommons RANO experiment data prep. + + Also creates the json file for the data, as well as runs nnunet preprocessing. + + should be run using a virtual environment that has nnunet version 1 installed. + + args: + postopp_pardir(str) : Parent directory for postopp data, which should contain 'data' and 'labels' subdirectories with structure: + ├── data + │ ├── AAAC_0 + │ │ ├── 2008.03.30 + │ │ │ ├── AAAC_0_2008.03.30_brain_t1c.nii.gz + │ │ │ ├── AAAC_0_2008.03.30_brain_t1n.nii.gz + │ │ │ ├── AAAC_0_2008.03.30_brain_t2f.nii.gz + │ │ │ └── AAAC_0_2008.03.30_brain_t2w.nii.gz + │ │ └── 2008.12.17 + │ │ ├── AAAC_0_2008.12.17_brain_t1c.nii.gz + │ │ ├── AAAC_0_2008.12.17_brain_t1n.nii.gz + │ │ ├── AAAC_0_2008.12.17_brain_t2f.nii.gz + │ │ └── AAAC_0_2008.12.17_brain_t2w.nii.gz + │ ├── AAAC_1 + │ │ ├── 2008.03.30_duplicate + │ │ │ ├── AAAC_1_2008.03.30_duplicate_brain_t1c.nii.gz + │ │ │ ├── AAAC_1_2008.03.30_duplicate_brain_t1n.nii.gz + │ │ │ ├── AAAC_1_2008.03.30_duplicate_brain_t2f.nii.gz + │ │ │ └── AAAC_1_2008.03.30_duplicate_brain_t2w.nii.gz + │ │ └── 2008.12.17_duplicate + │ │ ├── AAAC_1_2008.12.17_duplicate_brain_t1c.nii.gz + │ │ ├── AAAC_1_2008.12.17_duplicate_brain_t1n.nii.gz + │ │ ├── AAAC_1_2008.12.17_duplicate_brain_t2f.nii.gz + │ │ └── AAAC_1_2008.12.17_duplicate_brain_t2w.nii.gz + │ ├── AAAC_extra + │ │ └── 2008.12.10 + │ │ ├── AAAC_extra_2008.12.10_brain_t1c.nii.gz + │ │ ├── AAAC_extra_2008.12.10_brain_t1n.nii.gz + │ │ ├── AAAC_extra_2008.12.10_brain_t2f.nii.gz + │ │ └── AAAC_extra_2008.12.10_brain_t2w.nii.gz + │ ├── data.csv + │ └── splits.csv + ├── labels + │ ├── AAAC_0 + │ │ ├── 2008.03.30 + │ │ │ └── AAAC_0_2008.03.30_final_seg.nii.gz + │ │ └── 2008.12.17 + │ │ └── AAAC_0_2008.12.17_final_seg.nii.gz + │ ├── AAAC_1 + │ │ ├── 2008.03.30_duplicate + │ │ │ └── AAAC_1_2008.03.30_duplicate_final_seg.nii.gz + │ │ └── 2008.12.17_duplicate + │ │ └── AAAC_1_2008.12.17_duplicate_final_seg.nii.gz + │ └── AAAC_extra + │ └── 2008.12.10 + │ └── AAAC_extra_2008.12.10_final_seg.nii.gz + └── report.yaml + + three_digit_task_num(str) : Should start with '5' and not collide with other NNUnet task nums on your system. + init_model_path (str) : path to initial (pretrained) model file [default None] - must be provided if init_model_info_path is. + [ONLY USE IF YOU KNOW THE MODEL ARCHITECTURE MAKES SENSE FOR THE FEDERATION.] + init_model_info_path(str) : path to initial (pretrained) model info pikle file [default None]- must be provided if init_model_path is. + [ONLY USE IF YOU KNOW THE MODEL ARCHITECTURE MAKES SENSE FOR THE FEDERATION.] + plans_path(str) : Path the the NNUnet plan file + [ONLY USE IF YOU KNOW THE MODEL ARCHITECTURE MAKES SENSE FOR THE FEDERATION.] + task_name(str) : Name of task that is part of the task name + percent_train(float) : The percentage of samples to split into the train portion for the fold specified below (NNUnet makes its own folds but we overwrite + all with None except the fold indicated below and put in our own split instead determined by a hard coded split logic default) + split_logic(str) : Determines how the percent_train is computed. Choices are: 'by_subject' and 'by_subject_time_pair' (see inner function docstring) + split_seed(int) : base rng seed used in split logic + network(str) : NNUnet network to be used + network_trainer(str) : NNUnet network trainer to be used + fold(str) : Fold to train on, can be a sting indicating an int, or can be 'all' + local_plans_identifier(str) : Used in the plans file naming for collaborators that will be performing local training to produce a pretrained model. + shared_plans_identifier(str) : Used in the plans file naming for the shared plan distributed across the federation. + overwrite_nnunet_datadirs(str) : Allows overwriting NNUnet directories for given task number and name. + task_name(str) : Any string task name. + timestamp_selection(str) : Indicates how to determine the timestamp to pick. Only 'earliest', 'latest', or 'all' are supported. + for each subject ID at the source: 'latest' and 'earliest' are the only ones supported so far + verbose(bool) : If True, print debugging information. + """ + + # some argument inspection + if str(three_digit_task_num)[0] != '5': + raise ValueError(f"The three digit task number: {three_digit_task_num} should start with 5 to avoid NNUnet repository tasks, but it starts with {three_digit_task_num[0]}") + if init_model_path or init_model_info_path: + if not init_model_path or not init_model_info_path: + raise ValueError(f"If either init_model_path or init_model_info_path are provided, they both must be.") + if init_model_path: + if not init_model_path.endswith('.model'): + raise ValueError(f"Initial model file should end with, '.model'") + if not init_model_info_path.endswith('.model.pkl'): + raise ValueError(f"Initial model info file should end with, 'model.pkl'") + + # task_folder_info is a zipped lists indexed over tasks (collaborators) + # zip(task_nums, tasks, nnunet_dst_pardirs, nnunet_images_train_pardirs, nnunet_labels_train_pardirs) + + col_paths = setup_fl_data(postopp_pardir=postopp_pardir, + three_digit_task_num=three_digit_task_num, + task_name=task_name, + percent_train=percent_train, + split_logic=split_logic, + split_seed=split_seed, + fold=fold, + timestamp_selection=timestamp_selection, + network=network, + network_trainer=network_trainer, + local_plans_identifier=local_plans_identifier, + shared_plans_identifier=shared_plans_identifier, + init_model_path=init_model_path, + init_model_info_path=init_model_info_path, + plans_path=plans_path, + cuda_device=cuda_device, + overwrite_nnunet_datadirs=overwrite_nnunet_datadirs, + verbose=verbose) + + return col_paths + +if __name__ == '__main__': + + argparser = argparse.ArgumentParser() + argparser.add_argument( + '--postopp_pardir', + type=str, + help="Parent directory to postopp data.") + argparser.add_argument( + '--three_digit_task_num', + type=int, + help="Should start with '5'. If fedsim == N, all N task numbers starting with this number will be used.") + argparser.add_argument( + '--init_model_path', + type=str, + default=None, + help="Path to initial (pretrained) model file [ONLY USE IF YOU KNOW THE MODEL ARCHITECTURE MAKES SENSE FOR THE FEDERATION.].") + argparser.add_argument( + '--init_model_info_path', + type=str, + default=None, + help="Path to initial (pretrained) model info file [ONLY USE IF YOU KNOW THE MODEL ARCHITECTURE MAKES SENSE FOR THE FEDERATION.].") + argparser.add_argument( + '--plans_path', + type=str, + default=None, + help="Path to the training plan file[ONLY USE IF YOU KNOW THE MODEL ARCHITECTURE MAKES SENSE FOR THE FEDERATION.].") + argparser.add_argument( + '--task_name', + type=str, + help="Part of the NNUnet data task directory name. With 'first_three_digit_task_num being 'XXX', the directory name becomes: .../nnUNet_raw_data_base/nnUNet_raw_data/TaskXXX_.") + argparser.add_argument( + '--percent_train', + type=float, + default=0.8, + help="The percentage of samples to split into the train portion for the fold specified below (NNUnet makes its own folds but we overwrite) - see docstring in main") + argparser.add_argument( + '--split_logic', + type=str, + default='by_subject_time_pair', + help="Determines how the percent_train is computed. Choices are: 'by_subject' and 'by_subject_time_pair' (see inner function docstring)") + argparser.add_argument( + '--split_seed', + type=int, + default=7777777, + help="base rng seed used in split logic") + argparser.add_argument( + '--network', + type=str, + default='3d_fullres', + help="NNUnet network to be used.") + argparser.add_argument( + '--network_trainer', + type=str, + default='nnUNetTrainerV2', + help="NNUnet network trainer to be used.") + argparser.add_argument( + '--fold', + type=str, + default='0', + help="Fold to train on, can be a sting indicating an int, or can be 'all'.") + argparser.add_argument( + '--timestamp_selection', + type=str, + default='all', + help="Indicates how to determine the timestamp to pick for each subject ID at the source: 'latest' and 'earliest' are the only ones supported so far.") + argparser.add_argument( + '--cuda_device', + type=str, + default='0', + help="Used for the setting of os.environ['CUDA_VISIBLE_DEVICES']") + argparser.add_argument( + '--overwrite_nnunet_datadirs', + action='store_true', + help="Allows overwriting NNUnet directories with task numbers from first_three_digit_task_num to that plus one les than number of insitutions.") + argparser.add_argument( + '--verbose', + action='store_true', + help="Print debuging information.") + + args = argparser.parse_args() + + kwargs = vars(args) + + main(**kwargs) diff --git a/examples/fl_post/fl/project/plan.py b/examples/fl_post/fl/project/plan.py new file mode 100644 index 000000000..2feb1bf52 --- /dev/null +++ b/examples/fl_post/fl/project/plan.py @@ -0,0 +1,16 @@ +import yaml + + +def generate_plan(training_config_path, aggregator_config_path, plan_path): + with open(training_config_path) as f: + training_config = yaml.safe_load(f) + with open(aggregator_config_path) as f: + aggregator_config = yaml.safe_load(f) + + # TODO: key checks. Also, define what should be considered aggregator_config + # (e.g., tls=true, reconnect_interval, ...) + training_config["network"]["settings"]["agg_addr"] = aggregator_config["address"] + training_config["network"]["settings"]["agg_port"] = aggregator_config["port"] + + with open(plan_path, "w") as f: + yaml.dump(training_config, f) diff --git a/examples/fl_post/fl/project/requirements.txt b/examples/fl_post/fl/project/requirements.txt new file mode 100644 index 000000000..2533d25c4 --- /dev/null +++ b/examples/fl_post/fl/project/requirements.txt @@ -0,0 +1,4 @@ +onnx==1.13.0 +typer==0.9.0 +git+https://github.com/brandon-edwards/nnUNet_v1.7.1_local.git@328e7669cc371a7603835bbb42fec6e12e62f092#egg=nnunet +numpy==1.26.4 diff --git a/examples/fl_post/fl/project/src/__init__.py b/examples/fl_post/fl/project/src/__init__.py new file mode 100644 index 000000000..f1410b129 --- /dev/null +++ b/examples/fl_post/fl/project/src/__init__.py @@ -0,0 +1,3 @@ +# Copyright (C) 2020-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""You may copy this file as the starting point of your own model.""" diff --git a/examples/fl_post/fl/project/src/launch_collaborator_with_env_vars.sh b/examples/fl_post/fl/project/src/launch_collaborator_with_env_vars.sh new file mode 100644 index 000000000..b85de2bb2 --- /dev/null +++ b/examples/fl_post/fl/project/src/launch_collaborator_with_env_vars.sh @@ -0,0 +1 @@ +CUDA_VISIBLE_DEVICES=$1 fx collaborator start -n $2 \ No newline at end of file diff --git a/examples/fl_post/fl/project/src/nnunet_dummy_dataloader.py b/examples/fl_post/fl/project/src/nnunet_dummy_dataloader.py new file mode 100644 index 000000000..68cbbbc40 --- /dev/null +++ b/examples/fl_post/fl/project/src/nnunet_dummy_dataloader.py @@ -0,0 +1,36 @@ +# Copyright (C) 2020-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +Contributors: Micah Sheller, Brandon Edwards + +""" + +"""You may copy this file as the starting point of your own model.""" + +import json +import os + +class NNUNetDummyDataLoader(): + def __init__(self, data_path, p_train): + self.task_name = data_path + data_base_path = os.path.join(os.environ['nnUNet_preprocessed'], self.task_name) + with open(f'{data_base_path}/dataset.json', 'r') as f: + data_config = json.load(f) + data_size = data_config['numTraining'] + + # TODO: determine how nnunet validation splits round + self.train_data_size = int(p_train * data_size) + self.valid_data_size = data_size - self.train_data_size + + def get_feature_shape(self): + return [1,1,1] + + def get_train_data_size(self): + return self.train_data_size + + def get_valid_data_size(self): + return self.valid_data_size + + def get_task_name(self): + return self.task_name \ No newline at end of file diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py new file mode 100644 index 000000000..b162eb1e0 --- /dev/null +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -0,0 +1,294 @@ + + + + +# The following was copied and modified from the source: +# https://github.com/kaapana/kaapana/blob/26d71920d53c3110e2494cbb2ddb0cbb996b880a/data-processing/base-images/base-nnunet/files/patched/run_training.py#L213 + + +# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import numpy as np +import torch +import random +from batchgenerators.utilities.file_and_folder_operations import * +from nnunet.run.default_configuration import get_default_configuration +from nnunet.run.load_pretrained_weights import load_pretrained_weights +from nnunet.training.cascade_stuff.predict_next_stage import predict_next_stage +from nnunet.training.network_training.nnUNetTrainer import nnUNetTrainer +from nnunet.training.network_training.nnUNetTrainerCascadeFullRes import ( + nnUNetTrainerCascadeFullRes, +) +from nnunet.training.network_training.nnUNetTrainerV2_CascadeFullRes import ( + nnUNetTrainerV2CascadeFullRes, +) +from nnunet.utilities.task_name_id_conversion import convert_id_to_task_name + + +# We will be syncing training across many nodes who independently preprocess data +# In order to do this we will need to sync the training plans (defining the model architecture etc.) +# NNUnet does this by overwriting the plans file which includes a unique alternative plans identifier other than the default one +plans_param = 'nnUNetPlans_pretrained_POSTOPP' +#from nnunet.paths import default_plans_identifier + +def seed_everything(seed=1234): + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + + +def train_nnunet(actual_max_num_epochs, + fl_round, + val_epoch=True, + train_epoch=True, + train_cutoff=np.inf, + val_cutoff=np.inf, + network='3d_fullres', + network_trainer='nnUNetTrainerV2', + task='Task543_FakePostOpp_More', + fold='0', + continue_training=True, + c=False, + p=plans_param, + use_compressed_data=False, + deterministic=False, + npz=False, + find_lr=False, + valbest=False, + fp32=False, + val_folder='validation_raw', + disable_saving=False, + disable_postprocessing_on_folds=True, + val_disable_overwrite=True, + disable_next_stage_pred=False, + pretrained_weights=None): + + """ + actual_max_num_epochs (int): Provides the number of epochs intended to be trained over the course of the whole federation (for lr scheduling) + (this needs to be held constant outside of individual calls to this function so that the lr is consistetly scheduled) + fl_round (int): Federated round, equal to the epoch used for the model (in lr scheduling) + val_epoch (bool) : Will validation be performed + train_epoch (bool) : Will training run (rather than val only) + task (int): can be task name or task id + fold: "0, 1, ..., 5 or 'all'" + c: use this if you want to continue a training + p: plans identifier. Only change this if you created a custom experiment planner + use_compressed_data: "If you set use_compressed_data, the training cases will not be decompressed. Reading compressed data " + "is much more CPU and RAM intensive and should only be used if you know what you are " + "doing" + deterministic: "Makes training deterministic, but reduces training speed substantially. I (Fabian) think " + "this is not necessary. Deterministic training will make you overfit to some random seed. " + "Don't use that." + npz: "if set then nnUNet will " + "export npz files of " + "predicted segmentations " + "in the validation as well. " + "This is needed to run the " + "ensembling step so unless " + "you are developing nnUNet " + "you should enable this" + find_lr: not used here, just for fun + valbest: hands off. This is not intended to be used + fp32: disable mixed precision training and run old school fp32 + val_folder: name of the validation folder. No need to use this for most people + disable_saving: If set nnU-Net will not save any parameter files (except a temporary checkpoint that " + "will be removed at the end of the training). Useful for development when you are " + "only interested in the results and want to save some disk space + disable_postprocessing_on_folds: Running postprocessing on each fold only makes sense when developing with nnU-Net and " + "closely observing the model performance on specific configurations. You do not need it " + "when applying nnU-Net because the postprocessing for this will be determined only once " + "all five folds have been trained and nnUNet_find_best_configuration is called. Usually " + "running postprocessing on each fold is computationally cheap, but some users have " + "reported issues with very large images. If your images are large (>600x600x600 voxels) " + "you should consider setting this flag. + val_disable_overwrite: If True, validation does not overwrite existing segmentations + pretrained_wieghts: path to nnU-Net checkpoint file to be used as pretrained model (use .model " + "file, for example model_final_checkpoint.model). Will only be used when actually training. " + "Optional. Beta. Use with caution." + disable_next_stage_pred: If True, do not predict next stage + """ + + class Arguments(): + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + args = Arguments(**locals()) + + if args.deterministic: + seed_everything() + + task = args.task + fold = args.fold + network = args.network + network_trainer = args.network_trainer + plans_identifier = args.p + find_lr = args.find_lr + disable_postprocessing_on_folds = args.disable_postprocessing_on_folds + + use_compressed_data = args.use_compressed_data + decompress_data = not use_compressed_data + + deterministic = args.deterministic + valbest = args.valbest + + fp32 = args.fp32 + run_mixed_precision = not fp32 + + val_folder = args.val_folder + # interp_order = args.interp_order + # interp_order_z = args.interp_order_z + # force_separate_z = args.force_separate_z + + if not task.startswith("Task"): + task_id = int(task) + task = convert_id_to_task_name(task_id) + + if fold == "all": + pass + else: + fold = int(fold) + + # if force_separate_z == "None": + # force_separate_z = None + # elif force_separate_z == "False": + # force_separate_z = False + # elif force_separate_z == "True": + # force_separate_z = True + # else: + # raise ValueError("force_separate_z must be None, True or False. Given: %s" % force_separate_z) + ( + plans_file, + output_folder_name, + dataset_directory, + batch_dice, + stage, + trainer_class, + ) = get_default_configuration(network, task, network_trainer, plans_identifier) + + if trainer_class is None: + raise RuntimeError( + "Could not find trainer class in nnunet.training.network_training" + ) + + if network == "3d_cascade_fullres": + assert issubclass( + trainer_class, (nnUNetTrainerCascadeFullRes, nnUNetTrainerV2CascadeFullRes) + ), ( + "If running 3d_cascade_fullres then your " + "trainer class must be derived from " + "nnUNetTrainerCascadeFullRes" + ) + else: + assert issubclass( + trainer_class, nnUNetTrainer + ), "network_trainer was found but is not derived from nnUNetTrainer" + + trainer = trainer_class( + plans_file, + fold, + actual_max_num_epochs=actual_max_num_epochs, + output_folder=output_folder_name, + dataset_directory=dataset_directory, + batch_dice=batch_dice, + stage=stage, + unpack_data=decompress_data, + deterministic=deterministic, + fp16=run_mixed_precision, + ) + + + trainer.initialize(True) + + if os.getenv("PREP_INCREMENT_STEP", None) == "from_dataset_properties": + trainer.save_checkpoint( + join(trainer.output_folder, "model_final_checkpoint.model") + ) + print("Preparation round: Model-averaging") + return + + if find_lr: + trainer.find_lr(num_iters=self.actual_max_num_epochs) + else: + if args.continue_training: + # -c was set, continue a previous training and ignore pretrained weights + trainer.load_latest_checkpoint() + elif (not args.continue_training) and (args.pretrained_weights is not None): + # we start a new training. If pretrained_weights are set, use them + load_pretrained_weights(trainer.network, args.pretrained_weights) + else: + # new training without pretraine weights, do nothing + pass + + # we want latest checkoint only (not best or any intermediate) + trainer.save_final_checkpoint = ( + True # whether or not to save the final checkpoint + ) + trainer.save_best_checkpoint = ( + False # whether or not to save the best checkpoint according to + ) + # self.best_val_eval_criterion_MA + trainer.save_intermediate_checkpoints = ( + False # whether or not to save checkpoint_latest. We need that in case + ) + # the training chashes + trainer.save_latest_only = ( + True # if false it will not store/overwrite _latest but separate files each + ) + + trainer.max_num_epochs = fl_round + 1 + trainer.epoch = fl_round + + # STAYing WITH NNUNET CONVENTION OF 50 AND 250 VAL AND TRAIN BATCHES RESPECTIVELY + # Note: This convention makes sense in combination with a train_completion_dampener of 0.0 + num_val_batches_per_epoch = 50 + num_train_batches_per_epoch = 250 + + # the nnunet trainer attributes have a different naming convention than I am using + trainer.num_batches_per_epoch = num_train_batches_per_epoch + trainer.num_val_batches_per_epoch = num_val_batches_per_epoch + + batches_applied_train, \ + batches_applied_val, \ + this_ave_train_loss, \ + this_ave_val_loss, \ + this_val_eval_metrics, \ + this_val_eval_metrics_C1, \ + this_val_eval_metrics_C2, \ + this_val_eval_metrics_C3, \ + this_val_eval_metrics_C4 = trainer.run_training(train_cutoff=train_cutoff, + val_cutoff=val_cutoff, + val_epoch=val_epoch, + train_epoch=train_epoch) + + train_completed = batches_applied_train / float(num_train_batches_per_epoch) + val_completed = batches_applied_val / float(num_val_batches_per_epoch) + + return train_completed, \ + val_completed, \ + this_ave_train_loss, \ + this_ave_val_loss, \ + this_val_eval_metrics, \ + this_val_eval_metrics_C1, \ + this_val_eval_metrics_C2, \ + this_val_eval_metrics_C3, \ + this_val_eval_metrics_C4 + + diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py new file mode 100644 index 000000000..3191d9a57 --- /dev/null +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -0,0 +1,298 @@ +# Copyright (C) 2020-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +Contributors: Micah Sheller, Patrick Foley, Brandon Edwards + +""" +# TODO: Clean up imports + +import os +import subprocess +import shutil +import time +import pickle as pkl +from copy import deepcopy +import hashlib +import yaml + +import numpy as np +import torch + +from openfl.utilities import TensorKey +from openfl.utilities.split import split_tensor_dict_for_holdouts + + +from .runner_pt_chkpt import PyTorchCheckpointTaskRunner +from .nnunet_v1 import train_nnunet + +shared_plans_identifier = 'nnUNetPlans_pretrained_POSTOPP' + +class PyTorchNNUNetCheckpointTaskRunner(PyTorchCheckpointTaskRunner): + """An abstract class for PyTorch model based Tasks, where training, validation etc. are processes that + pull model state from a PyTorch checkpoint.""" + + def __init__(self, + nnunet_task=None, + config_path=None, + actual_max_num_epochs=1000, + **kwargs): + """Initialize. + + Args: + nnunet_task (str) : Task string used to identify the data and model folders + config_path(str) : Path to the configuration file used by the training and validation script. + actual_max_num_epochs (int) : Number of epochs for which this collaborator's model will be trained, should match the total rounds of federation in which this runner is participating + kwargs : Additional key work arguments (will be passed to rebuild_model, initialize_tensor_key_functions, TODO: ). + TODO: + """ + + if 'nnUNet_raw_data_base' not in os.environ: + raise ValueError("NNUNet V1 requires that 'nnUNet_raw_data_base' be set either in the flplan or in the environment variables") + if 'nnUNet_preprocessed' not in os.environ: + raise ValueError("NNUNet V1 requires that 'nnUNet_preprocessed' be set either in the flplan or in the environment variables") + if 'RESULTS_FOLDER' not in os.environ: + raise ValueError("NNUNet V1 requires that 'RESULTS_FOLDER' be set either in the flplan or in the environment variables") + + super().__init__( + checkpoint_path_initial=os.path.join( + os.environ['RESULTS_FOLDER'], + f'nnUNet/3d_fullres/{nnunet_task}/nnUNetTrainerV2__{shared_plans_identifier}/fold_0/', + 'model_initial_checkpoint.model' + ), + checkpoint_path_save=os.path.join( + os.environ['RESULTS_FOLDER'], + f'nnUNet/3d_fullres/{nnunet_task}/nnUNetTrainerV2__{shared_plans_identifier}/fold_0/', + 'model_final_checkpoint.model' + ), + checkpoint_path_load=os.path.join( + os.environ['RESULTS_FOLDER'], + f'nnUNet/3d_fullres/{nnunet_task}/nnUNetTrainerV2__{shared_plans_identifier}/fold_0/', + 'model_final_checkpoint.model' + ), + **kwargs, + ) + + self.config_path = config_path + self.actual_max_num_epochs=actual_max_num_epochs + + # self.task_completed is a dictionary of task to amount completed as a float in [0,1] + # Values will be dynamically updated + # TODO: Tasks are hard coded for now + self.task_completed = {'aggregated_model_validation': 1.0, + 'train': 1.0, + 'locally_tuned_model_validation': 1.0} + + + def write_tensors_into_checkpoint(self, tensor_dict, with_opt_vars): + """ + Save model state in tensor_dict to in a pickle file at self.checkpoint_out_path. Uses pt.save(). + All state in the checkpoint other than the model state will be kept as is in the file. + Note: Utilization of a with_opt_vars input will be needed (along with saving an initial state optimizer state on disk), + will be needed if a self.opt_treatement of 'RESET' or 'AGG' are to be used + + Here is an example of a dictionary NNUnet uses for its state: + save_this = + { + 'epoch': self.epoch + 1, + 'state_dict': state_dict, + 'optimizer_state_dict': optimizer_state_dict, + 'lr_scheduler_state_dict': lr_sched_state_dct, + 'plot_stuff': (self.all_tr_losses, self.all_val_losses, self.all_val_losses_tr_mode, + self.all_val_eval_metrics), + 'best_stuff' : (self.best_epoch_based_on_MA_tr_loss, self.best_MA_tr_loss_for_patience, self.best_val_eval_criterion_MA) + } + + + Args: + tensor_dict (dictionary) : Dictionary with keys + with_opt_vars (bool) : Whether or not to save the optimizer state as well (this info will be part of the tensor dict in this case - i.e. tensor_dict = {**model_state, **opt_state}) + kwargs : unused + + Returns: + epoch + """ + # TODO: For now leaving the lr_scheduler_state_dict unchanged (this may be best though) + # TODO: Do we want to test this for 'RESET', 'CONTINUE_GLOBAL'? + + # get device for correct placement of tensors + device = self.device + + checkpoint_dict = self.load_checkpoint(checkpoint_path=self.checkpoint_path_load, map_location=device) + epoch = checkpoint_dict['epoch'] + new_state = {} + # grabbing keys from the checkpoint state dict, poping from the tensor_dict + seen_keys = [] + for k in checkpoint_dict['state_dict']: + if k not in seen_keys: + seen_keys.append(k) + else: + raise ValueError(f"\nKey {k} apears at least twice!!!!/n") + new_state[k] = torch.from_numpy(tensor_dict[k].copy()).to(device) + checkpoint_dict['state_dict'] = new_state + + if with_opt_vars: + # see if there is state to restore first + if tensor_dict.pop('__opt_state_needed') == 'true': + checkpoint_dict = self._set_optimizer_state(derived_opt_state_dict=tensor_dict, + checkpoint_dict=checkpoint_dict) + self.save_checkpoint(checkpoint_dict) + + # FIXME: this should be unnecessary now + # we may want to know epoch so that we can properly tell the training script to what epoch to train (NNUnet V1 only supports training with a max_num_epochs setting) + return epoch + + + def train(self, col_name, round_num, input_tensor_dict, epochs, val_cutoff_time=np.inf, train_cutoff_time=np.inf, train_completion_dampener=0.0, **kwargs): + # TODO: Figure out the right name to use for this method and the default assigner + """Perform training for a specified number of epochs.""" + + self.rebuild_model(input_tensor_dict=input_tensor_dict, **kwargs) + # 1. Insert tensor_dict info into checkpoint + self.set_tensor_dict(tensor_dict=input_tensor_dict, with_opt_vars=False) + self.logger.info(f"Training for round:{round_num}") + train_completed, \ + val_completed, \ + this_ave_train_loss, \ + this_ave_val_loss, \ + this_val_eval_metrics, \ + this_val_eval_metrics_C1, \ + this_val_eval_metrics_C2, \ + this_val_eval_metrics_C3, \ + this_val_eval_metrics_C4 = train_nnunet(actual_max_num_epochs=self.actual_max_num_epochs, + fl_round=round_num, + train_cutoff=train_cutoff_time, + val_cutoff = val_cutoff_time, + task=self.data_loader.get_task_name(), + val_epoch=True, + train_epoch=True) + + # dampen the train_completion + """ + values in range: (0, 1] with values near 0.0 making all train_completion rates shift nearer to 1.0, thus making the + trained model update weighting during aggregation stay closer to the plain data size weighting + specifically, update_weight = train_data_size / train_completed**train_completion_dampener + """ + train_completed = train_completed**train_completion_dampener + + # update amount of task completed + self.task_completed['train'] = train_completed + self.task_completed['locally_tuned_model_validation'] = val_completed + + # 3. Prepare metrics + metrics = {'train_loss': this_ave_train_loss} + + global_tensor_dict, local_tensor_dict = self.convert_results_to_tensorkeys(col_name, round_num, metrics, insert_model=True) + + self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work. Exact rates are: {train_completed} and {val_completed}") + + return global_tensor_dict, local_tensor_dict + + + def validate(self, col_name, round_num, input_tensor_dict, val_cutoff_time=np.inf, from_checkpoint=False, **kwargs): + # TODO: Figure out the right name to use for this method and the default assigner + """Perform validation.""" + + if not from_checkpoint: + self.rebuild_model(input_tensor_dict=input_tensor_dict, **kwargs) + # 1. Insert tensor_dict info into checkpoint + self.set_tensor_dict(tensor_dict=input_tensor_dict, with_opt_vars=False) + self.logger.info(f"Validating for round:{round_num}") + # 2. Train/val function existing externally + # Some todo inside function below + train_completed, \ + val_completed, \ + this_ave_train_loss, \ + this_ave_val_loss, \ + this_val_eval_metrics, \ + this_val_eval_metrics_C1, \ + this_val_eval_metrics_C2, \ + this_val_eval_metrics_C3, \ + this_val_eval_metrics_C4 = train_nnunet(actual_max_num_epochs=self.actual_max_num_epochs, + fl_round=round_num, + train_cutoff=0, + val_cutoff = val_cutoff_time, + task=self.data_loader.get_task_name(), + val_epoch=True, + train_epoch=False) + # double check + if train_completed != 0.0: + raise ValueError(f"Tried to validate only, but got a non-zero amount ({train_completed}) of training done.") + + # update amount of task completed + self.task_completed['aggregated_model_validation'] = val_completed + + self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work. Exact rates are: {train_completed} and {val_completed}") + + + # 3. Prepare metrics + metrics = {'val_eval': this_val_eval_metrics, + 'val_eval_C1': this_val_eval_metrics_C1, + 'val_eval_C2': this_val_eval_metrics_C2, + 'val_eval_C3': this_val_eval_metrics_C3, + 'val_eval_C4': this_val_eval_metrics_C4} + else: + checkpoint_dict = self.load_checkpoint(checkpoint_path=self.checkpoint_path_load) + + all_tr_losses, \ + all_val_losses, \ + all_val_losses_tr_mode, \ + all_val_eval_metrics, \ + all_val_eval_metrics_C1, \ + all_val_eval_metrics_C2, \ + all_val_eval_metrics_C3, \ + all_val_eval_metrics_C4 = checkpoint_dict['plot_stuff'] + # these metrics are appended to the checkpoint each call to train, so it is critical that we are grabbing this right after + metrics = {'val_eval': all_val_eval_metrics[-1], + 'val_eval_C1': all_val_eval_metrics_C1[-1], + 'val_eval_C2': all_val_eval_metrics_C2[-1], + 'val_eval_C3': all_val_eval_metrics_C3[-1], + 'val_eval_C4': all_val_eval_metrics_C4[-1]} + + return self.convert_results_to_tensorkeys(col_name, round_num, metrics, insert_model=False) + + + def load_metrics(self, filepath): + """ + Load metrics from file on disk + """ + raise NotImplementedError() + """ + with open(filepath) as json_file: + metrics = json.load(json_file) + return metrics + """ + + + def get_train_data_size(self, task_name=None): + """Get the number of training examples. + + It will be used for weighted averaging in aggregation. + This overrides the parent class method, + allowing dynamic weighting by storing recent appropriate weights in class attributes. + + Returns: + int: The number of training examples, weighted by how much of the task got completed, then cast to int to satisy proto schema + """ + if not task_name: + return self.data_loader.get_train_data_size() + else: + # self.task_completed is a dictionary of task_name to amount completed as a float in [0,1] + return int(np.ceil(self.task_completed[task_name]**(-1) * self.data_loader.get_train_data_size())) + + + def get_valid_data_size(self, task_name=None): + """Get the number of training examples. + + It will be used for weighted averaging in aggregation. + This overrides the parent class method, + allowing dynamic weighting by storing recent appropriate weights in class attributes. + + Returns: + int: The number of training examples, weighted by how much of the task got completed, then cast to int to satisy proto schema + """ + if not task_name: + return self.data_loader.get_valid_data_size() + else: + # self.task_completed is a dictionary of task_name to amount completed as a float in [0,1] + return int(np.ceil(self.task_completed[task_name]**(-1) * self.data_loader.get_valid_data_size())) diff --git a/examples/fl_post/fl/project/src/runner_pt_chkpt.py b/examples/fl_post/fl/project/src/runner_pt_chkpt.py new file mode 100644 index 000000000..a7fbd2056 --- /dev/null +++ b/examples/fl_post/fl/project/src/runner_pt_chkpt.py @@ -0,0 +1,324 @@ +# Copyright (C) 2020-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +Contributors: Micah Sheller, Patrick Foley, Brandon Edwards + +""" +# TODO: Clean up imports + +import os +import shutil +from copy import deepcopy + +import numpy as np +import torch + +from openfl.utilities import TensorKey +from openfl.utilities.split import split_tensor_dict_for_holdouts + +from openfl.federated.task.runner import TaskRunner +from .runner_pt_utils import rebuild_model_util, derive_opt_state_dict, expand_derived_opt_state_dict +from .runner_pt_utils import initialize_tensorkeys_for_functions_util, to_cpu_numpy + + +class PyTorchCheckpointTaskRunner(TaskRunner): + """An abstract class for PyTorch model based Tasks, where training, validation etc. are processes that + pull model state from a PyTorch checkpoint.""" + + def __init__(self, + device = 'cuda', + gpu_num_string = '0', + checkpoint_path_initial = None, + checkpoint_path_save = None, + checkpoint_path_load = None, + **kwargs): + """Initialize. + + Args: + device(str) : Device ('cpu' or 'cuda') to be used for training and validation script computations. + checkpoint_path_initial(str): Path to the model checkpoint that will be used to initialize this object and copied to the 'write' path to start. + checkpoint_path_save(str) : Path to the model checkpoint that will be saved and passed into the training function. + checkpoint_path_load(str) : Path to the model checkpoint that will be loaded. It is also the output file path for the training function. + kwargs : Additional key work arguments (will be passed to rebuild_model, initialize_tensor_key_functions, TODO: ). + TODO: + """ + super().__init__(**kwargs) + + self.checkpoint_path_initial = checkpoint_path_initial + self.checkpoint_path_save = checkpoint_path_save + self.checkpoint_path_load = checkpoint_path_load + self.gpu_num_string = gpu_num_string + + # TODO: Understand why "weights-only" + + # TODO: Both 'CONTINUE_GLOBAL' and 'RESET' could be suported here too (currently RESET throws an exception related to a + # missmatch in size coming from the momentum buffer and other stuff either in the model or optimizer) + self.opt_treatment = 'CONTINUE_LOCAL' + + if device not in ['cpu', 'cuda']: + raise ValueError("Device argument must be 'cpu' or 'cuda', but {device} was used instead.") + self.device = device + + self.training_round_completed = False + + # enable GPUs if appropriate + if self.device == 'cuda' and not self.gpu_num_string: + raise ValueError(f"If device is 'cuda' then gpu_num must be set rather than allowing to be the default None.") + else: + os.environ['CUDA_VISIBLE_DEVICES']= self.gpu_num_string + + self.required_tensorkeys_for_function = {} + self.initialize_tensorkeys_for_functions() + + # overwrite attribute to account for one optimizer param (in every + # child model that does not overwrite get and set tensordict) that is + # not a numpy array + self.tensor_dict_split_fn_kwargs.update({ + 'holdout_tensor_names': ['__opt_state_needed'] + }) + + # Initialize model + self.replace_checkpoint(self.checkpoint_path_initial) + + + def load_checkpoint(self, checkpoint_path, map_location=None): + """ + Function used to load checkpoint from disk. + """ + checkpoint_dict = torch.load(checkpoint_path, map_location=map_location) + return checkpoint_dict + + def save_checkpoint(self, checkpoint_dict): + """ + Function to save checkpoint to disk. + """ + torch.save(checkpoint_dict, self.checkpoint_path_save) + + # defining some class methods using some util functions imported above + + def rebuild_model(self, input_tensor_dict, **kwargs): + rebuild_model_util(runner_class=self, input_tensor_dict=input_tensor_dict, **kwargs) + + def initialize_tensorkeys_for_functions(self, **kwargs): + initialize_tensorkeys_for_functions_util(runner_class=self, **kwargs) + + def get_required_tensorkeys_for_function(self, func_name, **kwargs): + """ + Get the required tensors for specified function that could be called \ + as part of a task. By default, this is just all of the layers and \ + optimizer of the model. + + Args: + func_name + + Returns: + list : [TensorKey] + """ + if func_name == 'validate': + local_model = 'apply=' + str(kwargs['apply']) + return self.required_tensorkeys_for_function[func_name][local_model] + else: + return self.required_tensorkeys_for_function[func_name] + + def reset_opt_vars(self): + current_checkpoint_dict = self.load_checkpoint(checkpoint_path=self.checkpoint_path_load) + initial_checkpoint_dict = self.load_checkpoint(checkpoint_path=self.checkpoint_path_initial) + derived_opt_state_dict = self._get_optimizer_state(checkpoint_dict=initial_checkpoint_dict) + self._set_optimizer_state(derived_opt_state_dict=derived_opt_state_dict, + checkpoint_dict=current_checkpoint_dict) + + def set_tensor_dict(self, tensor_dict, with_opt_vars=False): + """Set the tensor dictionary. + + Args: + tensor_dict: The tensor dictionary + with_opt_vars (bool): Return the tensor dictionary including the + optimizer tensors (Default=False) + """ + return self.write_tensors_into_checkpoint(tensor_dict=tensor_dict, with_opt_vars=with_opt_vars) + + def replace_checkpoint(self, path_to_replacement): + checkpoint_dict = self.load_checkpoint(checkpoint_path=path_to_replacement) + self.save_checkpoint(checkpoint_dict) + # shutil.copyfile(src=path_to_replacement, dst=self.checkpoint_path_save) + + def write_tensors_into_checkpoint(self, tensor_dict, with_opt_vars): + raise NotImplementedError + + def get_tensor_dict(self, with_opt_vars=False): + """Return the tensor dictionary. + + Args: + with_opt_vars (bool): Return the tensor dictionary including the + optimizer tensors (Default=False) + + Returns: + dict: Tensor dictionary {**dict, **optimizer_dict} + + """ + return self.read_tensors_from_checkpoint(with_opt_vars=with_opt_vars) + + def read_tensors_from_checkpoint(self, with_opt_vars): + """Return a tensor dictionary interpreted from a checkpoint. + + Args: + with_opt_vars (bool): Return the tensor dictionary including the + optimizer tensors (Default=False) + + Returns: + dict: Tensor dictionary {**dict, **optimizer_dict} + + """ + checkpoint_dict = self.load_checkpoint(checkpoint_path=self.checkpoint_path_load) + state = to_cpu_numpy(checkpoint_dict['state_dict']) + if with_opt_vars: + opt_state = self._get_optimizer_state(checkpoint_dict=checkpoint_dict) + state = {**state, **opt_state} + return state + + def _get_weights_names(self, with_opt_vars=False): + """ + Gets model and potentially optimizer state dict key names + args: + with_opt_vars(bool) : Wether or not to get the optimizer key names + """ + state = self.get_tensor_dict(with_opt_vars=with_opt_vars) + return state.keys() + + def _set_optimizer_state(self, derived_opt_state_dict, checkpoint_dict): + """Set the optimizer state. + # TODO: Refactor this, we will sparate the custom aspect of the checkpoint dict from the more general code + + Args: + derived_opt_state_dict(bool) : flattened optimizer state dict + checkpoint_dict(dict) : checkpoint dictionary + + """ + self._write_optimizer_state_into_checkpoint(derived_opt_state_dict=derived_opt_state_dict, + checkpoint_dict=checkpoint_dict, + checkpoint_path=self.checkpoint_out_path) + + def _write_optimizer_state_into_checkpoint(self, derived_opt_state_dict, checkpoint_dict, checkpoint_path): + """Write the optimizer state contained within the derived_opt_state_dict into the checkpoint_dict, + keeping some settings already contained within that checkpoint file the same, then write the resulting + checkpoint back to the checkpoint path. + TODO: Refactor this, we will separate the custom aspect of the checkpoint dict from the more general code + + Args: + derived_opt_state_dict(bool) : flattened optimizer state dict + checkpoint_dir(path) : Path to the checkpoint file + + """ + temp_state_dict = expand_derived_opt_state_dict(derived_opt_state_dict, device=self.device) + # Note: The expansion above only populates the 'params' key of each param group under opt_state_dict['param_groups'] + # Therefore the default values under the additional keys such as: 'lr', 'momentum', 'dampening', 'weight_decay', 'nesterov', 'maximize', 'foreach', 'differentiable' + # need to be held over from the their initial values. + # FIXME: Figure out whether or not this breaks learning rate scheduling and the like. + + # Letting default values (everything under temp_state_dict['param_groups'] except the 'params' key) + # stay unchanged (these are not contained in the temp_state_dict) + # Assuming therefore that the optimizer.defaults (which hold this same info) are not changing over course of training. + # We only modify the 'state' key value pairs otherwise + for group_idx, group in enumerate(temp_state_dict['param_groups']): + checkpoint_dict['optimizer_state_dict']['param_groups'][group_idx]['params'] = group['params'] + checkpoint_dict['optimizer_state_dict']['state'] = temp_state_dict['state'] + + torch.save(checkpoint_dict, checkpoint_path) + + def _get_optimizer_state(self, checkpoint_dict): + """Get the optimizer state. + Args: + checkpoint_path(str) : path to the checkpoint + """ + return self._read_opt_state_from_checkpoint(checkpoint_dict) + + + def _read_opt_state_from_checkpoint(self, checkpoint_dict): + """Read the optimizer state from the checkpoint dict and put in tensor dict format. + # TODO: Refactor this, we will sparate the custom aspect of the checkpoint dict from the more general code + """ + + opt_state_dict = deepcopy(checkpoint_dict['optimizer_state_dict']) + + # Optimizer state might not have some parts representing frozen parameters + # So we do not synchronize them + param_keys_with_state = set(opt_state_dict['state'].keys()) + for group in opt_state_dict['param_groups']: + local_param_set = set(group['params']) + params_to_sync = local_param_set & param_keys_with_state + group['params'] = sorted(params_to_sync) + derived_opt_state_dict = derive_opt_state_dict(opt_state_dict) + + return derived_opt_state_dict + + + def convert_results_to_tensorkeys(self, col_name, round_num, metrics, insert_model): + # insert_model determined whether or not to include the model in the return dictionaries + + # 5. Convert to tensorkeys + + # output metric tensors (scalar) + origin = col_name + tags = ('trained',) + output_metric_dict = { + TensorKey( + metric_name, origin, round_num, True, ('metric',) + ): np.array( + metrics[metric_name] + ) for metric_name in metrics} + + if insert_model: + # output model tensors (Doesn't include TensorKey) + output_model_dict = self.get_tensor_dict(with_opt_vars=True) + global_model_dict, local_model_dict = split_tensor_dict_for_holdouts(logger=self.logger, + tensor_dict=output_model_dict, + **self.tensor_dict_split_fn_kwargs) + else: + global_model_dict, local_model_dict = {}, {} + + # create global tensorkeys + global_tensorkey_model_dict = { + TensorKey( + tensor_name, origin, round_num, False, tags + ): nparray for tensor_name, nparray in global_model_dict.items() + } + # create tensorkeys that should stay local + local_tensorkey_model_dict = { + TensorKey( + tensor_name, origin, round_num, False, tags + ): nparray for tensor_name, nparray in local_model_dict.items() + } + # the train/validate aggregated function of the next round will look + # for the updated model parameters. + # this ensures they will be resolved locally + next_local_tensorkey_model_dict = { + TensorKey( + tensor_name, origin, round_num + 1, False, ('model',) + ): nparray for tensor_name, nparray in local_model_dict.items() + } + + global_tensor_dict = { + **output_metric_dict, + **global_tensorkey_model_dict + } + local_tensor_dict = { + **local_tensorkey_model_dict, + **next_local_tensorkey_model_dict + } + + # update the required tensors if they need to be pulled from the + # aggregator + # TODO this logic can break if different collaborators have different + # roles between rounds. + # for example, if a collaborator only performs validation in the first + # round but training in the second, it has no way of knowing the + # optimizer state tensor names to request from the aggregator + # because these are only created after training occurs. + # A work around could involve doing a single epoch of training + # on random data to get the optimizer names, and then throwing away + # the model. + if self.opt_treatment == 'CONTINUE_GLOBAL': + self.initialize_tensorkeys_for_functions(with_opt_vars=True) + + return global_tensor_dict, local_tensor_dict diff --git a/examples/fl_post/fl/project/src/runner_pt_utils.py b/examples/fl_post/fl/project/src/runner_pt_utils.py new file mode 100644 index 000000000..28fa33de2 --- /dev/null +++ b/examples/fl_post/fl/project/src/runner_pt_utils.py @@ -0,0 +1,278 @@ +# Copyright (C) 2020-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Utilities modeule for PyTorch related Task Runners""" + +# NOTE: this might want to be its own PR to openfl + +from copy import deepcopy +import torch as pt +import numpy as np + +from openfl.utilities.split import split_tensor_dict_for_holdouts +from openfl.utilities import TensorKey + + +def rebuild_model_util(runner_class, input_tensor_dict, testing_with_opt_setting=False, **kwargs): + """ + Parse tensor names and update weights of model. Assumes opt_treatement == CONTINUE_LOCAL, but + allows for writing in optimizer variables for testing purposes + + Returns: + None + """ + if testing_with_opt_setting: + with_opt_vars = True + else: + with_opt_vars = False + + runner_class.set_tensor_dict(input_tensor_dict, with_opt_vars=with_opt_vars) + + +def derive_opt_state_dict(opt_state_dict): + """Separate optimizer tensors from the tensor dictionary. + + Flattens the optimizer state dict so as to have key, value pairs with + values as numpy arrays. + The keys have sufficient info to restore opt_state_dict using + expand_derived_opt_state_dict. + + Args: + opt_state_dict: The optimizer state dictionary + + """ + derived_opt_state_dict = {} + + # Determine if state is needed for this optimizer. + if len(opt_state_dict['state']) == 0: + derived_opt_state_dict['__opt_state_needed'] = 'false' + return derived_opt_state_dict + + derived_opt_state_dict['__opt_state_needed'] = 'true' + + # Using one example state key, we collect keys for the corresponding + # dictionary value. + example_state_key = opt_state_dict['param_groups'][0]['params'][0] + example_state_subkeys = set( + opt_state_dict['state'][example_state_key].keys() + ) + + + + # We assume that the state collected for all params in all param groups is + # the same. + # We also assume that whether or not the associated values to these state + # subkeys is a tensor depends only on the subkey. + # Using assert statements to break the routine if these assumptions are + # incorrect. + for state_key in opt_state_dict['state'].keys(): + assert example_state_subkeys == set(opt_state_dict['state'][state_key].keys()) + for state_subkey in example_state_subkeys: + assert (isinstance( + opt_state_dict['state'][example_state_key][state_subkey], + pt.Tensor) + == isinstance( + opt_state_dict['state'][state_key][state_subkey], + pt.Tensor)) + + state_subkeys = list(opt_state_dict['state'][example_state_key].keys()) + + # Tags will record whether the value associated to the subkey is a + # tensor or not. + state_subkey_tags = [] + for state_subkey in state_subkeys: + if isinstance( + opt_state_dict['state'][example_state_key][state_subkey], + pt.Tensor + ): + state_subkey_tags.append('istensor') + else: + state_subkey_tags.append('') + state_subkeys_and_tags = list(zip(state_subkeys, state_subkey_tags)) + + # Forming the flattened dict, using a concatenation of group index, + # subindex, tag, and subkey inserted into the flattened dict key - + # needed for reconstruction. + nb_params_per_group = [] + for group_idx, group in enumerate(opt_state_dict['param_groups']): + for idx, param_id in enumerate(group['params']): + for subkey, tag in state_subkeys_and_tags: + if tag == 'istensor': + new_v = opt_state_dict['state'][param_id][ + subkey].cpu().numpy() + else: + new_v = np.array( + [opt_state_dict['state'][param_id][subkey]] + ) + derived_opt_state_dict[f'__opt_state_{group_idx}_{idx}_{tag}_{subkey}'] = new_v + nb_params_per_group.append(idx + 1) + # group lengths are also helpful for reconstructing + # original opt_state_dict structure + derived_opt_state_dict['__opt_group_lengths'] = np.array( + nb_params_per_group + ) + return derived_opt_state_dict + + +def expand_derived_opt_state_dict(derived_opt_state_dict, device): + """Expand the optimizer state dictionary. + + Takes a derived opt_state_dict and creates an opt_state_dict suitable as + input for load_state_dict for restoring optimizer state. + + Reconstructing state_subkeys_and_tags using the example key + prefix, "__opt_state_0_0_", certain to be present. + + Args: + derived_opt_state_dict: Optimizer state dictionary + + Returns: + dict: Optimizer state dictionary + """ + state_subkeys_and_tags = [] + for key in derived_opt_state_dict: + if key.startswith('__opt_state_0_0_'): + stripped_key = key[16:] + if stripped_key.startswith('istensor_'): + this_tag = 'istensor' + subkey = stripped_key[9:] + else: + this_tag = '' + subkey = stripped_key[1:] + state_subkeys_and_tags.append((subkey, this_tag)) + + opt_state_dict = {'param_groups': [], 'state': {}} + nb_params_per_group = list( + derived_opt_state_dict.pop('__opt_group_lengths').astype(np.int32) + ) + + # Construct the expanded dict. + for group_idx, nb_params in enumerate(nb_params_per_group): + these_group_ids = [f'{group_idx}_{idx}' for idx in range(nb_params)] + opt_state_dict['param_groups'].append({'params': these_group_ids}) + for this_id in these_group_ids: + opt_state_dict['state'][this_id] = {} + for subkey, tag in state_subkeys_and_tags: + flat_key = f'__opt_state_{this_id}_{tag}_{subkey}' + if tag == 'istensor': + new_v = pt.from_numpy(derived_opt_state_dict.pop(flat_key)) + else: + # Here (for currrently supported optimizers) the subkey + # should be 'step' and the length of array should be one. + assert subkey == 'step' + assert len(derived_opt_state_dict[flat_key]) == 1 + new_v = int(derived_opt_state_dict.pop(flat_key)) + opt_state_dict['state'][this_id][subkey] = new_v + + # sanity check that we did not miss any optimizer state (after removing __opt_state_needed) + derived_opt_state_dict.pop('__opt_state_needed') + if len(derived_opt_state_dict) != 0: + raise ValueError(f"Opt state should have been exausted, but we have left: {derived_opt_state_dict}") + + return opt_state_dict + + +def initialize_tensorkeys_for_functions_util(runner_class, with_opt_vars=False): + """Set the required tensors for all publicly accessible task methods. + + By default, this is just all of the layers and optimizer of the model. + Custom tensors should be added to this function. + + Args: + None + + Returns: + None + """ + # TODO there should be a way to programmatically iterate through + # all of the methods in the class and declare the tensors. + # For now this is done manually + + output_model_dict = runner_class.get_tensor_dict(with_opt_vars=with_opt_vars) + global_model_dict, local_model_dict = split_tensor_dict_for_holdouts( + runner_class.logger, output_model_dict, + **runner_class.tensor_dict_split_fn_kwargs + ) + if not with_opt_vars: + global_model_dict_val = global_model_dict + local_model_dict_val = local_model_dict + else: + output_model_dict = runner_class.get_tensor_dict(with_opt_vars=False) + global_model_dict_val, local_model_dict_val = split_tensor_dict_for_holdouts( + runner_class.logger, + output_model_dict, + **runner_class.tensor_dict_split_fn_kwargs + ) + + runner_class.required_tensorkeys_for_function['train_batches'] = [ + TensorKey(tensor_name, 'GLOBAL', 0, False, ('model',)) + for tensor_name in global_model_dict] + runner_class.required_tensorkeys_for_function['train_batches'] += [ + TensorKey(tensor_name, 'LOCAL', 0, False, ('model',)) + for tensor_name in local_model_dict] + + runner_class.required_tensorkeys_for_function['train'] = [ + TensorKey( + tensor_name, 'GLOBAL', 0, False, ('model',) + ) for tensor_name in global_model_dict + ] + runner_class.required_tensorkeys_for_function['train'] += [ + TensorKey( + tensor_name, 'LOCAL', 0, False, ('model',) + ) for tensor_name in local_model_dict + ] + + # Validation may be performed on local or aggregated (global) model, + # so there is an extra lookup dimension for kwargs + runner_class.required_tensorkeys_for_function['validate'] = {} + # TODO This is not stateless. The optimizer will not be + runner_class.required_tensorkeys_for_function['validate']['apply=local'] = [ + TensorKey(tensor_name, 'LOCAL', 0, False, ('trained',)) + for tensor_name in { + **global_model_dict_val, + **local_model_dict_val + }] + runner_class.required_tensorkeys_for_function['validate']['apply=global'] = [ + TensorKey(tensor_name, 'GLOBAL', 0, False, ('model',)) + for tensor_name in global_model_dict_val + ] + runner_class.required_tensorkeys_for_function['validate']['apply=global'] += [ + TensorKey(tensor_name, 'LOCAL', 0, False, ('model',)) + for tensor_name in local_model_dict_val + ] + + +def to_cpu_numpy(state): + """Send data to CPU as Numpy array. + + Args: + state + + """ + # deep copy so as to decouple from active model + state = deepcopy(state) + + for k, v in state.items(): + # When restoring, we currently assume all values are tensors. + if not pt.is_tensor(v): + raise ValueError('We do not currently support non-tensors ' + 'coming from model.state_dict()') + # get as a numpy array, making sure is on cpu + state[k] = v.cpu().numpy() + return state + + +class DummyDataLoader(): + def __init__(self, feature_shape, training_data_size, valid_data_size): + self.feature_shape = feature_shape + self.training_data_size = training_data_size + self.valid_data_size = valid_data_size + + def get_feature_shape(self): + return self.feature_shape + + def get_training_data_size(self): + return self.training_data_size + + def get_valid_data_size(self): + return self.valid_data_size diff --git a/examples/fl_post/fl/project/utils.py b/examples/fl_post/fl/project/utils.py new file mode 100644 index 000000000..c656f4d3c --- /dev/null +++ b/examples/fl_post/fl/project/utils.py @@ -0,0 +1,178 @@ +import yaml +import os +import shutil + + +def generic_setup(output_logs): + tmpfolder = os.path.join(output_logs, ".tmp") + os.makedirs(tmpfolder, exist_ok=True) + # NOTE: this should be set before any code imports tempfile + os.environ["TMPDIR"] = tmpfolder + os.environ["RESULTS_FOLDER"] = os.path.join(tmpfolder, "nnUNet_trained_models") + os.environ["nnUNet_raw_data_base"] = os.path.join(tmpfolder, "nnUNet_raw_data_base") + os.environ["nnUNet_preprocessed"] = os.path.join(tmpfolder, "nnUNet_preprocessed") + workspace_folder = os.path.join(output_logs, "workspace") + os.makedirs(workspace_folder, exist_ok=True) + create_workspace(workspace_folder) + return workspace_folder + + +def setup_collaborator( + data_path, + labels_path, + node_cert_folder, + ca_cert_folder, + plan_path, + output_logs, + workspace_folder, +): + prepare_plan(plan_path, workspace_folder) + cn = get_collaborator_cn() + prepare_node_cert(node_cert_folder, "client", f"col_{cn}", workspace_folder) + prepare_ca_cert(ca_cert_folder, workspace_folder) + + +def setup_aggregator( + input_weights, + node_cert_folder, + ca_cert_folder, + output_logs, + output_weights, + plan_path, + collaborators, + report_path, + workspace_folder, +): + prepare_plan(plan_path, workspace_folder) + prepare_cols_list(collaborators, workspace_folder) + prepare_init_weights(input_weights, workspace_folder) + fqdn = get_aggregator_fqdn(workspace_folder) + prepare_node_cert(node_cert_folder, "server", f"agg_{fqdn}", workspace_folder) + prepare_ca_cert(ca_cert_folder, workspace_folder) + + +def generic_teardown(output_logs): + tmp_folder = os.path.join(output_logs, ".tmp") + shutil.rmtree(tmp_folder, ignore_errors=True) + + +def create_workspace(fl_workspace): + plan_folder = os.path.join(fl_workspace, "plan") + workspace_config = os.path.join(fl_workspace, ".workspace") + defaults_file = os.path.join(plan_folder, "defaults") + + os.makedirs(plan_folder, exist_ok=True) + with open(defaults_file, "w") as f: + f.write("../../workspace/plan/defaults\n\n") + with open(workspace_config, "w") as f: + f.write("current_plan_name: default\n\n") + + +def get_aggregator_fqdn(fl_workspace): + plan_path = os.path.join(fl_workspace, "plan", "plan.yaml") + plan = yaml.safe_load(open(plan_path)) + return plan["network"]["settings"]["agg_addr"].lower() + + +def get_collaborator_cn(): + return os.environ["MEDPERF_PARTICIPANT_LABEL"] + + +def get_weights_path(fl_workspace): + plan_path = os.path.join(fl_workspace, "plan", "plan.yaml") + plan = yaml.safe_load(open(plan_path)) + return { + "init": plan["aggregator"]["settings"]["init_state_path"], + "best": plan["aggregator"]["settings"]["best_state_path"], + "last": plan["aggregator"]["settings"]["last_state_path"], + } + + +def prepare_plan(plan_path, fl_workspace): + target_plan_folder = os.path.join(fl_workspace, "plan") + os.makedirs(target_plan_folder, exist_ok=True) + + target_plan_file = os.path.join(target_plan_folder, "plan.yaml") + shutil.copyfile(plan_path, target_plan_file) + + +def prepare_cols_list(collaborators_file, fl_workspace): + with open(collaborators_file) as f: + cols_dict = yaml.safe_load(f) + cn_different = False + for col_label in cols_dict.keys(): + cn = cols_dict[col_label] + if cn != col_label: + cn_different = True + if not cn_different: + # quick hack to support old and new openfl versions + cols_dict = list(cols_dict.keys()) + + target_plan_folder = os.path.join(fl_workspace, "plan") + os.makedirs(target_plan_folder, exist_ok=True) + target_plan_file = os.path.join(target_plan_folder, "cols.yaml") + with open(target_plan_file, "w") as f: + yaml.dump({"collaborators": cols_dict}, f) + + +def prepare_init_weights(input_weights, fl_workspace): + error_msg = f"{input_weights} should contain only one file: *.pbuf" + + files = os.listdir(input_weights) + file = files[0] # TODO: this may cause failure in MAC OS + if len(files) != 1 or not file.endswith(".pbuf"): + raise RuntimeError(error_msg) + + file = os.path.join(input_weights, file) + + target_weights_subpath = get_weights_path(fl_workspace)["init"] + target_weights_path = os.path.join(fl_workspace, target_weights_subpath) + target_weights_folder = os.path.dirname(target_weights_path) + os.makedirs(target_weights_folder, exist_ok=True) + os.symlink(file, target_weights_path) + + +def prepare_node_cert( + node_cert_folder, target_cert_folder_name, target_cert_name, fl_workspace +): + error_msg = f"{node_cert_folder} should contain only two files: *.crt and *.key" + + files = os.listdir(node_cert_folder) + file_extensions = [file.split(".")[-1] for file in files] + if len(files) != 2 or sorted(file_extensions) != ["crt", "key"]: + raise RuntimeError(error_msg) + + if files[0].endswith(".crt") and files[1].endswith(".key"): + cert_file = files[0] + key_file = files[1] + else: + key_file = files[0] + cert_file = files[1] + + key_file = os.path.join(node_cert_folder, key_file) + cert_file = os.path.join(node_cert_folder, cert_file) + + target_cert_folder = os.path.join(fl_workspace, "cert", target_cert_folder_name) + os.makedirs(target_cert_folder, exist_ok=True) + target_cert_file = os.path.join(target_cert_folder, f"{target_cert_name}.crt") + target_key_file = os.path.join(target_cert_folder, f"{target_cert_name}.key") + + os.symlink(key_file, target_key_file) + os.symlink(cert_file, target_cert_file) + + +def prepare_ca_cert(ca_cert_folder, fl_workspace): + error_msg = f"{ca_cert_folder} should contain only one file: *.crt" + + files = os.listdir(ca_cert_folder) + file = files[0] + if len(files) != 1 or not file.endswith(".crt"): + raise RuntimeError(error_msg) + + file = os.path.join(ca_cert_folder, file) + + target_ca_cert_folder = os.path.join(fl_workspace, "cert") + os.makedirs(target_ca_cert_folder, exist_ok=True) + target_ca_cert_file = os.path.join(target_ca_cert_folder, "cert_chain.crt") + + os.symlink(file, target_ca_cert_file) diff --git a/examples/fl_post/fl/setup_clean.sh b/examples/fl_post/fl/setup_clean.sh new file mode 100644 index 000000000..05fc6a919 --- /dev/null +++ b/examples/fl_post/fl/setup_clean.sh @@ -0,0 +1,4 @@ +rm -rf ./mlcube_agg +rm -rf ./mlcube_col* +rm -rf ./ca +rm -rf ./for_admin diff --git a/examples/fl_post/fl/setup_test.sh b/examples/fl_post/fl/setup_test.sh new file mode 100644 index 000000000..72a1c55b9 --- /dev/null +++ b/examples/fl_post/fl/setup_test.sh @@ -0,0 +1,84 @@ +while getopts t flag; do + case "${flag}" in + t) TWO_COL_SAME_CERT="true" ;; + esac +done +TWO_COL_SAME_CERT="${TWO_COL_SAME_CERT:-false}" + +COL1_CN="col1@example.com" +COL2_CN="col2@example.com" +COL3_CN="col3@example.com" +COL1_LABEL="col1@example.com" +COL2_LABEL="col2@example.com" +COL3_LABEL="col3@example.com" + +if ${TWO_COL_SAME_CERT}; then + COL1_CN="org1@example.com" + COL2_CN="org2@example.com" + COL3_CN="org2@example.com" # in this case this var is not used actually. it's OK +fi + +cp -r ./mlcube ./mlcube_agg +cp -r ./mlcube ./mlcube_col1 +cp -r ./mlcube ./mlcube_col2 +cp -r ./mlcube ./mlcube_col3 + +mkdir ./mlcube_agg/workspace/node_cert ./mlcube_agg/workspace/ca_cert +mkdir ./mlcube_col1/workspace/node_cert ./mlcube_col1/workspace/ca_cert +mkdir ./mlcube_col2/workspace/node_cert ./mlcube_col2/workspace/ca_cert +mkdir ./mlcube_col3/workspace/node_cert ./mlcube_col3/workspace/ca_cert +mkdir ./ca + +HOSTNAME_=$(hostname -A | cut -d " " -f 1) + +medperf mlcube run --mlcube ../mock_cert/mlcube --task trust +mv ../mock_cert/mlcube/workspace/pki_assets/* ./ca + +# col1 +medperf mlcube run --mlcube ../mock_cert/mlcube --task get_client_cert -e MEDPERF_INPUT_CN=$COL1_CN +mv ../mock_cert/mlcube/workspace/pki_assets/* ./mlcube_col1/workspace/node_cert +cp -r ./ca/* ./mlcube_col1/workspace/ca_cert + +# col2 +medperf mlcube run --mlcube ../mock_cert/mlcube --task get_client_cert -e MEDPERF_INPUT_CN=$COL2_CN +mv ../mock_cert/mlcube/workspace/pki_assets/* ./mlcube_col2/workspace/node_cert +cp -r ./ca/* ./mlcube_col2/workspace/ca_cert + +# col3 +if ${TWO_COL_SAME_CERT}; then + cp mlcube_col2/workspace/node_cert/* mlcube_col3/workspace/node_cert + cp mlcube_col2/workspace/ca_cert/* mlcube_col3/workspace/ca_cert +else + medperf mlcube run --mlcube ../mock_cert/mlcube --task get_client_cert -e MEDPERF_INPUT_CN=$COL3_CN + mv ../mock_cert/mlcube/workspace/pki_assets/* ./mlcube_col3/workspace/node_cert + cp -r ./ca/* ./mlcube_col3/workspace/ca_cert +fi + +medperf mlcube run --mlcube ../mock_cert/mlcube --task get_server_cert -e MEDPERF_INPUT_CN=$HOSTNAME_ +mv ../mock_cert/mlcube/workspace/pki_assets/* ./mlcube_agg/workspace/node_cert +cp -r ./ca/* ./mlcube_agg/workspace/ca_cert + +# aggregator_config +echo "address: $HOSTNAME_" >>mlcube_agg/workspace/aggregator_config.yaml +echo "port: 50273" >>mlcube_agg/workspace/aggregator_config.yaml + +# cols file +echo "$COL1_LABEL: $COL1_CN" >>mlcube_agg/workspace/cols.yaml +echo "$COL2_LABEL: $COL2_CN" >>mlcube_agg/workspace/cols.yaml +echo "$COL3_LABEL: $COL3_CN" >>mlcube_agg/workspace/cols.yaml + +# data download +cd mlcube_col1/workspace/ +wget https://storage.googleapis.com/medperf-storage/testfl/col1_prepared.tar.gz +tar -xf col1_prepared.tar.gz +rm col1_prepared.tar.gz +cd ../.. + +cd mlcube_col2/workspace/ +wget https://storage.googleapis.com/medperf-storage/testfl/col2_prepared.tar.gz +tar -xf col2_prepared.tar.gz +rm col2_prepared.tar.gz +cd ../.. + +cp -r mlcube_col2/workspace/data mlcube_col3/workspace +cp -r mlcube_col2/workspace/labels mlcube_col3/workspace diff --git a/examples/fl_post/fl/setup_test_no_docker.sh b/examples/fl_post/fl/setup_test_no_docker.sh new file mode 100644 index 000000000..e8938f561 --- /dev/null +++ b/examples/fl_post/fl/setup_test_no_docker.sh @@ -0,0 +1,119 @@ +while getopts n: flag; do + case "${flag}" in + n) NUM_COLS=${OPTARG} ;; + esac +done +NUM_COLS="${NUM_COLS:-3}" + +setupCA() { + mkdir ./ca + openssl genpkey -algorithm RSA -out ca/root.key -pkeyopt rsa_keygen_bits:3072 + openssl req -x509 -new -nodes -key ca/root.key -sha384 -days 36500 -out ca/root.crt \ + -subj "/DC=org/DC=simple/CN=Simple Root CA/O=Simple Inc/OU=Simple Root CA" + +} + +setupData() { + wget https://storage.googleapis.com/medperf-storage/fltest29July/small_test_data.tar.gz + tar -xf small_test_data.tar.gz + rm -rf small_test_data.tar.gz +} + +setupWeights() { + # weights setup + cd mlcube/workspace + mkdir additional_files + cd additional_files + wget https://storage.googleapis.com/medperf-storage/flpost_add9nov.tar.gz + tar -xf flpost_add9nov.tar.gz + rm flpost_add9nov.tar.gz + cd ../../../ +} + +setupAgg() { + HOSTNAME_=$(hostname -A | cut -d " " -f 1) + cp -r ./mlcube ./mlcube_agg + mkdir ./mlcube_agg/workspace/node_cert ./mlcube_agg/workspace/ca_cert + + # cert + sed -i "/^commonName = /c\commonName = $HOSTNAME_" csr.conf + sed -i "/^DNS\.1 = /c\DNS.1 = $HOSTNAME_" csr.conf + cd mlcube_agg/workspace/node_cert + openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 + openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_server + openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ + -CAcreateserial -out crt.crt -days 36500 -sha384 -extensions v3_server_crt -extfile ../../../csr.conf + rm csr.csr + cp ../../../ca/root.crt ../ca_cert/ + cd ../../../ + + # aggregator_config + echo "address: $HOSTNAME_" >>mlcube_agg/workspace/aggregator_config.yaml + echo "port: 50273" >>mlcube_agg/workspace/aggregator_config.yaml + + # weights + cp -r mlcube/workspace/additional_files mlcube_agg/workspace/additional_files + +} + +setupAdmin() { + ADMIN_CN="testfladmin@example.com" + + mkdir ./for_admin + mkdir ./for_admin/node_cert + + sed -i "/^commonName = /c\commonName = $ADMIN_CN" csr.conf + sed -i "/^DNS\.1 = /c\DNS.1 = $ADMIN_CN" csr.conf + cd for_admin/node_cert + openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 + openssl req -new -key key.key -out csr.csr -config ../../csr.conf -extensions v3_client + openssl x509 -req -in csr.csr -CA ../../ca/root.crt -CAkey ../../ca/root.key \ + -CAcreateserial -out crt.crt -days 36500 -sha384 -extensions v3_client_crt -extfile ../../csr.conf + rm csr.csr + mkdir ../ca_cert + cp -r ../../ca/root.crt ../ca_cert/root.crt + cd ../.. +} + +setupCol() { + # $1 is ID + N=$1 + COL_CN="col$N@example.com" + COL_LABEL="col$N@example.com" + cp -r ./mlcube ./mlcube_col$N + mkdir ./mlcube_col$N/workspace/node_cert ./mlcube_col$N/workspace/ca_cert + + # cert + sed -i "/^commonName = /c\commonName = $COL_CN" csr.conf + sed -i "/^DNS\.1 = /c\DNS.1 = $COL_CN" csr.conf + cd mlcube_col$N/workspace/node_cert + openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 + openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_client + openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ + -CAcreateserial -out crt.crt -days 36500 -sha384 -extensions v3_client_crt -extfile ../../../csr.conf + rm csr.csr + cp ../../../ca/root.crt ../ca_cert/ + cd ../../../ + + # add in cols file + echo "$COL_LABEL: $COL_CN" >>mlcube_agg/workspace/cols.yaml + + # weights + cp -r mlcube/workspace/additional_files mlcube_col$N/workspace/additional_files + + # data + mv small_test_data$N/* mlcube_col$N/workspace +} + +setupCA +setupData +setupWeights +setupAgg +setupAdmin + +a=0 +while [ $a -lt $NUM_COLS ]; do + a=$(expr $a + 1) + setupCol $a +done +rm -r small_test_data* diff --git a/examples/fl_post/fl/sync.sh b/examples/fl_post/fl/sync.sh new file mode 100755 index 000000000..f7a5803e6 --- /dev/null +++ b/examples/fl_post/fl/sync.sh @@ -0,0 +1,10 @@ +cp mlcube/workspace/training_config.yaml mlcube_agg/workspace/training_config.yaml +cp mlcube/mlcube.yaml mlcube_agg/mlcube.yaml + +for dir in mlcube_col*/; do + if [ -d "$dir" ]; then + cp mlcube/mlcube.yaml $dir/mlcube.yaml + rm -r $dir/workspace/additional_files + cp -r mlcube/workspace/additional_files $dir/workspace/additional_files + fi +done diff --git a/examples/fl_post/fl/test.sh b/examples/fl_post/fl/test.sh new file mode 100755 index 000000000..205323557 --- /dev/null +++ b/examples/fl_post/fl/test.sh @@ -0,0 +1,99 @@ +COL1_DATA="" +COL1_LABELS="" +COL2_DATA="" +COL2_LABELS="" +COL3_DATA="" +COL3_LABELS="" +while [[ "$#" -gt 0 ]]; do + case $1 in + --d1) + COL1_DATA="$2" + shift + ;; + --d2) + COL2_DATA="$2" + shift + ;; + --d3) + COL3_DATA="$2" + shift + ;; + --l1) + COL1_LABELS="$2" + shift + ;; + --l2) + COL2_LABELS="$2" + shift + ;; + --l3) + COL3_LABELS="$2" + shift + ;; + --ncols) + NUM_COLS="$2" + shift + ;; + *) + echo "Unknown parameter: $1" + exit 1 + ;; + esac + shift +done + +COL1_DATA="${COL1_DATA:-$PWD/mlcube_col1/workspace/data}" +COL1_LABELS="${COL1_LABELS:-$PWD/mlcube_col1/workspace/labels}" +COL2_DATA="${COL2_DATA:-$PWD/mlcube_col2/workspace/data}" +COL2_LABELS="${COL2_LABELS:-$PWD/mlcube_col2/workspace/labels}" +COL3_DATA="${COL3_DATA:-$PWD/mlcube_col3/workspace/data}" +COL3_LABELS="${COL3_LABELS:-$PWD/mlcube_col3/workspace/labels}" +NUM_COLS="${NUM_COLS:-3}" + +# generate plan and copy it to each node +GENERATE_PLAN_PLATFORM="docker" +AGG_PLATFORM="docker" +COL1_PLATFORM="docker" +COL2_PLATFORM="docker" +COL3_PLATFORM="docker" + +medperf --platform $GENERATE_PLAN_PLATFORM mlcube run --mlcube ./mlcube_agg --task generate_plan +mv ./mlcube_agg/workspace/plan/plan.yaml ./mlcube_agg/workspace +rm -r ./mlcube_agg/workspace/plan +cp ./mlcube_agg/workspace/plan.yaml ./for_admin + +for dir in mlcube_col*/; do + if [ -d "$dir" ]; then + cp ./mlcube_agg/workspace/plan.yaml $dir/workspace + fi +done + +# Run nodes +AGG="medperf --platform $AGG_PLATFORM mlcube run --mlcube ./mlcube_agg --task start_aggregator -P 50273" +COL1="medperf --platform $COL1_PLATFORM --gpus=device=0 mlcube run --mlcube ./mlcube_col1 --task train -e MEDPERF_PARTICIPANT_LABEL=col1@example.com --params data_path=$COL1_DATA,labels_path=$COL1_LABELS" +COL2="medperf --platform $COL2_PLATFORM --gpus=device=1 mlcube run --mlcube ./mlcube_col2 --task train -e MEDPERF_PARTICIPANT_LABEL=col2@example.com --params data_path=$COL2_DATA,labels_path=$COL2_LABELS" +COL3="medperf --platform $COL3_PLATFORM --gpus=device=1 mlcube run --mlcube ./mlcube_col3 --task train -e MEDPERF_PARTICIPANT_LABEL=col3@example.com --params data_path=$COL3_DATA,labels_path=$COL3_LABELS" + +$AGG >agg.log & +sleep 6 +$COL1 >col1.log & +sleep 6 +$COL2 >col2.log & +sleep 6 +$COL3 >>col3.log & + +a=3 +while [ $a -lt $NUM_COLS ]; do + a=$(expr $a + 1) + medperf --platform docker --gpus=device=1 mlcube run --mlcube ./mlcube_col$a --task train -e MEDPERF_PARTICIPANT_LABEL=col$a@example.com +done + +wait + +# gnome-terminal -- bash -c "$AGG; bash" +# gnome-terminal -- bash -c "$COL1; bash" +# gnome-terminal -- bash -c "$COL2; bash" +# gnome-terminal -- bash -c "$COL3; bash" + +# docker run --env MEDPERF_PARTICIPANT_LABEL=col1@example.com --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/data:/mlcube_io0:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/labels:/mlcube_io1:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/node_cert:/mlcube_io2:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/ca_cert:/mlcube_io3:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace:/mlcube_io4:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/logs:/mlcube_io5 -it --entrypoint bash mlcommons/medperf-fl:1.0.0 +# python /mlcube_project/mlcube.py train --data_path=/mlcube_io0 --labels_path=/mlcube_io1 --node_cert_folder=/mlcube_io2 --ca_cert_folder=/mlcube_io3 --plan_path=/mlcube_io4/plan.yaml --output_logs=/mlcube_io5 diff --git a/examples/fl_post/fl/test_init.sh b/examples/fl_post/fl/test_init.sh new file mode 100644 index 000000000..3dee25f0a --- /dev/null +++ b/examples/fl_post/fl/test_init.sh @@ -0,0 +1 @@ +medperf --gpus=1 mlcube run --mlcube ./mlcube_col1 --task train_initial_model diff --git a/flca/Dockerfile.dev b/flca/Dockerfile.dev new file mode 100644 index 000000000..45e3b1340 --- /dev/null +++ b/flca/Dockerfile.dev @@ -0,0 +1,22 @@ +FROM python:3.11.9-alpine + +ENV USE_PROXY=1 + +# update openssl to fix https://avd.aquasec.com/nvd/cve-2024-2511 +RUN apk update && apk add openssl=3.1.4-r6 tar curl && if [[ -n "${USE_PROXY}" ]]; then apk add nginx; fi + +ARG VERSION=0.26.1 +RUN curl -LO https://dl.smallstep.com/gh-release/cli/gh-release-header/v${VERSION}/step_linux_${VERSION}_amd64.tar.gz \ + && tar -xf step_linux_${VERSION}_amd64.tar.gz \ + && cp step_${VERSION}/bin/step /usr/bin +RUN curl -LO https://dl.smallstep.com/gh-release/certificates/gh-release-header/v${VERSION}/step-ca_linux_${VERSION}_amd64.tar.gz \ + && tar -xf step-ca_linux_${VERSION}_amd64.tar.gz \ + && cp step-ca /usr/bin + + +COPY ./dev_utils.py /utils.py + +COPY ./setup.py /setup.py +COPY ./entrypoint.sh /entrypoint.sh + +ENTRYPOINT ["/bin/sh", "/entrypoint.sh"] diff --git a/flca/Dockerfile.prod b/flca/Dockerfile.prod new file mode 100644 index 000000000..6af8a8904 --- /dev/null +++ b/flca/Dockerfile.prod @@ -0,0 +1,22 @@ +FROM python:3.11.9-alpine + +ENV USE_PROXY=1 + +# update openssl to fix https://avd.aquasec.com/nvd/cve-2024-2511 +RUN apk update && apk add openssl=3.1.4-r6 tar curl && if [[ -n "${USE_PROXY}" ]]; then apk add nginx; fi + +ARG VERSION=0.26.1 +RUN curl -LO https://dl.smallstep.com/gh-release/cli/gh-release-header/v${VERSION}/step_linux_${VERSION}_amd64.tar.gz \ + && tar -xf step_linux_${VERSION}_amd64.tar.gz \ + && cp step_${VERSION}/bin/step /usr/bin +RUN curl -LO https://dl.smallstep.com/gh-release/certificates/gh-release-header/v${VERSION}/step-ca_linux_${VERSION}_amd64.tar.gz \ + && tar -xf step-ca_linux_${VERSION}_amd64.tar.gz \ + && cp step-ca /usr/bin + +RUN pip install google-cloud-secret-manager==2.20.0 +COPY ./utils.py /utils.py + +COPY ./setup.py /setup.py +COPY ./entrypoint.sh /entrypoint.sh + +ENTRYPOINT ["/bin/sh", "/entrypoint.sh"] \ No newline at end of file diff --git a/flca/README.md b/flca/README.md new file mode 100644 index 000000000..f5916a8f6 --- /dev/null +++ b/flca/README.md @@ -0,0 +1,80 @@ +# Deploying Step + +## For Production + +### Configuration and secrets + +#### `ca_config` + +An example of this file can be found in `dev_assets/ca.json` (or [here](https://smallstep.com/docs/step-ca/configuration/#basic-configuration-options)). This contains ca configuration, and will be modified during runtime as follows: + + * `root`: will point to the path of the root ca cert after it gets downloaded and stored. + * `crt`: will point to the path of the intermediate ca cert after it gets downloaded and stored. + * `key`: will point to the path of the intermediate ca key after it gets downloaded and stored. + * `db`: will contain the database configuration. (will be taken from a secret variable) + * `authority.provisioners.0.options.x509.templateFile`: will point to the path of the OIDC provisioner cert template after it gets downloaded and stored. + * `authority.provisioners.1.options.x509.templateFile`: will point to the path of the ACME provisioner cert template after it gets downloaded and stored. + +#### Other configuration files + + * `root_ca_crt`: The root ca certificate. + * `intermediate_ca_crt`: The intermediate ca certificate. + * `client_x509_template`: The OIDC provisioner cert template. + * `server_x509_template`: The ACME provisioner cert template. + * `proxy_config`: If a proxy need to be used, this contains an Nginx server configuration. + +#### Secrets + + * `intermediate_ca_key`: The intermediate ca key. + * `intermediate_ca_password`: The password used to encrypt the intermediate ca key. + * `db_config`: Database connection configuration. + +#### Main settings file + +All secrets and configurations are separately stored in GCP's secret manager. There is a main settings file `settings.json` that is also stored on the secret manager, and it is a JSON file that contains references to the other secrets/configurations. + +### Deployment + + * Build + + ```sh + docker build -t tmptag -f Dockerfile.prod . + ``` + + * tag + + ```sh + TAG=$(docker image ls | grep tmptag | tr -s " " | awk '{$1=$1;print}' | cut -d " " -f 3) + docker tag tmptag us-west1-docker.pkg.dev/medperf-330914/medperf-repo/medperf-ca:$TAG + ``` + + * Push + + ```sh + docker push us-west1-docker.pkg.dev/medperf-330914/medperf-repo/medperf-ca:$TAG + ``` + + * Setup secrets and configurations + * Edit `cloudbuild.yaml` as needed. You may change: + * the service account that will bind to the deployed instance. + * the port + * the service name if planning to deploy a new service, not a new revision of the existing service. + * SQL instance + * ... + * Run `gcloud builds submit --config=cloudbuild.yaml --substitutions=SHORT_SHA=$TAG` + +## For Development + +### Configuration and secrets + +The folder `dev_assets` contains configurations and ""secrets"" described above, but for development. + +### Deployment + +Build using `Dockerfile.dev` (tag it say with `local/devca:0.0.0`), then run: + +```sh +docker run --volume ./dev_assets:/assets -p :443:443 local/devca:0.0.0 +``` + +Set `` as you wish (`0.0.0.0`, `127.0.0.1`, `$(hostname -I | cut -d " " -f 1)`, ...) diff --git a/flca/cloudbuild.yaml b/flca/cloudbuild.yaml new file mode 100644 index 000000000..158cfc889 --- /dev/null +++ b/flca/cloudbuild.yaml @@ -0,0 +1,45 @@ +#The script is invoked manually with all settings provided in the secret +#It assumes that DB is created before the script run +#Inorder to deploy a service, pass sha-id of the already built image +#Command: gcloud builds submit --config=cloudbuild.yaml --substitutions=SHORT_SHA= +steps: + - id: "deploy cloud run" + name: "gcr.io/cloud-builders/gcloud" + args: + [ + "run", + "deploy", + "${_CLOUD_RUN_SERVICE_NAME}", + "--platform", + "managed", + "--region", + "${_REGION}", + "--image", + "${_REGION}-${_ARTIFACT_REGISTRY_DOMAIN}/${PROJECT_ID}/${_REPO_NAME}/${_IMAGE_NAME}:${SHORT_SHA}", + "--add-cloudsql-instances", + "${PROJECT_ID}:${_REGION}:${_SQL_INSTANCE_NAME}", + "--set-env-vars", + "SETTINGS_SECRETS_NAME=${_SECRET_SETTINGS_NAME}", + "--allow-unauthenticated", + "--min-instances", + "${_CLOUD_RUN_MIN_INSTANCES}", + "--port", + "${_PORT}", + "--service-account", + "${_SERVICE_ACCOUNT}" + ] + +substitutions: + _REGION: us-west1 + _ARTIFACT_REGISTRY_DOMAIN: docker.pkg.dev + _REPO_NAME: medperf-repo + _IMAGE_NAME: medperf-ca + _CLOUD_RUN_SERVICE_NAME: medperf-ca + _CLOUD_RUN_MIN_INSTANCES: "1" + _SECRET_SETTINGS_NAME: medperf-ca-settings + _SQL_INSTANCE_NAME: medperf-dev + _PORT: "443" + _SERVICE_ACCOUNT: "medperf-ca@medperf-330914.iam.gserviceaccount.com" + +options: + dynamic_substitutions: true diff --git a/flca/dev_assets/ca.json b/flca/dev_assets/ca.json new file mode 100644 index 000000000..8e6e191d3 --- /dev/null +++ b/flca/dev_assets/ca.json @@ -0,0 +1,57 @@ +{ + "root": "/stephome/certs/root_ca.crt", + "federatedRoots": null, + "crt": "/stephome/certs/intermediate_ca.crt", + "key": "/stephome/secrets/intermediate_ca_key", + "dnsNames": [ + "127.0.0.1" + ], + "address": "127.0.0.1:8000", + "logger": { + "format": "text" + }, + "db": "", + "authority": { + "provisioners": [ + { + "type": "ACME", + "name": "acme", + "options": { + "x509": { + "templateFile": "/stephome/templates/certs/x509/server.tpl" + } + } + }, + { + "type": "OIDC", + "name": "auth0", + "clientID": "kQoZ38ESRfUuMUUBlQRv2gWwOwGAMOqd", + "clientSecret": "", + "configurationEndpoint": "https://dev-5xl8y6uuc2hig2ly.us.auth0.com/.well-known/openid-configuration", + "options": { + "x509": { + "templateFile": "/stephome/templates/certs/x509/client.tpl" + }, + "ssh": {} + } + } + ], + "claims": { + "minTLSCertDuration": "8766h", + "maxTLSCertDuration": "8766h", + "defaultTLSCertDuration": "8766h", + "disableRenewal": true + }, + "template": {}, + "backdate": "1m0s" + }, + "tls": { + "cipherSuites": [ + "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", + "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" + ], + "minVersion": 1.2, + "maxVersion": 1.3, + "renegotiation": false + } +} \ No newline at end of file diff --git a/flca/dev_assets/client.tpl b/flca/dev_assets/client.tpl new file mode 100644 index 000000000..e4f3523f5 --- /dev/null +++ b/flca/dev_assets/client.tpl @@ -0,0 +1,10 @@ +{ + "subject": {{ toJson .Token.email }}, + "sans": {{ toJson .SANs }}, +{{- if typeIs "*rsa.PublicKey" .Insecure.CR.PublicKey }} + "keyUsage": ["dataEncipherment", "digitalSignature"], +{{- else }} + {{ fail "Key type must be RSA. Try again with --kty=RSA" }} +{{- end }} + "extKeyUsage": ["serverAuth", "clientAuth"] +} \ No newline at end of file diff --git a/flca/dev_assets/db_config.json b/flca/dev_assets/db_config.json new file mode 100644 index 000000000..5b1cabc9c --- /dev/null +++ b/flca/dev_assets/db_config.json @@ -0,0 +1,4 @@ +{ + "type": "badgerv2", + "dataSource": "/db" +} \ No newline at end of file diff --git a/flca/dev_assets/intermediate_ca.crt b/flca/dev_assets/intermediate_ca.crt new file mode 100644 index 000000000..2619b1f26 --- /dev/null +++ b/flca/dev_assets/intermediate_ca.crt @@ -0,0 +1,27 @@ +-----BEGIN CERTIFICATE----- +MIIElTCCAsmgAwIBAgIRAM0SvcPdc4E0j1asWd+pCoQwQQYJKoZIhvcNAQEKMDSg +DzANBglghkgBZQMEAgEFAKEcMBoGCSqGSIb3DQEBCDANBglghkgBZQMEAgEFAKID +AgEgMBoxGDAWBgNVBAMTD01lZFBlcmYgUm9vdCBDQTAeFw0yNDA1MDkxMjU2NDZa +Fw0zNDA1MTAwMDU2NDFaMCIxIDAeBgNVBAMTF01lZFBlcmYgSW50ZXJtZWRpYXRl +IENBMIIBojANBgkqhkiG9w0BAQEFAAOCAY8AMIIBigKCAYEAyz2XEnXNheKF5ul0 +TgvMcvfAEmQh1IjrtNsisk1Jcep82bInK6GiKTvr4f1JzqelqEY1PrZnjcKtG+Q7 +AttYtAZi+mH5FnxDZAEvPSMd4Feo7y0QCvjsfe6jwUX6XToh+2ET89863xCd9JLi +iy9FvjbZWLx3trN1k/aOBlotKUXaLOHRaOO4GTpzZymN+HElsYSxXhxiWVAuQQF3 +GEEXAi1rlekNDt46cqI230M4rI9FWuRZTOaVsHm7OyyOXJt7inbKvWsgWF9+YaZS +lfhuuj6RKDpK32DOekzAR39mjdJf+EsXnIw0Jx3Mqcq2n6xjh9/72Z88CpVFrcEm +TeSnohAEUM6f1a4oTxYCl+FOV8RhplCoC/NAxSGbfmZ0y8WYNeGwSQyAFFpBX2rO +zsXIWGbclRLkzklbfTUf70Fi4hjOzHEPrGK8J5bIjiOk/l8dc3vQRu30OmUEYWRY +g9FleBS67roTijEtWN9V63MLMpouBJugTN/xO48VHHYbolkZAgMBAAGjZjBkMA4G +A1UdDwEB/wQEAwIBBjASBgNVHRMBAf8ECDAGAQH/AgEAMB0GA1UdDgQWBBSyaUMC +TOjmPZn7ZTDXG6Qo8qKi4TAfBgNVHSMEGDAWgBT0pIgmVPfaZgIy7Cx1yxRc7SRb +UzBBBgkqhkiG9w0BAQowNKAPMA0GCWCGSAFlAwQCAQUAoRwwGgYJKoZIhvcNAQEI +MA0GCWCGSAFlAwQCAQUAogMCASADggGBAIXTa3szlsPLh7cv5X/ZgFzSF/6dtRmC +1UlQz8YvQUU5qzI53fvvVCph0+I02YHjA8npospXB0v9FTwiU2TrdYR9Kld/JMzI +zQrkyZC1ePFiZrc7fLzTp8kuo78rUqgJY6TgGxBwH0GfwonJo9/wr0Xxnzex3h+P +3DobGjlcVkEpvgERL0JLvmWdi/Vh3saiaD6a+Lvy6c2dt5YpF6RHK8v0DEpzyfM4 +xDdkYA1z4i2sjZntshxpm21Q7K9I80jCv+MuqUPKcyVmlGlVfaxU0HynjPIY/53Y +cjIA8QVnW8Wf4iLn/Tg4To2YLXZx5R08iuyU/r3d5QWlqs24UZgJvqRxnkFnQE9n +wjXB/MtSG76VRFHNrcVAajNY4hfE4N4ZQN1IVoXA32a2arpMaAuZiLUmH8usJ1B0 +AvHP+tgKGJbxomDINqcA7kgf5SFMN3ocAhvBwvk3+NMMZqvb8KgiViavIe+A3Of0 +UaA1RyJhmgymzVjnDRSvVT5zddkpe6k38A== +-----END CERTIFICATE----- diff --git a/flca/dev_assets/intermediate_ca_key b/flca/dev_assets/intermediate_ca_key new file mode 100644 index 000000000..e8e1ef8ba --- /dev/null +++ b/flca/dev_assets/intermediate_ca_key @@ -0,0 +1,42 @@ +-----BEGIN RSA PRIVATE KEY----- +Proc-Type: 4,ENCRYPTED +DEK-Info: AES-256-CBC,1417c765eac64d4dcc4606c745aa811e + +8yCvhJYJTWTSu72X7sbunRTfb2VwzyLtIokzF2jIIGnlXGcN5YtNTqHdoOWha94B +FPlyLuzvGlzSjpEtYIcBYNReY0l75KWIWJ0AP4OufHobyzJ5FttTHQPUh72v6iPt +bWWS68Y5tsfJLIlB0BBrkfIHH7eh8kEXN75MZnXAPc4d8RccN6SYoubCKWqNFvYb +9QloWLGVP58sd5ILSf5zE50KYO3ZMoDgjK3W0hTD5ocBtRYzD3m5toyFdA5Sxufc +Ea7JI2NJ8Nch1jLwW0MbtDC77yu1RfVGuJAfLZsQwyXPVT4QxOZQoRLWMeUuZSe9 ++EfNsjrwmNHu2h+z9BXVYq6DX5lpE0FBnzzIsOXCkmSsyDa0/7q3yAD63dGS9NF1 +hMypj840NBoF1X6PIdG0r49HhsfQ9witQh1H77ZM0HzSBN1jo9dmljgLocU+cAv5 +sPt0H/+BQzwOJG8U1SYfW0Sa7airYfY3mdZgqX/Ghzp2uZavmD6UYQgvHFaSX1j3 +n4EJUZpV489FyWRcd+0e/czzZ+QJcRsGUeYOvLMGKXZSKpG68/lMjZ8hHQ3t+DuN +3OqotHK6B5HEayMcQn6Hglscfq72UbiaqkXIwNBugSnwwLtO+ObHEyC1DSldgXpt +GQgwh9YHHbJioTy4ZQxoxjCkRnRWk8VSgEs1bH8MxVuVUsiJNrkY82AwwQ+Z5bpx +cC0niMIRuqKG2Ja0IzUnqhkAjYprYn55dIa05iCMtkNCT5DuQbRa8GlGqsH/pzN1 +jQuB5TTigcpnj8b97GXSDBpPej7pZ48vGZ8+CyIo72hd8nMYeZjcNp2rzM87Aa5G +15z36341apE/4yRya5nRMGXObdBTHiZfQzCrVZ2pIYqhGLew3UU98a9LJwR28+KO +cH5V67P9hOhHwj1/O6WKhDJBeyCrnEjDjXfY4/CzKzL72x8iy6ujGG71eu9Ton6J +mCEpXZlamIchBnVM1o1VM7C3PaMJnIwD8JsUt/G3siBKG+Xa9qehe2agkMvOKDVI +zpC7IjpzXjPQ9TpvCxQ+A7D6yNL4p3NjdVW6vlYKa58LESuXhxvr5QsfXUYk70B+ +95J/vpgWyoFadvWIYVAvjcrYRPlgUjJDf1tz/kxqkiWSbDnCMdPG7NeX+RTFY4IW +ifzvrSeTo2k4ceYKFMvWB7NmKAgPISvMDggM1irdrEuoYmOFofCnVRdxHOpXB3w0 +2MazMFXyXP2AEOkTU7FPyO5dRVrtbXIAggCfUpIE8CdseO6YO9k8LAKfehVHWIqX +H+SVUYhH/Ij/B5u7/YLVDtYfup7vzF637nOtI3LODo/XWFBZmFt/mpeTnF1v5UTU +irDi2cma4fcKs8FDIAgq8SYX/jyz9+7tD1yVhP8qXJqESmi04kHBs7D2O+MbiPg5 +gkHP0B8ZVwClieG/ovAJloZFUVcr+PgN+eeUK1VhBoN9GQotJ6oxO0YBaExbnGne +14TJArS9xBhBhmNh79J9QTaoSDro995uS4UWjEtm6vnwSU7X2U44pcY4O+cxlJjK +rHIjmAOa988F7WSGPg3PniFZ7LTU5qpcJ0blebrciYGhgk8YN7x58oncjkK7vIVV +OO4WJbLfkAbPX4UePiIp/3oMNZD+4ndChPcd0T8xlHaKkrB8jtUQk6Cfgeg5IWlZ +AD1FNyvrLDdigef/rKJ1VOjQjG4bPcdkSB7zo1MlWzHGJOQl0+U06m0whdAhhyIb +cC1MbeszAe+v8+9F2VaNEiRzQzEQAcTcdLOQoMB0q6WVE/G7QN88GMVV4KfqrK64 +jyH9tK2u8itmoxE5tSVwCU5Eys12li15/l1bg/LormlfO2r/nZX7L0ifjWWcaRAn +YFEhzVErS1/7671lINKN+tlJS7L8pHvfT9nhtFfZfUNaYZrhz/1RhijEyCKmrSZm +JDYqYl05lSqeLjHVDY5AAtGQzgrQRKDAD87AkbR+6365Pw8xnhzkxKZndeA7Zpxy +21Q4gsUK5EdIM0RYNDyV0c12pY9S5mN5V8DDTGH+IkdghiI1UX0QjmLDjAFrHrQ3 +9VKY+t+jZtk4AkkLxovDZMgwCPulhHvHKcXG6+v+vUJ2/W1B71JU1f+JzPZdBulT +yj1XiBPKo5FkIDShRs+iAjmbOlIoIaWhJB0G4VMFy0mhtIvaiq0GoDWitUT6VaUw +P85KzananB/5MCd9wzxb4quOxD0tbapaaOXfaLwFmM06BH+WFK/TtVitJiuR+W4z +MrFpHlbGQYQOtjhFC94yuC1oGviV1vqOpswMDntGCzB7dofossqUnxh7SUxFtDIP +PWJgcFW2qmBDUu8+n0SmfaGWqNTl4FpsvdrMwpdL5X+OAq80KXcUqWtGvu7PfRJU +-----END RSA PRIVATE KEY----- diff --git a/flca/dev_assets/pwd.txt b/flca/dev_assets/pwd.txt new file mode 100644 index 000000000..fde6f26b8 --- /dev/null +++ b/flca/dev_assets/pwd.txt @@ -0,0 +1 @@ +password! \ No newline at end of file diff --git a/flca/dev_assets/reverse_proxy.conf b/flca/dev_assets/reverse_proxy.conf new file mode 100644 index 000000000..a0065bf89 --- /dev/null +++ b/flca/dev_assets/reverse_proxy.conf @@ -0,0 +1,10 @@ +server { + listen 443; + location / { + proxy_pass https://127.0.0.1:8000; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_set_header X-Forwarded-Proto $scheme; + } +} \ No newline at end of file diff --git a/flca/dev_assets/root_ca.crt b/flca/dev_assets/root_ca.crt new file mode 100644 index 000000000..790b5951c --- /dev/null +++ b/flca/dev_assets/root_ca.crt @@ -0,0 +1,26 @@ +-----BEGIN CERTIFICATE----- +MIIEazCCAp+gAwIBAgIQbHXgymGrqb/G85lgg0qzAzBBBgkqhkiG9w0BAQowNKAP +MA0GCWCGSAFlAwQCAQUAoRwwGgYJKoZIhvcNAQEIMA0GCWCGSAFlAwQCAQUAogMC +ASAwGjEYMBYGA1UEAxMPTWVkUGVyZiBSb290IENBMB4XDTI0MDUwOTEyNTYzNFoX +DTQ0MDUwOTEyNTYzM1owGjEYMBYGA1UEAxMPTWVkUGVyZiBSb290IENBMIIBojAN +BgkqhkiG9w0BAQEFAAOCAY8AMIIBigKCAYEAyPGt+k78Us5F61TL9AIt5rWVUUC+ +lGSZiZfDQktz43FfMKH6Kei7R4/9sTZoqKTa3XKyshMX23odrCyvKtpeBNnKQToo +yqDjLWx4wl4rcZatKIKSWou2Uhk77+ONa3T37ckAzgsHzECbRJtd/PsKk12o5PbI +0jgrcVZnXN897H78FDc5TxPQgdCN7v59URGJQ4e1H+dJlyaMF0bbN236PwQrlMA3 +kcnRZ1qsZ6R2eocFII7GCDmsdFZ3X0peg4hUz0Lf5pDYff6hf1R6VyfbqR+xJOdx +BHHO1Ak5UAxv1EYUflOZp11snqn3ZQCChRfDkoANWqdU0LIqNHlDn58YXPiXwnli +gItjyKxlphXKlFUi6juhWvW4nrYf378g3cO7echkn7NaHXudIxQp6xO8iETSUOEf +pGGuV72kapiZA9+GTFwYAK8MmuiHYxelZlqLl4fpUnxvZSrtr8m2n87a6AMRPWKS +yIxar684hQI6TRq9hJTs+HneeZT2CH7k9FqtAgMBAAGjRTBDMA4GA1UdDwEB/wQE +AwIBBjASBgNVHRMBAf8ECDAGAQH/AgEBMB0GA1UdDgQWBBT0pIgmVPfaZgIy7Cx1 +yxRc7SRbUzBBBgkqhkiG9w0BAQowNKAPMA0GCWCGSAFlAwQCAQUAoRwwGgYJKoZI +hvcNAQEIMA0GCWCGSAFlAwQCAQUAogMCASADggGBAK4poY8+QbdDkahXHFWc4CeL +dffenE0FKUCDvR4+vQLoFjUKVcpdnkIrkouWSys2VmZusgvUofEKMk2vKVB/kzDF +Zg0GkIeBAZr0n10LImvwed1CzmA99K0J42W3Z8ksr0Qsr+1qmx9SyVv1xsMmIw1g +0CkpjkkpvNb9Z+nKbn5q3WBEcThM8llsi9krOaj9wzDVOGx1D7TcrB7IOTmCvkrE +TthFqz5EuUOHW0KVyQsDBd/ktxx5zkkHeAo9RrBifG4drmKF4ZnW9JcSdD74H7+2 +K1S9U3BkCdwSDfRpuuE4v1RP2ClrfdEF4//PjCnlw+7vEkPrOugEJZO5/IZzRApb +By5R8fgK6ChDUAWmx67CUb5PGG9ugGumVKpnk4Hwa+hixn5f19MeXTLEXW2Wmkp4 +SD1vsz9zN9HTk4dMpPJCO0oIWvbzEZcpCJt4kiBshzZK5GgP8PxGEqjxC6Vr/sUv +xgP8tlytyAGjZMcLPsXrMZMtlhxCEt7y9mMZ7NGLCA== +-----END CERTIFICATE----- diff --git a/flca/dev_assets/server.tpl b/flca/dev_assets/server.tpl new file mode 100644 index 000000000..78cc4d606 --- /dev/null +++ b/flca/dev_assets/server.tpl @@ -0,0 +1,10 @@ +{ + "subject": {{ toJson .Subject }}, + "sans": {{ toJson .SANs }}, +{{- if typeIs "*rsa.PublicKey" .Insecure.CR.PublicKey }} + "keyUsage": ["dataEncipherment", "digitalSignature"], +{{- else }} + {{ fail "Key type must be RSA. Try again with --kty=RSA" }} +{{- end }} + "extKeyUsage": ["serverAuth", "clientAuth"] +} \ No newline at end of file diff --git a/flca/dev_assets/settings.json b/flca/dev_assets/settings.json new file mode 100644 index 000000000..e2aa5202f --- /dev/null +++ b/flca/dev_assets/settings.json @@ -0,0 +1,11 @@ +{ +"ca_config": "/assets/ca.json", +"intermediate_ca_key": "/assets/intermediate_ca_key", +"intermediate_ca_password": "/assets/pwd.txt", +"root_ca_crt": "/assets/root_ca.crt", +"intermediate_ca_crt": "/assets/intermediate_ca.crt", +"client_x509_template": "/assets/client.tpl", +"server_x509_template": "/assets/server.tpl", +"proxy_config": "/assets/reverse_proxy.conf", +"db_config": "/assets/db_config.json" +} \ No newline at end of file diff --git a/flca/dev_utils.py b/flca/dev_utils.py new file mode 100644 index 000000000..05378d0dd --- /dev/null +++ b/flca/dev_utils.py @@ -0,0 +1,23 @@ +import json +import os + + +def safe_store(content: str, path: str): + with open(path, "w") as f: + pass + os.chmod(path, 0o600) + with open(path, "a") as f: + f.write(content) + + +def get_all_secrets(): + # load settings + with open("/assets/settings.json") as f: + settings = json.load(f) + + # get secrets + secrets = {} + for key in settings.keys(): + with open(settings[key]) as f: + secrets[key] = f.read() + return secrets diff --git a/flca/entrypoint.sh b/flca/entrypoint.sh new file mode 100644 index 000000000..35325b77c --- /dev/null +++ b/flca/entrypoint.sh @@ -0,0 +1,13 @@ +export STEPPATH=$(step path) +python /setup.py +step-ca --password-file=$STEPPATH/secrets/pwd.txt $STEPPATH/config/ca.json & + +if [[ -n "$USE_PROXY" ]]; then + STATUS="1" + while [ "$STATUS" -ne "0" ]; do + sleep 1 + step ca health --ca-url 127.0.0.1:8000 + STATUS="$?" + done + nginx -g "daemon off;" +fi diff --git a/flca/manual_setup/README.md b/flca/manual_setup/README.md new file mode 100644 index 000000000..d4a37c4e5 --- /dev/null +++ b/flca/manual_setup/README.md @@ -0,0 +1 @@ +This contains code used to generate the root and intermediate ca certs and keys. You need to install `step` to use it (look at the dockerfiles). diff --git a/flca/manual_setup/create_keys.sh b/flca/manual_setup/create_keys.sh new file mode 100644 index 000000000..5dad840bd --- /dev/null +++ b/flca/manual_setup/create_keys.sh @@ -0,0 +1,17 @@ +step certificate create "MedPerf Root CA" \ + ./root_ca.crt \ + ./root_ca.key \ + --template ./rsa_root_ca.tpl \ + --kty RSA \ + --not-after 175320h \ + --size 3072 + +step certificate create "MedPerf Intermediate CA" \ + ./intermediate_ca.crt \ + ./intermediate_ca.key \ + --ca ./root_ca.crt \ + --ca-key ./root_ca.key \ + --template ./rsa_intermediate_ca.tpl \ + --kty RSA \ + --not-after 87660h \ + --size 3072 diff --git a/flca/manual_setup/rsa_intermediate_ca.tpl b/flca/manual_setup/rsa_intermediate_ca.tpl new file mode 100644 index 000000000..3f5606989 --- /dev/null +++ b/flca/manual_setup/rsa_intermediate_ca.tpl @@ -0,0 +1,12 @@ +{ + "subject": {{ toJson .Subject }}, + "issuer": {{ toJson .Subject }}, + "keyUsage": ["certSign", "crlSign"], + "basicConstraints": { + "isCA": true, + "maxPathLen": 0 + } + {{- if typeIs "*rsa.PublicKey" .Insecure.CR.PublicKey }} + , "signatureAlgorithm": "SHA256-RSAPSS" + {{- end }} +} \ No newline at end of file diff --git a/flca/manual_setup/rsa_root_ca.tpl b/flca/manual_setup/rsa_root_ca.tpl new file mode 100644 index 000000000..150812984 --- /dev/null +++ b/flca/manual_setup/rsa_root_ca.tpl @@ -0,0 +1,12 @@ +{ + "subject": {{ toJson .Subject }}, + "issuer": {{ toJson .Subject }}, + "keyUsage": ["certSign", "crlSign"], + "basicConstraints": { + "isCA": true, + "maxPathLen": 1 + } + {{- if typeIs "*rsa.PublicKey" .Insecure.CR.PublicKey }} + , "signatureAlgorithm": "SHA256-RSAPSS" + {{- end }} +} \ No newline at end of file diff --git a/flca/setup.py b/flca/setup.py new file mode 100644 index 000000000..06b9acd3f --- /dev/null +++ b/flca/setup.py @@ -0,0 +1,96 @@ +import os +import json +from utils import get_all_secrets, safe_store + + +def validate(secrets: dict): + """main settings are expected to be a json file that contains secrets IDs of other objects""" + expected_keys = set( + [ + "ca_config", + "intermediate_ca_key", + "intermediate_ca_password", + "root_ca_crt", + "intermediate_ca_crt", + "client_x509_template", + "server_x509_template", + "db_config", + ] + ) + if os.environ.get("USE_PROXY", None): + expected_keys.add("proxy_config") + + if expected_keys != set(secrets.keys()): + msg = "Expected keys: " + ", ".join(expected_keys) + msg += "\nFound keys: " + ", ".join(set(secrets.keys())) + raise ValueError(msg) + + +def setup(): + step_path = os.environ.get("STEPPATH", None) + if step_path is None: + raise Exception("STEPPATH var is not set") + + secrets = get_all_secrets() + validate(secrets) + + # Create folders + secrets_folder = os.path.join(step_path, "secrets") + certs_folder = os.path.join(step_path, "certs") + config_folder = os.path.join(step_path, "config") + templates_folder = os.path.join(step_path, "templates", "certs", "x509") + os.makedirs(secrets_folder, mode=0o600) + os.makedirs(certs_folder, mode=0o600) + os.makedirs(config_folder, mode=0o600) + os.makedirs(templates_folder, mode=0o600) + + # store key and its password + intermediate_ca_key_path = os.path.join(secrets_folder, "intermediate_ca_key") + safe_store(secrets["intermediate_ca_key"], intermediate_ca_key_path) + + password_path = os.path.join(secrets_folder, "pwd.txt") + safe_store(secrets["intermediate_ca_password"], password_path) + + # store root and intermediate certs + root_ca_crt_path = os.path.join(certs_folder, "root_ca.crt") + safe_store(secrets["root_ca_crt"], root_ca_crt_path) + + intermediate_ca_crt_path = os.path.join(certs_folder, "intermediate_ca.crt") + safe_store(secrets["intermediate_ca_crt"], intermediate_ca_crt_path) + + # store signing templates + client_tpl_path = os.path.join(templates_folder, "client.tpl") + safe_store(secrets["client_x509_template"], client_tpl_path) + + server_tpl_path = os.path.join(templates_folder, "server.tpl") + safe_store(secrets["server_x509_template"], server_tpl_path) + + # Get config + config = json.loads(secrets["ca_config"]) + + # Override config with runtime paths + config["root"] = root_ca_crt_path + config["crt"] = intermediate_ca_crt_path + config["key"] = intermediate_ca_key_path + # assuming server provisioner is the first one, and client is second + config["authority"]["provisioners"][0]["options"]["x509"][ + "templateFile" + ] = server_tpl_path + config["authority"]["provisioners"][1]["options"]["x509"][ + "templateFile" + ] = client_tpl_path + + # Override db config + config["db"] = json.loads(secrets["db_config"]) + + # write config + ca_config_path = os.path.join(config_folder, "ca.json") + safe_store(json.dumps(config), ca_config_path) + + # setup proxy + if os.environ.get("USE_PROXY", None): + safe_store(secrets["proxy_config"], "/etc/nginx/http.d/reverse-proxy.conf") + + +if __name__ == "__main__": + setup() diff --git a/flca/utils.py b/flca/utils.py new file mode 100644 index 000000000..f316c8baf --- /dev/null +++ b/flca/utils.py @@ -0,0 +1,49 @@ +import google.auth +from google.cloud import secretmanager +import json +import os + + +def get_secret(secret_name: str): + """Code copied and modified from medperf/server/medperf/settings.py""" + + try: + _, os.environ["GOOGLE_CLOUD_PROJECT"] = google.auth.default() + except google.auth.exceptions.DefaultCredentialsError: + raise Exception( + "No local .env or GOOGLE_CLOUD_PROJECT detected. No secrets found." + ) + + # Pull secrets from Secret Manager + print("Loading env from GCP secrets manager") + project_id = os.environ.get("GOOGLE_CLOUD_PROJECT") + + client = secretmanager.SecretManagerServiceClient() + settings_version = os.environ.get("SETTINGS_SECRETS_VERSION", "latest") + name = f"projects/{project_id}/secrets/{secret_name}/versions/{settings_version}" + payload = client.access_secret_version(name=name).payload.data.decode("UTF-8") + return payload + + +def safe_store(content: str, path: str): + with open(path, "w") as f: + pass + os.chmod(path, 0o600) + with open(path, "a") as f: + f.write(content) + + +def get_all_secrets(): + settings_name = os.environ.get("SETTINGS_SECRETS_NAME", None) + if settings_name is None: + raise Exception("SETTINGS_SECRETS_NAME var is not set") + + # load settings + settings = get_secret(settings_name) + settings_dict: dict = json.loads(settings) + + # get secrets + secrets = {} + for key in settings_dict.keys(): + secrets[key] = get_secret(settings_dict[key]) + return secrets diff --git a/mock_tokens/generate_tokens.py b/mock_tokens/generate_tokens.py index b202c9eeb..c68e74e2c 100644 --- a/mock_tokens/generate_tokens.py +++ b/mock_tokens/generate_tokens.py @@ -23,7 +23,7 @@ def token_payload(user): } -users = ["testadmin", "testbo", "testmo", "testdo"] +users = ["testadmin", "testbo", "testmo", "testdo", "testdo2", "testao", "testfladmin"] tokens = {} # Use headers when verifying tokens using json web keys diff --git a/mock_tokens/tokens.json b/mock_tokens/tokens.json index 6c6bcfc12..022681ff2 100644 --- a/mock_tokens/tokens.json +++ b/mock_tokens/tokens.json @@ -1,6 +1,9 @@ { - "testadmin@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdGFkbWluQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0YWRtaW4iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE3MDEwMzY3NjEsImV4cCI6MTE3MDEwMzY3NjF9.iZwlrNHjT90aZt_puWQnNke-7IrtQQ5FXsxpGfrYQjRGGXG4mgAvqI-9o-D4MWw8zdO0pDNddbQI44aoXDa_oOUpo23qhqjo-AahIKKUGu4W166cV6G8lseza7xr7WtZqEn_WA2qJR-IcqZvu80Lt6nURR-7tl80cLK4NdD5TmOvTOZdn4psgQg1uWrfWCLcQvjvfEtGPxHij1zu2usv5FuyDytp49xjFbH90bnepkIV0Jr_BfUZEm75sRf1wfj8c-t3IhqdWySfR0gSC4UW9ieaG_h7_kxRI_J3qfUwBklbtCMkOnApA4FaRUnv48fRBWCGxtU_1AVHbUwwPldMfUd8cDf_76Ipi31nIX5PVw7g7O00L23-CyjGf23U2j4Srz1xBHG_u3HAoT7XXOPpjaLGI0y021e9x7i1GWHMzqzcGcNUlJj8GMfocTJOLR1y4UNYvvWFuhaeqOHpVclcJ22Mo9JjsFLfy5D5TPetk2vBD0bExCgAOmAmhdnSEY96OxItjWYfSlZuBen29JD9NaUCwK4knQm2NODnKeTIS034EQsWGqXT-84VUdR_pJr_-seNzrmLxD9pLNfW3XARgscE-7Rfg29cdADj0RU_KNblkyNjwJzn1XxUI54IRlj3oYTQ8R4VzFag2NbJgqDpg6EQYg8Ii2X1v-8QN3tvuQQ", - "testbo@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdGJvQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0Ym8iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE3MDEwMzY3NjEsImV4cCI6MTE3MDEwMzY3NjF9.BPDuCCa5akAzJmM7qSuUB7A0fmSkGGOKl4bCdRkyeYBCYuPofLHXCeEuILw9iZxKIKGpxiOal-6JLWlo6cteW4-ebcby3u7zV2rc0pWWvsezsFuYQt8FojAi6Knv5R0BbgSLtVxc_BGR7vK0apVEc9VU4ootfDitdHDHFGo2QO90RIIIg1toWIK5NkK39WQLvvUnEXhtvrhejeqFPpJj7SgrggytoW1ZZqGmtFDKhJii1cKdW5anNOsUdD5L386lgh5K5n_nxim63MrZI8wdJvLW16_NcvVRYOrgfEP7jp1kyb5Vmv_NQaS9CsnSMewv-JA4lP9LC3bs1YixOEcHYP6T9z1g6hhV96RkpCjIZuo5QbBYKVsONscePZDLTdlj2NrfMuyjt7NbrJWJpmaOCmqKvnQDIo1gdDd940_kBgcLgrjtTn6LfndXsAWM6U6_x78uKL73XoJwQwPXVwF2_hyo6vEHufx9rfo7WyPD5vTHb_1FrZQE-FZ_0DmOpe4tVaDy9nkfcZzu_jlYMnFMTJSn8te3hs9JWyfDGTwLY9r0WW5htRAS8dHSOGRzXpMCD6gHvcbYEpcEewGGJiaWombVUOCJg-x1Ax8zXx2-k4vFol0_7jYPw0EsduqTtjXChdw8NGDVUKLjjLlAZ_oUAYRCaYOdBpBNFFKKIfJ6oGw", - "testmo@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdG1vQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0bW8iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE3MDEwMzY3NjEsImV4cCI6MTE3MDEwMzY3NjF9.ZDTNB9UAMnTA67QQXBBqPBjG96QJbFfPweL0AEioSMu-VzZTQlSdmgmXR36Li8opkBwxaVpWKCyP95eCgEAZ6Fdr_kE259kHJxHTD8UrejN8vj0ramqmXfd2xIxEN7q8YGzt4USGXFVII_iCehPandtuXfsqvQrcuu-dlfbsb85Azphi_6SWM3-2U5Vit7bDOeP218XHBy78uhW9mLZt6d838Hk97U9tr2vlhrO7bCOhy5SHSP4svhkF-wIglxarIxPu7RUsYshJp21aY976tJ39_RDS9ResIytYZGrAUacKGU5OJihyaaR_WNoppXGPxJsfbGyqylj_XC9_jgsnZCyfztahzOeWyCjjQeosSiB6dc7cmzlgeDcBtXAvkEtYg6B8SFR-c-NM6Wy0Wd4L4UQW175ySCBRwfuGUlKTszwG7LyRcxolui7ESc8E1PynczfOXxdXut1mxtrXrfKX90jgIV_wVR9LGYF0IVteJS-kudpabyDG65-LAli15ZwYbNYDaFLSit6j4W_sVeN9zZPA6cm37hfzI6pqf1J42R6_hmL1lmdv2Aq_6kumDShxVDugpjpnCt6vaZFphqJR5F_MIW587brzrbVVVNuaD1T1Gf0WJgX9-FaFsvRwXMhEdTfyXDPcnIdYH83T-tNNFyZCJek64-e--WSnSrRrQYs", - "testdo@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdGRvQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0ZG8iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE3MDEwMzY3NjEsImV4cCI6MTE3MDEwMzY3NjF9.A-6hIS7Ua0nXr89Percg0YVjYmWFcpQSK66URRAjByTuLnvfOokfEyTULpdjB9uNOSEwIOoH55Y3o2lNG5_M9PVshScnPXR8YR0Ow7elU0GIfdcv3YDYu-YzH0iTbrVEc0X91J_SPp2vN5dSO0UVLHE9fQxxL1OzYeu2w9yHTu5OdEpJC7yPLPUu3dbKhknhwO4OGt_9hBD4b8X85f8js89d9YA9gl7XJydwNjdnNfv8ZUXwDP9b77h0RBeGGNkyYNNkRAex35g5D4Xd9evPsf7obHZkRTjiF-RMToCeJNnjv0BzHlsiz66qAONjW6LCcAwW7-ACS282YYP1yRhr4NXQinfBBo62auYa25sRTHv6uRV7IMrw01OovbVb4lbuKWTZQ0TUmW_UlQujv_EpzoXHcC__ZvMVRG9einjLdj2MfUHuvmxLeM2OYK75VzjP63YKP_67O9hAbAzyn1Z6-CQAgom8coGUV8Mdwz-L-VmvpYMdmCFohYRTqspluj-O-hrMNKexycdbbuflodkRhn8Pi35fUqb2UP5eQvhXusX7ob-YB7PHaka6xY8Wsndb_blEoTkSUmbxeSWt8QKz_n6AWBO6LCgOCG1LyAv2_DaL2M00sJ9oLCGDYnkDRoc_Aq45GEHxIS8cEQL2siFQbrH-KPT8PlKshNjMLtyNcO8" + "testadmin@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdGFkbWluQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0YWRtaW4iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE3MjUzMTc2MTUsImV4cCI6MTE3MjUzMTc2MTV9.A1xirQHtKErew3u9J1hVjMxiTeXbJ96ymQqnRE4xOsX8aBzaVPF7T7u6Vf0icVJT-EZ755ladQ4oaotd4iRMUMWyyOwLVSKUvgFWFLGW4GAalNV3YNVL0KjgHg_i10GXk9M9ruEgSuZlVKtY-R7oLcPpzzLwhyO5MQ0VFDrQ4ClXzw_Q6Cln4TP-oJeuIhLqRiTaqba4Yu8Vf3a6KkysJG2Ldo322669H8FiQv9OgWeOQOvvZQV44TJ9OMVQnSYbNgM7NVf5wSkSpho5rvOG2q6MQ7vWbUx7TnPyPUySE0f9M0ql4ycXbPTUIpyf4X44ynqSehtaodW6g-0cwjWJZMX7iHD4mpPaOS5_Z0nf7ARsMv903n_Ybi4GNiJqUfXazACD086Hxh7LMjxPhdLAaL76DkNH2WNp9Kb21-0VTWETcb6-bx2na0nwcVPBYNyhfJGs6gZv9pXQSm-8v345-6FBHW1xA-X00qJnVqOMo_MPDza3lju3HNN1JbMBiRY02DNofMvxUN2AlrcdmZHoaDoOrMhM4IZadnNijenH7UKIn_KCQub7Ji5HYgynpbbr63Rvjvjp2RN_qIEVyoFj71qF56J8Ccio2FIdLjigWJChfBVcEcvt9wr1LgFi3fV2KMOUVTWoRg6kTAtep_iwigKxrQWnkeUuqCy6Dju_ftU", + "testbo@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdGJvQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0Ym8iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE3MjUzMTc2MTUsImV4cCI6MTE3MjUzMTc2MTV9.BhAlgy16Pz3RYvJcTKTtpwyjoPWDTGK4ExmEJIud8kz7cTFHLf-XBcHRwaAY0RaGJv7MlKnpuBF_dYXAw5wdVQPN7MKKPJRyH4LpKHrvaV3kGGmoKhfUoCrvBSwNBi8a-Y9ywnZhYzG0G4aFx2pLqmJBbjdCmpeeqXIvFlHP7xniV9HnFdT1J5iP42W1JHYZQa4cofDOyo211YtmkzjzfBa8Y7_cSMzA_1_EF-tEZCBomoS_D_ghrNsTOOzSONs1OKAHuPmhqNsS1vuQEO5vAYq9GkAac8gb7aKJ91tGWJYMFwkiDtrNErPAUkPMuuBKhsM6mHkhAc85cHgop8FcH69XkBE-a4GvF9cGv4BQh5mf3-XaRvXf49ZLozRoN6WMlVTDYcO1S8lbM9qYou4k2BaK3vztKZLGbpDhbIr-LlYjFxBYXV8aGzW9aY1cfQ-hbASOS11woYz0I9vi33ODkvAdll6K9xntNJE-9V_hHA36tdGKnpM6JqHWtUeXXyGBSnXmDaW51tEUpOXkLESZ553n-fmZoyX8auyDB_s0dlhMS4IDGgplbiLtnB2sSmrr0hK3zS-WJ8Ht5n6XVRsMvaSXsToArfAOKD6RK9KsJ_BC0ibdv1KHKotzVO_Eq4hUWlnebDCxmw9M_iFOZXCtv0zZ8_T655zrAFRqvVNfczI", + "testmo@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdG1vQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0bW8iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE3MjUzMTc2MTUsImV4cCI6MTE3MjUzMTc2MTV9.pq8dKUIH9Yw03m8DUM17YLn6Kkj3O53B3zZEUlgrMAON0jG59erfb0-vzH6r-M3QMtE1N9MYJO6idUpJYVNcPwK0CY1MFinaA4QqrzHjdbDT0sXu9LuME9hXtzoiyXmi7q3yPESUM_4Y67YBoCbe44V4wUHSP73mQFdIPud0NqhM1vON9zuBa4pTu5GT3JS5DlN6WIZwj-xQtwrYZDCCweWvQbfkct1KxAq0VdyHTM5E4EcgkFtuTk6MGAW7rBEqg72jzk_H3gmP4z5owUl_NpCOM_fWQQSH5iLGn_QKuEbtt8Pi8MB5KBCv7kId2gitIiozzyqgVGAGYMPWAroFmQCc4Pd5pBZRp7fyhE8HFvkurED4RfB0pOhKFc-l9NjV4ESYMeTjBTgdkcCTu9GFKCAEcKONjtVIMYe7Hl0UhOEbNqnDwpOR1qYxYGZTX346Z8-QmG-id9sI5FhJHxOi0p0gv6oHI1UKEZTHmNPWzGLvGbZSJWdJIdESeuhicRwcqeBOBG5tvbqIW71q1tcvAyNco9--bQYY3lLfM57uOWd60e4gLdT-GSaIFEI6jfgS9AWgP9wfujIWyqpKUvjzkZVW-mn6OFtJ53UwRYdhLFs3xwJ2MZdqOtTtidmGks4UxjUDxhcs84TwqgyuBu-gtrAmlinKKuEtKfCeQSrOtrM", + "testdo@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdGRvQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0ZG8iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE3MjUzMTc2MTUsImV4cCI6MTE3MjUzMTc2MTV9.h_MiBOCO1HpRa-t5899x4MeBaRHT56d3o-aLf7hiVjMGLaM_Qvr4HwYOTiCMdyanK-XQjZF2qpGamFXwBxcDUCqelMaEfGZkt29aTf6NFeXlsq62GF8Nukm8QU1eDTeKAuE1q8bmT7Kfz57njo-rVarZPIr4Lisy8QJmn8-zRC9pW9OQH6yrIXMgxcivPt6JiwavReevDCx4ZLyoy1ULaRLuXf_-MybcSd04Mj3zANgg_3P1LwlY1-m3kmAdx4YI20D0EuLwzu0S8F5in3rpPuhC8J1UcHG_-IMG7Y5J4g0rojZ3UwJKOBfYYzSRhP1cZqUZ7ZDUsq2rqB_03VtEflFVKtn_-xcLTy0zIbLUPQKfTIijJwtKPfh95o7vrVF4zj7tEOmJe-v0K-DN_I4xIJT0ajEoTs81tjEHHcfKNSq6mIAlROSM9lrDNLWKW_rn_0douB6kkKVUry_gFLTeGtTvsJuZWQyjw4B2WVwlhgjh3ECyRPvvhJwrZxLytBx5CzMEpC4bSl8MWvRJOh2oSOGP2xn4FtgjfAJVi0CVV_xZhTGzfhiQ8-PBMAGzgfoEkoQWPsPq-YIS_lHQw3ZTFmiehd5EtmxApHW5kvCv3bqtqCuYlWcWHAdpYmDnpETiYwCJZUTnoCsSsRGbE4TFGPhvH5J6-XFiEBbx4GKcJqU", + "testdo2@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdGRvMkBleGFtcGxlLmNvbSIsImlzcyI6Imh0dHBzOi8vbG9jYWxob3N0OjgwMDAvIiwic3ViIjoidGVzdGRvMiIsImF1ZCI6Imh0dHBzOi8vbG9jYWxob3N0LWxvY2FsZGV2LyIsImlhdCI6MTcyNTMxNzYxNSwiZXhwIjoxMTcyNTMxNzYxNX0.q_RIdhezmevU-HxjNlNdexL43UISEHz5lyEXckpxyoK9oot8tNCHjNzWVxczJKQDLSTDYseszhouVqNvNPlOM_NRZq-bhXuAFWdpdL1ORqXCYn5RCjC4bLMdgtU_0kB0DVzWhUGYE5HD1aKaJPXkzJEyUsGafXUI5RCih5FeSzIhQjbkiVnI0oKrKKA9bbHJwa98cpw8OBsRHfnnVQPA_ZAaai12iq115HyhHUFPPgSPOKhiqu2bEMSPfYiOj31UB3UsDL2ZdU2Pxzb9UMSmuGmZUO1PkjHfozx5OTDHvcNQkQyDa9PbwprmF4SWCq67ma_OklepyUqi7dRpQXYJRl3cN1JzAlSh-eTmkhCr5SpIsg8_fMyZJ6hqwKDIpfGvUftovmwrumO5AKXoPHi7sdQnOEI31vloUx0ni0wgje7-3SBu3DcCndWZ-nyzEhex7vpEWXEkz7T8MDYxqbBh7ksLukypW8t0NDxZVXywpRaTEqn4G344UEM-L0StEaWA3G8Ed72pexWyAwsX3nF-ZvTnir6v1VvRk6H_v6dX028KgdUVlw87N3NLQ-CwWxcPvIbld2mv9djhfZ4-0cmywLs8bAy7ponsy3LsVY-gkmJUBMVV6865QohMZKKm4Ws_M__v3MQDgQzrQPCxEame7g7hwiblQar6pJYYaS_Doi4", + "testao@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdGFvQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0YW8iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE3MjUzMTc2MTUsImV4cCI6MTE3MjUzMTc2MTV9.FcKfgh6hm_sOFRE_CT8YNyoH11WHx7ANLLUA4gHi731w-Di00cU3BrOmJvOKZXqJRx1cC9HSciUpn2sxEzdL0g47RrurKRXB6Ex6jS9D267eW8_myc64WgF40yFDwK9RvRVtjFtEAuGxBz9r8sqAEDvrfkXTIya3AWm9CwBhD9I2IueSJF0uoYJ58qp0s1C8caQpbuxDh9QLQmTSqyr1lLM6nf1B3SFRm18XWCkv3GU0wH4DG94crMgUBQF-QyUN7jIsAKbI_7OaqwgFVzHY4SJTx6B_8EF8PHHE_11IfoXTI_jcl_TV0SS_e00A6RT6E7YbBNrTbhgpO02qQbFS6QpaG4TCo8hzX85Rv6DqqRxGT0z3mQzAYmZMPKf5YRIgHoLb_4rM5zYOCm3sbE6NIvKEACmHkaXjDT4TbBLNWcNAKAvL-CZpKkvmH6hdszqYCGzZbw8kSEQTMLuuOe1F0IwLVNhUWl5psAhDHSHCJ_7rI6afPZEl9Q0Aex0q9_HAX4Pb7Dc2i4jIxgIM3Ojsm2ODlFaqV9hxeaLtZrRBWE65VZhvUMk1CWgLByX-uVxPZbZAWNdTntjjoRGmXc9XT6wd9DIMdiORta_TqK0fNskgLLmM7v2H8rgcaMfbzVasG8UgKDESsk7BPhKYBavA75UbEv0zMg8Qq79eRyIGoBY", + "testfladmin@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdGZsYWRtaW5AZXhhbXBsZS5jb20iLCJpc3MiOiJodHRwczovL2xvY2FsaG9zdDo4MDAwLyIsInN1YiI6InRlc3RmbGFkbWluIiwiYXVkIjoiaHR0cHM6Ly9sb2NhbGhvc3QtbG9jYWxkZXYvIiwiaWF0IjoxNzI1MzE3NjE1LCJleHAiOjExNzI1MzE3NjE1fQ.rFI0WUYJxtmOBIzxhlK5AUYArPl9yOcE2BmNwzAMApav5-NiGF1_L5WbatZbkbqTKDBVSvI8TrCEG191Cw8SCw-mKigRd4_C7K4HG70DDVZzStLbQUI5irChy4_a4HmA_SipUnR84jeeGNkRJCTHkeQ9WxOylKttX9ZTxbOHsCm2urMQyllEaDEe6V8M1J3JOuFtmVZRL05LCy9jJRPvTrz35o7j1mbdbjPFWe3R3SV5oXBnqtMkFjqOH93PaUgtAHGZ5TOD3sdeBxRyRNHMP7xf7LFZgih-6ai12O0Iq0wn3B5Q54-YEP5ExdmzjeCFtblQ9VzgRxG7isHxWiRytJr-vf9ScpWm9VLOhI71pCOFDg0pDLWt9L525hShv_wXJ0LwjWzU7z6gTUy2HYLGGWgh7XZTn_EqhLb7rx4DD3hLJ0KyeJ0w6UIK2Wjwr85HNQAt00HuaL2zTjyO5rF9GHdWW8SXYMPLFM3egwPJJ72dCEIWH8Hs5JjRftRREC9nGWQPWoebzb73RDtivvY3C8vjk34WjWuaaoKzeyY6PXzSRNMaUk3BVa6lxgHpri9ytQpm1LmTT-ksnndpCl5VPC0LoynJs0qdpSL7JaO73MWsgu1gt81W53leUfn-8EyhJT3x2i74HtGHyyoIJ8nzqRHrIlnwCgD5hlqlSpNFbw0" } \ No newline at end of file diff --git a/scripts/status_analysis.ipynb b/scripts/status_analysis.ipynb new file mode 100644 index 000000000..2206a3c15 --- /dev/null +++ b/scripts/status_analysis.ipynb @@ -0,0 +1,166 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import os\n", + "import yaml\n", + "\n", + "\n", + "def create_metrics_dataframe(per_round_status):\n", + " # concatenate all the 'metrics' lists\n", + " lists = [r['metrics'] for r in per_round_status]\n", + " metrics = [i for l in lists for i in l]\n", + " return pd.DataFrame(metrics)\n", + "\n", + "def load_status_yaml(training_id=1, server='api_medperf_org'):\n", + " with open(os.path.join(os.path.expanduser('~'), '.medperf/training', server, f\"{training_id}\", 'status.yaml'), 'r') as f:\n", + " return yaml.safe_load(f)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "df = create_metrics_dataframe(load_status_yaml())" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
metric_namemetric_originmetric_valueroundtask_name
0val_evalkcma@mdanderson.org0.9021860aggregated_model_validation
1val_eval_C1kcma@mdanderson.org0.8436170aggregated_model_validation
2val_eval_C2kcma@mdanderson.org0.8897110aggregated_model_validation
3val_eval_C3kcma@mdanderson.org0.9184370aggregated_model_validation
4val_eval_C4kcma@mdanderson.org0.9569770aggregated_model_validation
\n", + "
" + ], + "text/plain": [ + " metric_name metric_origin metric_value round \\\n", + "0 val_eval kcma@mdanderson.org 0.902186 0 \n", + "1 val_eval_C1 kcma@mdanderson.org 0.843617 0 \n", + "2 val_eval_C2 kcma@mdanderson.org 0.889711 0 \n", + "3 val_eval_C3 kcma@mdanderson.org 0.918437 0 \n", + "4 val_eval_C4 kcma@mdanderson.org 0.956977 0 \n", + "\n", + " task_name \n", + "0 aggregated_model_validation \n", + "1 aggregated_model_validation \n", + "2 aggregated_model_validation \n", + "3 aggregated_model_validation \n", + "4 aggregated_model_validation " + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "flpost", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.15" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/server/.env.local.local-auth b/server/.env.local.local-auth index fa5ea9c33..db4449d2d 100644 --- a/server/.env.local.local-auth +++ b/server/.env.local.local-auth @@ -20,4 +20,12 @@ GS_BUCKET_NAME= AUTH_AUDIENCE=https://localhost-localdev/ AUTH_ISSUER=https://localhost:8000/ AUTH_JWK_URL= -AUTH_VERIFYING_KEY="-----BEGIN PUBLIC KEY-----\nMIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAtKO1SzU6N/sZTJmYNk0C\n/5XbK8eWfcKX2HxFl7fr0V++wrXXGsMs9A8hQEbVWtgYbWaOSkXN0ojmcUt1NFcb\nSPYLmOK/oUXVASEbuZAdIi+ByQ1EnIIAmYSKjRBDUQM8wc73Z9AvrjnhrvEHyrIN\nKyXeLnaCKj/r0s5sQA85SngnCWQbZsRQyHysfsQLwguG0SKFF9EfdNJiaoD8lLBo\nqvUQIYi8MXuVAB7O5EomJoZJe7KEeemsLhCnjTlKHcumjnAiRy5Y0rL6aFXgQkg0\nY4NWxMbsIWAplzh2qCs2jEd88mAUJnHkMzeOKhb1Q+tcmg6ZG6GmwT9fujsOjYrn\na/RTx83B1rRVRHHBFsEP4/ctVf2VdARz+RO+mIh5yZsPiqmRSKpHfbKgnkBpQlAj\nwVrzP9HYT11EXGFesLKRt6Oin0I5FkJ1Ji4w680XjeyZ4KInMY87OvQtltIyrZI9\nR9uY9EnpISGYch6kxbVw0GzdQdP/0mUnYlIeWwyvsXsWB/b3pZ9BiQuCMtlxoWlk\naRjWk9dWIZKFL2uhgeNeY5Wh3Qx9EFx8hnz9ohdaNBPB5BNO2qI61NedFrjYN9LF\nSfcGL7iATU1JQS4rDisnyjDikkTHL9B1u6sMrTsoaqi9Dl5b0gC8RnPVnJItasMN\n9HcW8Pfo2Ava4ler7oU47jUCAwEAAQ==\n-----END PUBLIC KEY-----" \ No newline at end of file +AUTH_VERIFYING_KEY="-----BEGIN PUBLIC KEY-----\nMIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAtKO1SzU6N/sZTJmYNk0C\n/5XbK8eWfcKX2HxFl7fr0V++wrXXGsMs9A8hQEbVWtgYbWaOSkXN0ojmcUt1NFcb\nSPYLmOK/oUXVASEbuZAdIi+ByQ1EnIIAmYSKjRBDUQM8wc73Z9AvrjnhrvEHyrIN\nKyXeLnaCKj/r0s5sQA85SngnCWQbZsRQyHysfsQLwguG0SKFF9EfdNJiaoD8lLBo\nqvUQIYi8MXuVAB7O5EomJoZJe7KEeemsLhCnjTlKHcumjnAiRy5Y0rL6aFXgQkg0\nY4NWxMbsIWAplzh2qCs2jEd88mAUJnHkMzeOKhb1Q+tcmg6ZG6GmwT9fujsOjYrn\na/RTx83B1rRVRHHBFsEP4/ctVf2VdARz+RO+mIh5yZsPiqmRSKpHfbKgnkBpQlAj\nwVrzP9HYT11EXGFesLKRt6Oin0I5FkJ1Ji4w680XjeyZ4KInMY87OvQtltIyrZI9\nR9uY9EnpISGYch6kxbVw0GzdQdP/0mUnYlIeWwyvsXsWB/b3pZ9BiQuCMtlxoWlk\naRjWk9dWIZKFL2uhgeNeY5Wh3Qx9EFx8hnz9ohdaNBPB5BNO2qI61NedFrjYN9LF\nSfcGL7iATU1JQS4rDisnyjDikkTHL9B1u6sMrTsoaqi9Dl5b0gC8RnPVnJItasMN\n9HcW8Pfo2Ava4ler7oU47jUCAwEAAQ==\n-----END PUBLIC KEY-----" + +#CA configuration +CA_NAME="MedPerf CA" +CA_CONFIG={"address":"https://127.0.0.1","port":443,"fingerprint":"fingerprint","client_provisioner":"auth0","server_provisioner":"acme"} +CA_MLCUBE_NAME="MedPerf CA" +CA_MLCUBE_URL="https://raw.githubusercontent.com/hasan7n/medperf/19c80d88deaad27b353d1cb9bc180757534027aa/examples/fl/mock_cert/mlcube/mlcube.yaml" +CA_MLCUBE_HASH="d3d723fa6e14ea5f3ff1b215c4543295271bebf301d113c4953c5d54310b7dd1" +CA_MLCUBE_IMAGE_HASH="48a16a6b1b42aed79741abf5a799b309feac7f2b4ccb7a8ac89a0fccfc6dd691" \ No newline at end of file diff --git a/server/aggregator/__init__.py b/server/aggregator/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/server/aggregator/admin.py b/server/aggregator/admin.py new file mode 100644 index 000000000..ef24fb8a0 --- /dev/null +++ b/server/aggregator/admin.py @@ -0,0 +1,7 @@ +from django.contrib import admin +from .models import Aggregator + + +@admin.register(Aggregator) +class AggregatorAdmin(admin.ModelAdmin): + list_display = [field.name for field in Aggregator._meta.fields] diff --git a/server/aggregator/apps.py b/server/aggregator/apps.py new file mode 100644 index 000000000..4fa2bf6ac --- /dev/null +++ b/server/aggregator/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class AggregatorConfig(AppConfig): + default_auto_field = 'django.db.models.BigAutoField' + name = 'aggregator' diff --git a/server/aggregator/migrations/0001_initial.py b/server/aggregator/migrations/0001_initial.py new file mode 100644 index 000000000..f6d026115 --- /dev/null +++ b/server/aggregator/migrations/0001_initial.py @@ -0,0 +1,56 @@ +# Generated by Django 4.2.11 on 2024-04-29 13:21 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ("mlcube", "0002_alter_mlcube_unique_together"), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.CreateModel( + name="Aggregator", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("name", models.CharField(max_length=20, unique=True)), + ("config", models.JSONField()), + ("is_valid", models.BooleanField(default=True)), + ("metadata", models.JSONField(blank=True, default=dict, null=True)), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("modified_at", models.DateTimeField(auto_now=True)), + ( + "aggregation_mlcube", + models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, + related_name="aggregators", + to="mlcube.mlcube", + ), + ), + ( + "owner", + models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "ordering": ["created_at"], + }, + ), + ] diff --git a/server/aggregator/migrations/__init__.py b/server/aggregator/migrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/server/aggregator/models.py b/server/aggregator/models.py new file mode 100644 index 000000000..8f29fcb15 --- /dev/null +++ b/server/aggregator/models.py @@ -0,0 +1,25 @@ +from django.db import models +from django.contrib.auth import get_user_model + +User = get_user_model() + + +class Aggregator(models.Model): + owner = models.ForeignKey(User, on_delete=models.PROTECT) + name = models.CharField(max_length=20, unique=True) + config = models.JSONField() + aggregation_mlcube = models.ForeignKey( + "mlcube.MlCube", + on_delete=models.PROTECT, + related_name="aggregators", + ) + is_valid = models.BooleanField(default=True) + metadata = models.JSONField(default=dict, blank=True, null=True) + created_at = models.DateTimeField(auto_now_add=True) + modified_at = models.DateTimeField(auto_now=True) + + def __str__(self): + return str(self.config) + + class Meta: + ordering = ["created_at"] diff --git a/server/aggregator/serializers.py b/server/aggregator/serializers.py new file mode 100644 index 000000000..89eabfbbb --- /dev/null +++ b/server/aggregator/serializers.py @@ -0,0 +1,9 @@ +from rest_framework import serializers +from .models import Aggregator + + +class AggregatorSerializer(serializers.ModelSerializer): + class Meta: + model = Aggregator + fields = "__all__" + read_only_fields = ["owner"] diff --git a/server/aggregator/urls.py b/server/aggregator/urls.py new file mode 100644 index 000000000..8641718c7 --- /dev/null +++ b/server/aggregator/urls.py @@ -0,0 +1,12 @@ +from django.urls import path +from . import views +import aggregator_association.views as tviews + +app_name = "aggregator" + +urlpatterns = [ + path("", views.AggregatorList.as_view()), + path("/", views.AggregatorDetail.as_view()), + path("training/", tviews.ExperimentAggregatorList.as_view()), + path("/training//", tviews.AggregatorApproval.as_view()), +] diff --git a/server/aggregator/views.py b/server/aggregator/views.py new file mode 100644 index 000000000..940d87252 --- /dev/null +++ b/server/aggregator/views.py @@ -0,0 +1,52 @@ +from django.http import Http404 +from rest_framework.generics import GenericAPIView +from rest_framework.response import Response +from rest_framework import status + +from .models import Aggregator +from .serializers import AggregatorSerializer +from drf_spectacular.utils import extend_schema + + +class AggregatorList(GenericAPIView): + serializer_class = AggregatorSerializer + queryset = "" + + @extend_schema(operation_id="aggregators_retrieve_all") + def get(self, request, format=None): + """ + List all aggregators + """ + aggregators = Aggregator.objects.all() + aggregators = self.paginate_queryset(aggregators) + serializer = AggregatorSerializer(aggregators, many=True) + return self.get_paginated_response(serializer.data) + + def post(self, request, format=None): + """ + Create a new Aggregator + """ + serializer = AggregatorSerializer(data=request.data) + if serializer.is_valid(): + serializer.save(owner=request.user) + return Response(serializer.data, status=status.HTTP_201_CREATED) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + +class AggregatorDetail(GenericAPIView): + serializer_class = AggregatorSerializer + queryset = "" + + def get_object(self, pk): + try: + return Aggregator.objects.get(pk=pk) + except Aggregator.DoesNotExist: + raise Http404 + + def get(self, request, pk, format=None): + """ + Retrieve an aggregator instance. + """ + aggregator = self.get_object(pk) + serializer = AggregatorSerializer(aggregator) + return Response(serializer.data) diff --git a/server/aggregator_association/__init__.py b/server/aggregator_association/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/server/aggregator_association/admin.py b/server/aggregator_association/admin.py new file mode 100644 index 000000000..33bbd2b50 --- /dev/null +++ b/server/aggregator_association/admin.py @@ -0,0 +1,7 @@ +from django.contrib import admin +from .models import ExperimentAggregator + + +@admin.register(ExperimentAggregator) +class ExperimentAggregatorAdmin(admin.ModelAdmin): + list_display = [field.name for field in ExperimentAggregator._meta.fields] diff --git a/server/aggregator_association/apps.py b/server/aggregator_association/apps.py new file mode 100644 index 000000000..4df2c898d --- /dev/null +++ b/server/aggregator_association/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class AggregatorAssociationConfig(AppConfig): + default_auto_field = 'django.db.models.BigAutoField' + name = 'aggregator_association' diff --git a/server/aggregator_association/migrations/0001_initial.py b/server/aggregator_association/migrations/0001_initial.py new file mode 100644 index 000000000..56b2e466a --- /dev/null +++ b/server/aggregator_association/migrations/0001_initial.py @@ -0,0 +1,65 @@ +# Generated by Django 4.2.11 on 2024-04-29 13:21 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ("aggregator", "0001_initial"), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.CreateModel( + name="ExperimentAggregator", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("metadata", models.JSONField(default=dict)), + ( + "approval_status", + models.CharField( + choices=[ + ("PENDING", "PENDING"), + ("APPROVED", "APPROVED"), + ("REJECTED", "REJECTED"), + ], + default="PENDING", + max_length=100, + ), + ), + ("approved_at", models.DateTimeField(blank=True, null=True)), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("modified_at", models.DateTimeField(auto_now=True)), + ( + "aggregator", + models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, + to="aggregator.aggregator", + ), + ), + ( + "initiated_by", + models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "ordering": ["created_at"], + }, + ), + ] diff --git a/server/aggregator_association/migrations/0002_initial.py b/server/aggregator_association/migrations/0002_initial.py new file mode 100644 index 000000000..ef70e0e1a --- /dev/null +++ b/server/aggregator_association/migrations/0002_initial.py @@ -0,0 +1,25 @@ +# Generated by Django 4.2.11 on 2024-04-29 13:21 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ("training", "0001_initial"), + ("aggregator_association", "0001_initial"), + ] + + operations = [ + migrations.AddField( + model_name="experimentaggregator", + name="training_exp", + field=models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + to="training.trainingexperiment", + ), + ), + ] diff --git a/server/aggregator_association/migrations/__init__.py b/server/aggregator_association/migrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/server/aggregator_association/models.py b/server/aggregator_association/models.py new file mode 100644 index 000000000..3642679c1 --- /dev/null +++ b/server/aggregator_association/models.py @@ -0,0 +1,27 @@ +from django.db import models +from django.contrib.auth import get_user_model + +User = get_user_model() + + +class ExperimentAggregator(models.Model): + MODEL_STATUS = ( + ("PENDING", "PENDING"), + ("APPROVED", "APPROVED"), + ("REJECTED", "REJECTED"), + ) + aggregator = models.ForeignKey("aggregator.Aggregator", on_delete=models.PROTECT) + training_exp = models.ForeignKey( + "training.TrainingExperiment", on_delete=models.CASCADE + ) + initiated_by = models.ForeignKey(User, on_delete=models.PROTECT) + metadata = models.JSONField(default=dict) + approval_status = models.CharField( + choices=MODEL_STATUS, max_length=100, default="PENDING" + ) + approved_at = models.DateTimeField(null=True, blank=True) + created_at = models.DateTimeField(auto_now_add=True) + modified_at = models.DateTimeField(auto_now=True) + + class Meta: + ordering = ["created_at"] diff --git a/server/aggregator_association/permissions.py b/server/aggregator_association/permissions.py new file mode 100644 index 000000000..a4723b0fe --- /dev/null +++ b/server/aggregator_association/permissions.py @@ -0,0 +1,54 @@ +from rest_framework.permissions import BasePermission +from training.models import TrainingExperiment +from aggregator.models import Aggregator + + +class IsAdmin(BasePermission): + def has_permission(self, request, view): + return request.user.is_superuser + + +class IsAggregatorOwner(BasePermission): + def get_object(self, pk): + try: + return Aggregator.objects.get(pk=pk) + except Aggregator.DoesNotExist: + return None + + def has_permission(self, request, view): + if request.method == "POST": + pk = request.data.get("aggregator", None) + else: + pk = view.kwargs.get("pk", None) + if not pk: + return False + aggregator = self.get_object(pk) + if not aggregator: + return False + if aggregator.owner.id == request.user.id: + return True + else: + return False + + +class IsExpOwner(BasePermission): + def get_object(self, pk): + try: + return TrainingExperiment.objects.get(pk=pk) + except TrainingExperiment.DoesNotExist: + return None + + def has_permission(self, request, view): + if request.method == "POST": + pk = request.data.get("training_exp", None) + else: + pk = view.kwargs.get("tid", None) + if not pk: + return False + training_exp = self.get_object(pk) + if not training_exp: + return False + if training_exp.owner.id == request.user.id: + return True + else: + return False diff --git a/server/aggregator_association/serializers.py b/server/aggregator_association/serializers.py new file mode 100644 index 000000000..712d7e3b7 --- /dev/null +++ b/server/aggregator_association/serializers.py @@ -0,0 +1,120 @@ +from rest_framework import serializers +from django.utils import timezone +from training.models import TrainingExperiment + +from .models import ExperimentAggregator +from utils.associations import ( + validate_approval_status_on_creation, + validate_approval_status_on_update, +) + + +class ExperimentAggregatorListSerializer(serializers.ModelSerializer): + class Meta: + model = ExperimentAggregator + read_only_fields = ["initiated_by", "approved_at"] + fields = "__all__" + + def validate(self, data): + tid = self.context["request"].data.get("training_exp") + aggregator = self.context["request"].data.get("aggregator") + approval_status = self.context["request"].data.get("approval_status", "PENDING") + + training_exp = TrainingExperiment.objects.get(pk=tid) + + # training_exp approval status + training_exp_approval_status = training_exp.approval_status + if training_exp_approval_status != "APPROVED": + raise serializers.ValidationError( + "Association requests can be made only on an approved training experiment" + ) + + # training_exp event status + event = training_exp.event + if event and not event.finished: + raise serializers.ValidationError( + "The training experiment does not currently accept associations" + ) + + # An already approved aggregator + exp_aggregator = training_exp.aggregator + if exp_aggregator and exp_aggregator.id != aggregator: + raise serializers.ValidationError( + "The training experiment already has an aggregator" + ) + + # approval status + last_experiment_aggregator = ( + ExperimentAggregator.objects.filter( + training_exp__id=tid, aggregator__id=aggregator + ) + .order_by("-created_at") + .first() + ) + validate_approval_status_on_creation( + last_experiment_aggregator, approval_status + ) + + return data + + def create(self, validated_data): + approval_status = validated_data.get("approval_status", "PENDING") + if approval_status != "PENDING": + validated_data["approved_at"] = timezone.now() + else: + same_owner = ( + validated_data["aggregator"].owner.id + == validated_data["training_exp"].owner.id + ) + if same_owner: + validated_data["approval_status"] = "APPROVED" + validated_data["approved_at"] = timezone.now() + return ExperimentAggregator.objects.create(**validated_data) + + +class AggregatorApprovalSerializer(serializers.ModelSerializer): + class Meta: + model = ExperimentAggregator + read_only_fields = ["initiated_by", "approved_at"] + fields = [ + "approval_status", + "initiated_by", + "approved_at", + "created_at", + "modified_at", + ] + + def validate(self, data): + if not self.instance: + raise serializers.ValidationError("No aggregator association found") + # check if there is already an approved aggregator + exp_aggregator = self.instance.training_exp.aggregator + if exp_aggregator and exp_aggregator.id != self.instance.aggregator.id: + raise serializers.ValidationError( + "The training experiment already has an aggregator" + ) + return data + + def validate_approval_status(self, cur_approval_status): + last_approval_status = self.instance.approval_status + initiated_user = self.instance.initiated_by + current_user = self.context["request"].user + validate_approval_status_on_update( + last_approval_status, cur_approval_status, initiated_user, current_user + ) + + event = self.instance.training_exp.event + if event and not event.finished: + raise serializers.ValidationError( + "User cannot approve or reject an association when the experiment is ongoing" + ) + return cur_approval_status + + def update(self, instance, validated_data): + if "approval_status" in validated_data: + if validated_data["approval_status"] != instance.approval_status: + instance.approval_status = validated_data["approval_status"] + if instance.approval_status != "PENDING": + instance.approved_at = timezone.now() + instance.save() + return instance diff --git a/server/aggregator_association/views.py b/server/aggregator_association/views.py new file mode 100644 index 000000000..0ab6dca75 --- /dev/null +++ b/server/aggregator_association/views.py @@ -0,0 +1,81 @@ +from .models import ExperimentAggregator +from django.http import Http404 +from rest_framework.generics import GenericAPIView +from rest_framework.response import Response +from rest_framework import status + +from .permissions import IsAdmin, IsAggregatorOwner, IsExpOwner +from .serializers import ( + ExperimentAggregatorListSerializer, + AggregatorApprovalSerializer, +) + + +class ExperimentAggregatorList(GenericAPIView): + permission_classes = [IsAdmin | IsExpOwner | IsAggregatorOwner] + serializer_class = ExperimentAggregatorListSerializer + queryset = "" + + def post(self, request, format=None): + """ + Associate a aggregator to a training_exp + """ + serializer = ExperimentAggregatorListSerializer( + data=request.data, context={"request": request} + ) + if serializer.is_valid(): + serializer.save(initiated_by=request.user) + return Response(serializer.data, status=status.HTTP_201_CREATED) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + +class AggregatorApproval(GenericAPIView): + serializer_class = AggregatorApprovalSerializer + queryset = "" + + def get_permissions(self): + self.permission_classes = [IsAdmin | IsExpOwner | IsAggregatorOwner] + if self.request.method == "DELETE": + self.permission_classes = [IsAdmin] + return super(self.__class__, self).get_permissions() + + def get_object(self, aggregator_id, training_exp_id): + try: + return ExperimentAggregator.objects.filter( + aggregator__id=aggregator_id, training_exp__id=training_exp_id + ) + except ExperimentAggregator.DoesNotExist: + raise Http404 + + def get(self, request, pk, tid, format=None): + """ + Retrieve approval status of training_exp aggregator associations + """ + training_expaggregator = ( + self.get_object(pk, tid).order_by("-created_at").first() + ) + serializer = AggregatorApprovalSerializer(training_expaggregator) + return Response(serializer.data) + + def put(self, request, pk, tid, format=None): + """ + Update approval status of the last training_exp aggregator association + """ + training_expaggregator = ( + self.get_object(pk, tid).order_by("-created_at").first() + ) + serializer = AggregatorApprovalSerializer( + training_expaggregator, data=request.data, context={"request": request} + ) + if serializer.is_valid(): + serializer.save() + return Response(serializer.data) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + def delete(self, request, pk, tid, format=None): + """ + Delete a training_exp aggregator association + """ + training_expaggregator = self.get_object(pk, tid) + training_expaggregator.delete() + return Response(status=status.HTTP_204_NO_CONTENT) diff --git a/server/benchmark/serializers.py b/server/benchmark/serializers.py index 11f007d37..344635ade 100644 --- a/server/benchmark/serializers.py +++ b/server/benchmark/serializers.py @@ -57,11 +57,6 @@ def validate_approval_status(self, approval_status): raise serializers.ValidationError( "User can only approve or reject a benchmark" ) - if self.instance.state == "DEVELOPMENT": - raise serializers.ValidationError( - "User cannot approve or reject when benchmark is in development stage" - ) - if approval_status == "APPROVED": if self.instance.approval_status == "REJECTED": raise serializers.ValidationError( diff --git a/server/benchmarkdataset/serializers.py b/server/benchmarkdataset/serializers.py index 9cc120079..cbf4f6d00 100644 --- a/server/benchmarkdataset/serializers.py +++ b/server/benchmarkdataset/serializers.py @@ -4,6 +4,10 @@ from dataset.models import Dataset from .models import BenchmarkDataset +from utils.associations import ( + validate_approval_status_on_creation, + validate_approval_status_on_update, +) class BenchmarkDatasetListSerializer(serializers.ModelSerializer): @@ -12,61 +16,42 @@ class Meta: read_only_fields = ["initiated_by", "approved_at"] fields = "__all__" - def __validate_approval_status(self, last_benchmarkdataset, approval_status): - if not last_benchmarkdataset: - if approval_status != "PENDING": - raise serializers.ValidationError( - "User can approve or reject association request only if there are prior requests" - ) - else: - if approval_status == "PENDING": - if last_benchmarkdataset.approval_status != "REJECTED": - raise serializers.ValidationError( - "User can create a new request only if prior request is rejected" - ) - elif approval_status == "APPROVED": - raise serializers.ValidationError( - "User cannot create an approved association request" - ) - # approval_status == "REJECTED": - else: - if last_benchmarkdataset.approval_status != "APPROVED": - raise serializers.ValidationError( - "User can reject request only if prior request is approved" - ) - def validate(self, data): bid = self.context["request"].data.get("benchmark") dataset = self.context["request"].data.get("dataset") approval_status = self.context["request"].data.get("approval_status", "PENDING") + benchmark = Benchmark.objects.get(pk=bid) - benchmark_state = benchmark.state - if benchmark_state != "OPERATION": - raise serializers.ValidationError( - "Association requests can be made only on an operational benchmark" - ) + + # benchmark approval status benchmark_approval_status = benchmark.approval_status if benchmark_approval_status != "APPROVED": raise serializers.ValidationError( "Association requests can be made only on an approved benchmark" ) + + # dataset state dataset_obj = Dataset.objects.get(pk=dataset) dataset_state = dataset_obj.state if dataset_state != "OPERATION": raise serializers.ValidationError( "Association requests can be made only on an operational dataset" ) + + # dataset prep mlcube if dataset_obj.data_preparation_mlcube != benchmark.data_preparation_mlcube: raise serializers.ValidationError( "Dataset association request can be made only if the dataset" " was prepared with benchmark's data preparation MLCube" ) + + # approval status last_benchmarkdataset = ( BenchmarkDataset.objects.filter(benchmark__id=bid, dataset__id=dataset) .order_by("-created_at") .first() ) - self.__validate_approval_status(last_benchmarkdataset, approval_status) + validate_approval_status_on_creation(last_benchmarkdataset, approval_status) return data @@ -75,10 +60,11 @@ def create(self, validated_data): if approval_status != "PENDING": validated_data["approved_at"] = timezone.now() else: - if ( + same_owner = ( validated_data["dataset"].owner.id == validated_data["benchmark"].owner.id - ): + ) + if same_owner: validated_data["approval_status"] = "APPROVED" validated_data["approved_at"] = timezone.now() return BenchmarkDataset.objects.create(**validated_data) @@ -103,17 +89,11 @@ def validate(self, data): def validate_approval_status(self, cur_approval_status): last_approval_status = self.instance.approval_status - if last_approval_status != "PENDING": - raise serializers.ValidationError( - "User can approve or reject only a pending request" - ) initiated_user = self.instance.initiated_by current_user = self.context["request"].user - if cur_approval_status == "APPROVED": - if current_user.id == initiated_user.id: - raise serializers.ValidationError( - "Same user cannot approve the association request" - ) + validate_approval_status_on_update( + last_approval_status, cur_approval_status, initiated_user, current_user + ) return cur_approval_status def update(self, instance, validated_data): diff --git a/server/benchmarkmodel/serializers.py b/server/benchmarkmodel/serializers.py index afa34acd4..9e13f612b 100644 --- a/server/benchmarkmodel/serializers.py +++ b/server/benchmarkmodel/serializers.py @@ -4,6 +4,10 @@ from mlcube.models import MlCube from .models import BenchmarkModel +from utils.associations import ( + validate_approval_status_on_creation, + validate_approval_status_on_update, +) class BenchmarkModelListSerializer(serializers.ModelSerializer): @@ -16,49 +20,32 @@ def validate(self, data): bid = self.context["request"].data.get("benchmark") mlcube = self.context["request"].data.get("model_mlcube") approval_status = self.context["request"].data.get("approval_status", "PENDING") + benchmark = Benchmark.objects.get(pk=bid) - benchmark_state = benchmark.state - if benchmark_state != "OPERATION": - raise serializers.ValidationError( - "Association requests can be made only on an operational benchmark" - ) + + # benchmark approval status benchmark_approval_status = benchmark.approval_status if benchmark_approval_status != "APPROVED": raise serializers.ValidationError( "Association requests can be made only on an approved benchmark" ) - mlcube_state = MlCube.objects.get(pk=mlcube).state + + # mlcube state + mlcube_obj = MlCube.objects.get(pk=mlcube) + mlcube_state = mlcube_obj.state if mlcube_state != "OPERATION": raise serializers.ValidationError( "Association requests can be made only on an operational model mlcube" ) + + # approval status last_benchmarkmodel = ( BenchmarkModel.objects.filter(benchmark__id=bid, model_mlcube__id=mlcube) .order_by("-created_at") .first() ) - if not last_benchmarkmodel: - if approval_status != "PENDING": - raise serializers.ValidationError( - "User can approve or reject association request only if there are prior requests" - ) - else: - if approval_status == "PENDING": - if last_benchmarkmodel.approval_status != "REJECTED": - raise serializers.ValidationError( - "User can create a new request only if prior request is rejected" - ) - # check valid results passed - elif approval_status == "APPROVED": - raise serializers.ValidationError( - "User cannot create an approved association request" - ) - # approval_status == "REJECTED": - else: - if last_benchmarkmodel.approval_status != "APPROVED": - raise serializers.ValidationError( - "User can reject request only if prior request is approved" - ) + validate_approval_status_on_creation(last_benchmarkmodel, approval_status) + return data def create(self, validated_data): @@ -66,10 +53,11 @@ def create(self, validated_data): if approval_status != "PENDING": validated_data["approved_at"] = timezone.now() else: - if ( + same_owner = ( validated_data["model_mlcube"].owner.id == validated_data["benchmark"].owner.id - ): + ) + if same_owner: validated_data["approval_status"] = "APPROVED" validated_data["approved_at"] = timezone.now() return BenchmarkModel.objects.create(**validated_data) @@ -95,17 +83,11 @@ def validate(self, data): def validate_approval_status(self, cur_approval_status): last_approval_status = self.instance.approval_status - if last_approval_status != "PENDING": - raise serializers.ValidationError( - "User can approve or reject only a pending request" - ) initiated_user = self.instance.initiated_by current_user = self.context["request"].user - if cur_approval_status == "APPROVED": - if current_user.id == initiated_user.id: - raise serializers.ValidationError( - "Same user cannot approve the association request" - ) + validate_approval_status_on_update( + last_approval_status, cur_approval_status, initiated_user, current_user + ) return cur_approval_status def update(self, instance, validated_data): diff --git a/server/ca/__init__.py b/server/ca/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/server/ca/admin.py b/server/ca/admin.py new file mode 100644 index 000000000..04525f7d2 --- /dev/null +++ b/server/ca/admin.py @@ -0,0 +1,7 @@ +from django.contrib import admin +from .models import CA + + +@admin.register(CA) +class CAAdmin(admin.ModelAdmin): + list_display = [field.name for field in CA._meta.fields] diff --git a/server/ca/apps.py b/server/ca/apps.py new file mode 100644 index 000000000..5bfe9cfe8 --- /dev/null +++ b/server/ca/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class CaConfig(AppConfig): + default_auto_field = "django.db.models.BigAutoField" + name = "ca" diff --git a/server/ca/migrations/0001_initial.py b/server/ca/migrations/0001_initial.py new file mode 100644 index 000000000..875545c00 --- /dev/null +++ b/server/ca/migrations/0001_initial.py @@ -0,0 +1,72 @@ +# Generated by Django 4.2.11 on 2024-04-29 13:21 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ("mlcube", "0002_alter_mlcube_unique_together"), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.CreateModel( + name="CA", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("name", models.CharField(max_length=20, unique=True)), + ("config", models.JSONField()), + ("is_valid", models.BooleanField(default=True)), + ("metadata", models.JSONField(blank=True, default=dict, null=True)), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("modified_at", models.DateTimeField(auto_now=True)), + ( + "ca_mlcube", + models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, + related_name="ca", + to="mlcube.mlcube", + ), + ), + ( + "client_mlcube", + models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, + related_name="ca_client", + to="mlcube.mlcube", + ), + ), + ( + "owner", + models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, + to=settings.AUTH_USER_MODEL, + ), + ), + ( + "server_mlcube", + models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, + related_name="ca_server", + to="mlcube.mlcube", + ), + ), + ], + options={ + "ordering": ["created_at"], + }, + ), + ] diff --git a/server/ca/migrations/0002_createmedperfca.py b/server/ca/migrations/0002_createmedperfca.py new file mode 100644 index 000000000..f93eb513b --- /dev/null +++ b/server/ca/migrations/0002_createmedperfca.py @@ -0,0 +1,43 @@ +from django.contrib.auth import get_user_model +from django.db import migrations +from django.db.backends.postgresql.schema import DatabaseSchemaEditor +from django.db.migrations.state import StateApps +from django.conf import settings +from ca.models import CA +from mlcube.models import MlCube + +User = get_user_model() + + +def createmedperfca(apps: StateApps, schema_editor: DatabaseSchemaEditor) -> None: + """ + Dynamically create the configured main CA as part of a migration + """ + admin_user = User.objects.get(username=settings.SUPERUSER_USERNAME) + ca_mlcube = MlCube.objects.create( + name=settings.CA_MLCUBE_NAME, + git_mlcube_url=settings.CA_MLCUBE_URL, + mlcube_hash=settings.CA_MLCUBE_HASH, + image_hash=settings.CA_MLCUBE_IMAGE_HASH, + owner=admin_user, + state="OPERATION", + ) + CA.objects.create( + name=settings.CA_NAME, + config=settings.CA_CONFIG, + ca_mlcube=ca_mlcube, + client_mlcube=ca_mlcube, + server_mlcube=ca_mlcube, + owner=admin_user, + ) + + +class Migration(migrations.Migration): + + initial = True + dependencies = [ + ("ca", "0001_initial"), + ("user", "0001_createsuperuser"), + ("mlcube", "0002_alter_mlcube_unique_together"), + ] + operations = [migrations.RunPython(createmedperfca)] diff --git a/server/ca/migrations/__init__.py b/server/ca/migrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/server/ca/models.py b/server/ca/models.py new file mode 100644 index 000000000..5b3d62d30 --- /dev/null +++ b/server/ca/models.py @@ -0,0 +1,29 @@ +from django.db import models +from django.contrib.auth import get_user_model + +User = get_user_model() + + +class CA(models.Model): + owner = models.ForeignKey(User, on_delete=models.PROTECT) + name = models.CharField(max_length=20, unique=True) + config = models.JSONField() + client_mlcube = models.ForeignKey( + "mlcube.MlCube", on_delete=models.PROTECT, related_name="ca_client" + ) + server_mlcube = models.ForeignKey( + "mlcube.MlCube", on_delete=models.PROTECT, related_name="ca_server" + ) + ca_mlcube = models.ForeignKey( + "mlcube.MlCube", on_delete=models.PROTECT, related_name="ca" + ) + is_valid = models.BooleanField(default=True) + metadata = models.JSONField(default=dict, blank=True, null=True) + created_at = models.DateTimeField(auto_now_add=True) + modified_at = models.DateTimeField(auto_now=True) + + def __str__(self): + return str(self.config) + + class Meta: + ordering = ["created_at"] diff --git a/server/ca/serializers.py b/server/ca/serializers.py new file mode 100644 index 000000000..d693058da --- /dev/null +++ b/server/ca/serializers.py @@ -0,0 +1,9 @@ +from rest_framework import serializers +from .models import CA + + +class CASerializer(serializers.ModelSerializer): + class Meta: + model = CA + fields = "__all__" + read_only_fields = ["owner"] diff --git a/server/ca/urls.py b/server/ca/urls.py new file mode 100644 index 000000000..45d7ce343 --- /dev/null +++ b/server/ca/urls.py @@ -0,0 +1,12 @@ +from django.urls import path +from . import views +import ca_association.views as tviews + +app_name = "ca" + +urlpatterns = [ + path("", views.CAList.as_view()), + path("/", views.CADetail.as_view()), + path("training/", tviews.ExperimentCAList.as_view()), + path("/training//", tviews.CAApproval.as_view()), +] diff --git a/server/ca/views.py b/server/ca/views.py new file mode 100644 index 000000000..8a2bc6036 --- /dev/null +++ b/server/ca/views.py @@ -0,0 +1,52 @@ +from django.http import Http404 +from rest_framework.generics import GenericAPIView +from rest_framework.response import Response +from rest_framework import status + +from .models import CA +from .serializers import CASerializer +from drf_spectacular.utils import extend_schema + + +class CAList(GenericAPIView): + serializer_class = CASerializer + queryset = "" + + @extend_schema(operation_id="cas_retrieve_all") + def get(self, request, format=None): + """ + List all cas + """ + cas = CA.objects.all() + cas = self.paginate_queryset(cas) + serializer = CASerializer(cas, many=True) + return self.get_paginated_response(serializer.data) + + def post(self, request, format=None): + """ + Create a new CA + """ + serializer = CASerializer(data=request.data) + if serializer.is_valid(): + serializer.save(owner=request.user) + return Response(serializer.data, status=status.HTTP_201_CREATED) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + +class CADetail(GenericAPIView): + serializer_class = CASerializer + queryset = "" + + def get_object(self, pk): + try: + return CA.objects.get(pk=pk) + except CA.DoesNotExist: + raise Http404 + + def get(self, request, pk, format=None): + """ + Retrieve an ca instance. + """ + ca = self.get_object(pk) + serializer = CASerializer(ca) + return Response(serializer.data) diff --git a/server/ca_association/__init__.py b/server/ca_association/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/server/ca_association/admin.py b/server/ca_association/admin.py new file mode 100644 index 000000000..3317c359c --- /dev/null +++ b/server/ca_association/admin.py @@ -0,0 +1,7 @@ +from django.contrib import admin +from .models import ExperimentCA + + +@admin.register(ExperimentCA) +class ExperimentCAAdmin(admin.ModelAdmin): + list_display = [field.name for field in ExperimentCA._meta.fields] diff --git a/server/ca_association/apps.py b/server/ca_association/apps.py new file mode 100644 index 000000000..9fba9e4c1 --- /dev/null +++ b/server/ca_association/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class CAAssociationConfig(AppConfig): + default_auto_field = "django.db.models.BigAutoField" + name = "ca_association" diff --git a/server/ca_association/migrations/0001_initial.py b/server/ca_association/migrations/0001_initial.py new file mode 100644 index 000000000..5f9d17362 --- /dev/null +++ b/server/ca_association/migrations/0001_initial.py @@ -0,0 +1,64 @@ +# Generated by Django 4.2.11 on 2024-04-29 13:21 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ("ca", "0001_initial"), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.CreateModel( + name="ExperimentCA", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("metadata", models.JSONField(default=dict)), + ( + "approval_status", + models.CharField( + choices=[ + ("PENDING", "PENDING"), + ("APPROVED", "APPROVED"), + ("REJECTED", "REJECTED"), + ], + default="PENDING", + max_length=100, + ), + ), + ("approved_at", models.DateTimeField(blank=True, null=True)), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("modified_at", models.DateTimeField(auto_now=True)), + ( + "ca", + models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, to="ca.ca" + ), + ), + ( + "initiated_by", + models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "ordering": ["created_at"], + }, + ), + ] diff --git a/server/ca_association/migrations/0002_initial.py b/server/ca_association/migrations/0002_initial.py new file mode 100644 index 000000000..848d72c90 --- /dev/null +++ b/server/ca_association/migrations/0002_initial.py @@ -0,0 +1,25 @@ +# Generated by Django 4.2.11 on 2024-04-29 13:21 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ("training", "0001_initial"), + ("ca_association", "0001_initial"), + ] + + operations = [ + migrations.AddField( + model_name="experimentca", + name="training_exp", + field=models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + to="training.trainingexperiment", + ), + ), + ] diff --git a/server/ca_association/migrations/__init__.py b/server/ca_association/migrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/server/ca_association/models.py b/server/ca_association/models.py new file mode 100644 index 000000000..f67e4a21e --- /dev/null +++ b/server/ca_association/models.py @@ -0,0 +1,27 @@ +from django.db import models +from django.contrib.auth import get_user_model + +User = get_user_model() + + +class ExperimentCA(models.Model): + MODEL_STATUS = ( + ("PENDING", "PENDING"), + ("APPROVED", "APPROVED"), + ("REJECTED", "REJECTED"), + ) + ca = models.ForeignKey("ca.CA", on_delete=models.PROTECT) + training_exp = models.ForeignKey( + "training.TrainingExperiment", on_delete=models.CASCADE + ) + initiated_by = models.ForeignKey(User, on_delete=models.PROTECT) + metadata = models.JSONField(default=dict) + approval_status = models.CharField( + choices=MODEL_STATUS, max_length=100, default="PENDING" + ) + approved_at = models.DateTimeField(null=True, blank=True) + created_at = models.DateTimeField(auto_now_add=True) + modified_at = models.DateTimeField(auto_now=True) + + class Meta: + ordering = ["created_at"] diff --git a/server/ca_association/permissions.py b/server/ca_association/permissions.py new file mode 100644 index 000000000..640422df1 --- /dev/null +++ b/server/ca_association/permissions.py @@ -0,0 +1,54 @@ +from rest_framework.permissions import BasePermission +from training.models import TrainingExperiment +from ca.models import CA + + +class IsAdmin(BasePermission): + def has_permission(self, request, view): + return request.user.is_superuser + + +class IsCAOwner(BasePermission): + def get_object(self, pk): + try: + return CA.objects.get(pk=pk) + except CA.DoesNotExist: + return None + + def has_permission(self, request, view): + if request.method == "POST": + pk = request.data.get("ca", None) + else: + pk = view.kwargs.get("pk", None) + if not pk: + return False + ca = self.get_object(pk) + if not ca: + return False + if ca.owner.id == request.user.id: + return True + else: + return False + + +class IsExpOwner(BasePermission): + def get_object(self, pk): + try: + return TrainingExperiment.objects.get(pk=pk) + except TrainingExperiment.DoesNotExist: + return None + + def has_permission(self, request, view): + if request.method == "POST": + pk = request.data.get("training_exp", None) + else: + pk = view.kwargs.get("tid", None) + if not pk: + return False + training_exp = self.get_object(pk) + if not training_exp: + return False + if training_exp.owner.id == request.user.id: + return True + else: + return False diff --git a/server/ca_association/serializers.py b/server/ca_association/serializers.py new file mode 100644 index 000000000..9e417515a --- /dev/null +++ b/server/ca_association/serializers.py @@ -0,0 +1,117 @@ +from rest_framework import serializers +from django.utils import timezone +from training.models import TrainingExperiment +from django.conf import settings + +from .models import ExperimentCA +from utils.associations import ( + validate_approval_status_on_creation, + validate_approval_status_on_update, +) + + +class ExperimentCAListSerializer(serializers.ModelSerializer): + class Meta: + model = ExperimentCA + read_only_fields = ["initiated_by", "approved_at"] + fields = "__all__" + + def validate(self, data): + tid = self.context["request"].data.get("training_exp") + ca = self.context["request"].data.get("ca") + approval_status = self.context["request"].data.get("approval_status", "PENDING") + + training_exp = TrainingExperiment.objects.get(pk=tid) + + # training_exp approval status + training_exp_approval_status = training_exp.approval_status + if training_exp_approval_status != "APPROVED": + raise serializers.ValidationError( + "Association requests can be made only on an approved training experiment" + ) + + # training_exp event status + event = training_exp.event + if event and not event.finished: + raise serializers.ValidationError( + "The training experiment does not currently accept associations" + ) + + # An already approved ca + exp_ca = training_exp.ca + if exp_ca and exp_ca.id != ca: + raise serializers.ValidationError( + "The training experiment already has an ca" + ) + + # approval status + last_experiment_ca = ( + ExperimentCA.objects.filter(training_exp__id=tid, ca__id=ca) + .order_by("-created_at") + .first() + ) + validate_approval_status_on_creation(last_experiment_ca, approval_status) + + return data + + def create(self, validated_data): + approval_status = validated_data.get("approval_status", "PENDING") + if approval_status != "PENDING": + validated_data["approved_at"] = timezone.now() + else: + same_owner = ( + validated_data["ca"].owner.id == validated_data["training_exp"].owner.id + ) + is_main_ca = validated_data["ca"].name == settings.CA_NAME + if same_owner or is_main_ca: + validated_data["approval_status"] = "APPROVED" + validated_data["approved_at"] = timezone.now() + return ExperimentCA.objects.create(**validated_data) + + +class CAApprovalSerializer(serializers.ModelSerializer): + class Meta: + model = ExperimentCA + read_only_fields = ["initiated_by", "approved_at"] + fields = [ + "approval_status", + "initiated_by", + "approved_at", + "created_at", + "modified_at", + ] + + def validate(self, data): + if not self.instance: + raise serializers.ValidationError("No ca association found") + # check if there is already an approved ca + exp_ca = self.instance.training_exp.ca + if exp_ca and exp_ca.id != self.instance.ca.id: + raise serializers.ValidationError( + "The training experiment already has an ca" + ) + return data + + def validate_approval_status(self, cur_approval_status): + last_approval_status = self.instance.approval_status + initiated_user = self.instance.initiated_by + current_user = self.context["request"].user + validate_approval_status_on_update( + last_approval_status, cur_approval_status, initiated_user, current_user + ) + + event = self.instance.training_exp.event + if event and not event.finished: + raise serializers.ValidationError( + "User cannot approve or reject an association when the experiment is ongoing" + ) + return cur_approval_status + + def update(self, instance, validated_data): + if "approval_status" in validated_data: + if validated_data["approval_status"] != instance.approval_status: + instance.approval_status = validated_data["approval_status"] + if instance.approval_status != "PENDING": + instance.approved_at = timezone.now() + instance.save() + return instance diff --git a/server/ca_association/views.py b/server/ca_association/views.py new file mode 100644 index 000000000..4cf05bee0 --- /dev/null +++ b/server/ca_association/views.py @@ -0,0 +1,77 @@ +from .models import ExperimentCA +from django.http import Http404 +from rest_framework.generics import GenericAPIView +from rest_framework.response import Response +from rest_framework import status + +from .permissions import IsAdmin, IsCAOwner, IsExpOwner +from .serializers import ( + ExperimentCAListSerializer, + CAApprovalSerializer, +) + + +class ExperimentCAList(GenericAPIView): + permission_classes = [IsAdmin | IsExpOwner | IsCAOwner] + serializer_class = ExperimentCAListSerializer + queryset = "" + + def post(self, request, format=None): + """ + Associate a ca to a training_exp + """ + serializer = ExperimentCAListSerializer( + data=request.data, context={"request": request} + ) + if serializer.is_valid(): + serializer.save(initiated_by=request.user) + return Response(serializer.data, status=status.HTTP_201_CREATED) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + +class CAApproval(GenericAPIView): + serializer_class = CAApprovalSerializer + queryset = "" + + def get_permissions(self): + self.permission_classes = [IsAdmin | IsExpOwner | IsCAOwner] + if self.request.method == "DELETE": + self.permission_classes = [IsAdmin] + return super(self.__class__, self).get_permissions() + + def get_object(self, ca_id, training_exp_id): + try: + return ExperimentCA.objects.filter( + ca__id=ca_id, training_exp__id=training_exp_id + ) + except ExperimentCA.DoesNotExist: + raise Http404 + + def get(self, request, pk, tid, format=None): + """ + Retrieve approval status of training_exp ca associations + """ + training_expca = self.get_object(pk, tid).order_by("-created_at").first() + serializer = CAApprovalSerializer(training_expca) + return Response(serializer.data) + + def put(self, request, pk, tid, format=None): + """ + Update approval status of the last training_exp ca association + """ + training_expca = self.get_object(pk, tid).order_by("-created_at").first() + serializer = CAApprovalSerializer( + training_expca, data=request.data, context={"request": request} + ) + if serializer.is_valid(): + serializer.save() + return Response(serializer.data) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + def delete(self, request, pk, tid, format=None): + """ + Delete a training_exp ca association + """ + training_expca = self.get_object(pk, tid) + training_expca.delete() + return Response(status=status.HTTP_204_NO_CONTENT) diff --git a/server/dataset/serializers.py b/server/dataset/serializers.py index aaee5aaab..49f6173dd 100644 --- a/server/dataset/serializers.py +++ b/server/dataset/serializers.py @@ -1,5 +1,6 @@ from rest_framework import serializers from .models import Dataset +from user.serializers import UserSerializer class DatasetFullSerializer(serializers.ModelSerializer): @@ -60,3 +61,14 @@ def validate(self, data): "User cannot update non editable fields in Operation mode" ) return data + + +class DatasetWithOwnerInfoSerializer(serializers.ModelSerializer): + """This is needed for training to get datasets and their owners + with one API call.""" + + owner = UserSerializer() + + class Meta: + model = Dataset + fields = ["id", "owner"] diff --git a/server/dataset/urls.py b/server/dataset/urls.py index 5aa23fd5a..7b020662e 100644 --- a/server/dataset/urls.py +++ b/server/dataset/urls.py @@ -1,6 +1,7 @@ from django.urls import path from . import views from benchmarkdataset import views as bviews +from traindataset_association import views as tviews app_name = "Dataset" @@ -10,5 +11,8 @@ path("benchmarks/", bviews.BenchmarkDatasetList.as_view()), path("/benchmarks//", bviews.DatasetApproval.as_view()), # path("/benchmarks/", bviews.DatasetBenchmarksList.as_view()), - # NOTE: when activating this endpoint later, check permissions and write tests + # path("/training/", tviews.DatasetExperimentList.as_view()), + # NOTE: when activating those two endpoints later, check permissions and write tests + path("training/", tviews.ExperimentDatasetList.as_view()), + path("/training//", tviews.DatasetApproval.as_view()), ] diff --git a/server/medperf/settings.py b/server/medperf/settings.py index a9c83f6fe..4b4710735 100644 --- a/server/medperf/settings.py +++ b/server/medperf/settings.py @@ -60,6 +60,13 @@ SUPERUSER_PASSWORD = env("SUPERUSER_PASSWORD") +CA_NAME = env("CA_NAME") +CA_CONFIG = env.json("CA_CONFIG") +CA_MLCUBE_NAME = env("CA_MLCUBE_NAME") +CA_MLCUBE_URL = env("CA_MLCUBE_URL") +CA_MLCUBE_HASH = env("CA_MLCUBE_HASH") +CA_MLCUBE_IMAGE_HASH = env("CA_MLCUBE_IMAGE_HASH") + ALLOWED_HOSTS = env.list("ALLOWED_HOSTS", default=[]) # TODO Change later to list of allowed domains @@ -91,6 +98,13 @@ "benchmarkmodel", "user", "result", + "training", + "aggregator", + "ca", + "traindataset_association", + "aggregator_association", + "ca_association", + "trainingevent", "rest_framework", "rest_framework.authtoken", "drf_spectacular", diff --git a/server/medperf/urls.py b/server/medperf/urls.py index be4e07dce..fb68b5d8f 100644 --- a/server/medperf/urls.py +++ b/server/medperf/urls.py @@ -36,5 +36,8 @@ path("results/", include("result.urls", namespace=API_VERSION), name="result"), path("users/", include("user.urls", namespace=API_VERSION), name="users"), path("me/", include("utils.urls", namespace=API_VERSION), name="me"), + path("training/", include("training.urls", namespace=API_VERSION), name="training"), + path("aggregators/", include("aggregator.urls", namespace=API_VERSION), name="aggregator"), + path("cas/", include("ca.urls", namespace=API_VERSION), name="ca") ])), ] diff --git a/server/traindataset_association/__init__.py b/server/traindataset_association/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/server/traindataset_association/admin.py b/server/traindataset_association/admin.py new file mode 100644 index 000000000..13119fb54 --- /dev/null +++ b/server/traindataset_association/admin.py @@ -0,0 +1,7 @@ +from django.contrib import admin +from .models import ExperimentDataset + + +@admin.register(ExperimentDataset) +class ExperimentDatasetAdmin(admin.ModelAdmin): + list_display = [field.name for field in ExperimentDataset._meta.fields] diff --git a/server/traindataset_association/apps.py b/server/traindataset_association/apps.py new file mode 100644 index 000000000..680686727 --- /dev/null +++ b/server/traindataset_association/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class TraindatasetAssociationConfig(AppConfig): + default_auto_field = 'django.db.models.BigAutoField' + name = 'traindataset_association' diff --git a/server/traindataset_association/migrations/0001_initial.py b/server/traindataset_association/migrations/0001_initial.py new file mode 100644 index 000000000..7938a2561 --- /dev/null +++ b/server/traindataset_association/migrations/0001_initial.py @@ -0,0 +1,65 @@ +# Generated by Django 4.2.11 on 2024-04-29 13:21 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ("dataset", "0004_auto_20231211_1827"), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.CreateModel( + name="ExperimentDataset", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("metadata", models.JSONField(default=dict)), + ( + "approval_status", + models.CharField( + choices=[ + ("PENDING", "PENDING"), + ("APPROVED", "APPROVED"), + ("REJECTED", "REJECTED"), + ], + default="PENDING", + max_length=100, + ), + ), + ("approved_at", models.DateTimeField(blank=True, null=True)), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("modified_at", models.DateTimeField(auto_now=True)), + ( + "dataset", + models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, + to="dataset.dataset", + ), + ), + ( + "initiated_by", + models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "ordering": ["modified_at"], + }, + ), + ] diff --git a/server/traindataset_association/migrations/0002_initial.py b/server/traindataset_association/migrations/0002_initial.py new file mode 100644 index 000000000..6ab74922e --- /dev/null +++ b/server/traindataset_association/migrations/0002_initial.py @@ -0,0 +1,25 @@ +# Generated by Django 4.2.11 on 2024-04-29 13:21 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ("training", "0001_initial"), + ("traindataset_association", "0001_initial"), + ] + + operations = [ + migrations.AddField( + model_name="experimentdataset", + name="training_exp", + field=models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + to="training.trainingexperiment", + ), + ), + ] diff --git a/server/traindataset_association/migrations/__init__.py b/server/traindataset_association/migrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/server/traindataset_association/models.py b/server/traindataset_association/models.py new file mode 100644 index 000000000..dc71107ca --- /dev/null +++ b/server/traindataset_association/models.py @@ -0,0 +1,27 @@ +from django.db import models +from django.contrib.auth import get_user_model + +User = get_user_model() + + +class ExperimentDataset(models.Model): + MODEL_STATUS = ( + ("PENDING", "PENDING"), + ("APPROVED", "APPROVED"), + ("REJECTED", "REJECTED"), + ) + dataset = models.ForeignKey("dataset.Dataset", on_delete=models.PROTECT) + training_exp = models.ForeignKey( + "training.TrainingExperiment", on_delete=models.CASCADE + ) + initiated_by = models.ForeignKey(User, on_delete=models.PROTECT) + metadata = models.JSONField(default=dict) + approval_status = models.CharField( + choices=MODEL_STATUS, max_length=100, default="PENDING" + ) + approved_at = models.DateTimeField(null=True, blank=True) + created_at = models.DateTimeField(auto_now_add=True) + modified_at = models.DateTimeField(auto_now=True) + + class Meta: + ordering = ["modified_at"] diff --git a/server/traindataset_association/permissions.py b/server/traindataset_association/permissions.py new file mode 100644 index 000000000..898122730 --- /dev/null +++ b/server/traindataset_association/permissions.py @@ -0,0 +1,54 @@ +from rest_framework.permissions import BasePermission +from training.models import TrainingExperiment +from dataset.models import Dataset + + +class IsAdmin(BasePermission): + def has_permission(self, request, view): + return request.user.is_superuser + + +class IsDatasetOwner(BasePermission): + def get_object(self, pk): + try: + return Dataset.objects.get(pk=pk) + except Dataset.DoesNotExist: + return None + + def has_permission(self, request, view): + if request.method == "POST": + pk = request.data.get("dataset", None) + else: + pk = view.kwargs.get("pk", None) + if not pk: + return False + dataset = self.get_object(pk) + if not dataset: + return False + if dataset.owner.id == request.user.id: + return True + else: + return False + + +class IsExpOwner(BasePermission): + def get_object(self, pk): + try: + return TrainingExperiment.objects.get(pk=pk) + except TrainingExperiment.DoesNotExist: + return None + + def has_permission(self, request, view): + if request.method == "POST": + pk = request.data.get("training_exp", None) + else: + pk = view.kwargs.get("tid", None) + if not pk: + return False + training_experiment = self.get_object(pk) + if not training_experiment: + return False + if training_experiment.owner.id == request.user.id: + return True + else: + return False diff --git a/server/traindataset_association/serializers.py b/server/traindataset_association/serializers.py new file mode 100644 index 000000000..950b073bc --- /dev/null +++ b/server/traindataset_association/serializers.py @@ -0,0 +1,132 @@ +from rest_framework import serializers +from django.utils import timezone +from training.models import TrainingExperiment +from dataset.models import Dataset + +from .models import ExperimentDataset +from utils.associations import ( + validate_approval_status_on_creation, + validate_approval_status_on_update, +) + + +def is_approved_participant(training_exp, dataset): + # training_exp event status + event = training_exp.event + if not event or event.finished: + return + + # TODO: modify when we use dataset labels + # TODO: is there a cleaner way? We are making assumptions on the json field structure + participants_list = event.participants.values() + return dataset.owner.email in participants_list + + +class ExperimentDatasetListSerializer(serializers.ModelSerializer): + class Meta: + model = ExperimentDataset + read_only_fields = ["initiated_by", "approved_at"] + fields = "__all__" + + def validate(self, data): + tid = self.context["request"].data.get("training_exp") + dataset = self.context["request"].data.get("dataset") + approval_status = self.context["request"].data.get("approval_status", "PENDING") + + training_exp = TrainingExperiment.objects.get(pk=tid) + + # training_exp approval status + training_exp_approval_status = training_exp.approval_status + if training_exp_approval_status != "APPROVED": + raise serializers.ValidationError( + "Association requests can be made only on an approved training experiment" + ) + + # dataset state + dataset_obj = Dataset.objects.get(pk=dataset) + dataset_state = dataset_obj.state + if dataset_state != "OPERATION": + raise serializers.ValidationError( + "Association requests can be made only on an operational dataset" + ) + + # dataset prep mlcube + if dataset_obj.data_preparation_mlcube != training_exp.data_preparation_mlcube: + raise serializers.ValidationError( + "Dataset association request can be made only if the dataset" + " was prepared with the training experiment's data preparation MLCube" + ) + + # approval status + last_training_expdataset = ( + ExperimentDataset.objects.filter(training_exp__id=tid, dataset__id=dataset) + .order_by("-created_at") + .first() + ) + validate_approval_status_on_creation(last_training_expdataset, approval_status) + + return data + + def create(self, validated_data): + approval_status = validated_data.get("approval_status", "PENDING") + if approval_status != "PENDING": + validated_data["approved_at"] = timezone.now() + else: + same_owner = ( + validated_data["dataset"].owner.id + == validated_data["training_exp"].owner.id + ) + if same_owner or is_approved_participant( + validated_data["training_exp"], validated_data["dataset"] + ): + validated_data["approval_status"] = "APPROVED" + validated_data["approved_at"] = timezone.now() + return ExperimentDataset.objects.create(**validated_data) + + +class DatasetApprovalSerializer(serializers.ModelSerializer): + class Meta: + model = ExperimentDataset + read_only_fields = ["initiated_by", "approved_at"] + fields = [ + "approval_status", + "initiated_by", + "approved_at", + "created_at", + "modified_at", + ] + + def validate(self, data): + if not self.instance: + raise serializers.ValidationError("No dataset association found") + return data + + def validate_approval_status(self, cur_approval_status): + last_approval_status = self.instance.approval_status + initiated_user = self.instance.initiated_by + current_user = self.context["request"].user + validate_approval_status_on_update( + last_approval_status, cur_approval_status, initiated_user, current_user + ) + + event = self.instance.training_exp.event + if event and not event.finished: + raise serializers.ValidationError( + "User cannot approve or reject an association when the experiment is ongoing" + ) + return cur_approval_status + + def update(self, instance, validated_data): + if "approval_status" in validated_data: + if validated_data["approval_status"] != instance.approval_status: + instance.approval_status = validated_data["approval_status"] + if instance.approval_status != "PENDING": + instance.approved_at = timezone.now() + instance.save() + return instance + + +class TrainingExperimentListofDatasetsSerializer(serializers.ModelSerializer): + class Meta: + model = ExperimentDataset + fields = ["dataset", "approval_status", "created_at"] diff --git a/server/traindataset_association/views.py b/server/traindataset_association/views.py new file mode 100644 index 000000000..729e298f9 --- /dev/null +++ b/server/traindataset_association/views.py @@ -0,0 +1,99 @@ +from .models import ExperimentDataset +from django.http import Http404 +from rest_framework.generics import GenericAPIView +from rest_framework.response import Response +from rest_framework import status +from drf_spectacular.utils import extend_schema + +from .permissions import IsAdmin, IsDatasetOwner, IsExpOwner +from .serializers import ( + ExperimentDatasetListSerializer, + DatasetApprovalSerializer, +) + + +class ExperimentDatasetList(GenericAPIView): + permission_classes = [IsAdmin | IsDatasetOwner] + serializer_class = ExperimentDatasetListSerializer + queryset = "" + + def post(self, request, format=None): + """ + Associate a dataset to a training_exp + """ + serializer = ExperimentDatasetListSerializer( + data=request.data, context={"request": request} + ) + if serializer.is_valid(): + serializer.save(initiated_by=request.user) + return Response(serializer.data, status=status.HTTP_201_CREATED) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + +class DatasetExperimentList(GenericAPIView): + serializer_class = ExperimentDatasetListSerializer + queryset = "" + + def get_object(self, pk): + try: + return ExperimentDataset.objects.filter(dataset__id=pk) + except ExperimentDataset.DoesNotExist: + raise Http404 + + @extend_schema(operation_id="datasets_experiments_retrieve_all") + def get(self, request, pk, format=None): + """ + Retrieve all experiments associated with a dataset + """ + training_expdataset = self.get_object(pk) + training_expdataset = self.paginate_queryset(training_expdataset) + serializer = ExperimentDatasetListSerializer(training_expdataset, many=True) + return self.get_paginated_response(serializer.data) + + +class DatasetApproval(GenericAPIView): + serializer_class = DatasetApprovalSerializer + queryset = "" + + def get_permissions(self): + self.permission_classes = [IsAdmin | IsExpOwner | IsDatasetOwner] + if self.request.method == "DELETE": + self.permission_classes = [IsAdmin] + return super(self.__class__, self).get_permissions() + + def get_object(self, dataset_id, training_exp_id): + try: + return ExperimentDataset.objects.filter( + dataset__id=dataset_id, training_exp__id=training_exp_id + ) + except ExperimentDataset.DoesNotExist: + raise Http404 + + def get(self, request, pk, tid, format=None): + """ + Retrieve approval status of training_exp dataset associations + """ + training_expdataset = self.get_object(pk, tid) + serializer = DatasetApprovalSerializer(training_expdataset, many=True) + return Response(serializer.data) + + def put(self, request, pk, tid, format=None): + """ + Update approval status of the last training_exp dataset association + """ + training_expdataset = self.get_object(pk, tid).order_by("-created_at").first() + serializer = DatasetApprovalSerializer( + training_expdataset, data=request.data, context={"request": request} + ) + if serializer.is_valid(): + serializer.save() + return Response(serializer.data) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + def delete(self, request, pk, tid, format=None): + """ + Delete a training_exp dataset association + """ + training_expdataset = self.get_object(pk, tid) + training_expdataset.delete() + return Response(status=status.HTTP_204_NO_CONTENT) diff --git a/server/training/__init__.py b/server/training/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/server/training/admin.py b/server/training/admin.py new file mode 100644 index 000000000..27dc93261 --- /dev/null +++ b/server/training/admin.py @@ -0,0 +1,10 @@ +from django.contrib import admin + +from .models import TrainingExperiment + + +class TrainingExperimentAdmin(admin.ModelAdmin): + list_display = [field.name for field in TrainingExperiment._meta.fields] + + +admin.site.register(TrainingExperiment, TrainingExperimentAdmin) diff --git a/server/training/apps.py b/server/training/apps.py new file mode 100644 index 000000000..8051e6caf --- /dev/null +++ b/server/training/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class TrainingConfig(AppConfig): + default_auto_field = 'django.db.models.BigAutoField' + name = 'training' diff --git a/server/training/migrations/0001_initial.py b/server/training/migrations/0001_initial.py new file mode 100644 index 000000000..4f5b65dd5 --- /dev/null +++ b/server/training/migrations/0001_initial.py @@ -0,0 +1,97 @@ +# Generated by Django 4.2.11 on 2024-04-29 13:21 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ("mlcube", "0002_alter_mlcube_unique_together"), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.CreateModel( + name="TrainingExperiment", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("name", models.CharField(max_length=20, unique=True)), + ("description", models.CharField(blank=True, max_length=100)), + ("docs_url", models.CharField(blank=True, max_length=100)), + ("demo_dataset_tarball_url", models.CharField(max_length=256)), + ("demo_dataset_tarball_hash", models.CharField(max_length=100)), + ("demo_dataset_generated_uid", models.CharField(max_length=128)), + ("metadata", models.JSONField(blank=True, default=dict, null=True)), + ( + "state", + models.CharField( + choices=[ + ("DEVELOPMENT", "DEVELOPMENT"), + ("OPERATION", "OPERATION"), + ], + default="DEVELOPMENT", + max_length=100, + ), + ), + ("is_valid", models.BooleanField(default=True)), + ( + "approval_status", + models.CharField( + choices=[ + ("PENDING", "PENDING"), + ("APPROVED", "APPROVED"), + ("REJECTED", "REJECTED"), + ], + default="PENDING", + max_length=100, + ), + ), + ("plan", models.JSONField(blank=True, null=True)), + ( + "user_metadata", + models.JSONField(blank=True, default=dict, null=True), + ), + ("approved_at", models.DateTimeField(blank=True, null=True)), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("modified_at", models.DateTimeField(auto_now=True)), + ( + "data_preparation_mlcube", + models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, + related_name="training_exp", + to="mlcube.mlcube", + ), + ), + ( + "fl_mlcube", + models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, + related_name="fl_mlcube", + to="mlcube.mlcube", + ), + ), + ( + "owner", + models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "ordering": ["modified_at"], + }, + ), + ] diff --git a/server/training/migrations/0002_trainingexperiment_fl_admin_mlcube.py b/server/training/migrations/0002_trainingexperiment_fl_admin_mlcube.py new file mode 100644 index 000000000..aebbbe3ce --- /dev/null +++ b/server/training/migrations/0002_trainingexperiment_fl_admin_mlcube.py @@ -0,0 +1,26 @@ +# Generated by Django 4.2.11 on 2024-07-28 22:15 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ("mlcube", "0002_alter_mlcube_unique_together"), + ("training", "0001_initial"), + ] + + operations = [ + migrations.AddField( + model_name="trainingexperiment", + name="fl_admin_mlcube", + field=models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.PROTECT, + related_name="fl_admin_mlcube", + to="mlcube.mlcube", + ), + ), + ] diff --git a/server/training/migrations/__init__.py b/server/training/migrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/server/training/models.py b/server/training/models.py new file mode 100644 index 000000000..db1e92c72 --- /dev/null +++ b/server/training/models.py @@ -0,0 +1,78 @@ +from django.db import models +from django.contrib.auth import get_user_model + +User = get_user_model() + + +class TrainingExperiment(models.Model): + EXP_STATUS = ( + ("PENDING", "PENDING"), + ("APPROVED", "APPROVED"), + ("REJECTED", "REJECTED"), + ) + STATES = ( + ("DEVELOPMENT", "DEVELOPMENT"), + ("OPERATION", "OPERATION"), + ) + + name = models.CharField(max_length=20, unique=True) + description = models.CharField(max_length=100, blank=True) + docs_url = models.CharField(max_length=100, blank=True) + owner = models.ForeignKey(User, on_delete=models.PROTECT) + demo_dataset_tarball_url = models.CharField(max_length=256) + demo_dataset_tarball_hash = models.CharField(max_length=100) + demo_dataset_generated_uid = models.CharField(max_length=128) + data_preparation_mlcube = models.ForeignKey( + "mlcube.MlCube", + on_delete=models.PROTECT, + related_name="training_exp", + ) + fl_mlcube = models.ForeignKey( + "mlcube.MlCube", + on_delete=models.PROTECT, + related_name="fl_mlcube", + ) + fl_admin_mlcube = models.ForeignKey( + "mlcube.MlCube", + on_delete=models.PROTECT, + related_name="fl_admin_mlcube", + blank=True, + null=True, + ) + + metadata = models.JSONField(default=dict, blank=True, null=True) + state = models.CharField(choices=STATES, max_length=100, default="DEVELOPMENT") + is_valid = models.BooleanField(default=True) + approval_status = models.CharField( + choices=EXP_STATUS, max_length=100, default="PENDING" + ) + plan = models.JSONField(blank=True, null=True) + + user_metadata = models.JSONField(default=dict, blank=True, null=True) + approved_at = models.DateTimeField(null=True, blank=True) + created_at = models.DateTimeField(auto_now_add=True) + modified_at = models.DateTimeField(auto_now=True) + + def __str__(self): + return self.name + + @property + def event(self): + return self.events.all().order_by("created_at").last() + + @property + def aggregator(self): + aggregator_assoc = ( + self.experimentaggregator_set.all().order_by("created_at").last() + ) + if aggregator_assoc and aggregator_assoc.approval_status == "APPROVED": + return aggregator_assoc.aggregator + + @property + def ca(self): + ca_assoc = self.experimentca_set.all().order_by("created_at").last() + if ca_assoc and ca_assoc.approval_status == "APPROVED": + return ca_assoc.ca + + class Meta: + ordering = ["modified_at"] diff --git a/server/training/permissions.py b/server/training/permissions.py new file mode 100644 index 000000000..7576d489a --- /dev/null +++ b/server/training/permissions.py @@ -0,0 +1,82 @@ +from rest_framework.permissions import BasePermission +from .models import TrainingExperiment +from traindataset_association.models import ExperimentDataset +from django.db.models import OuterRef, Subquery + + +class IsAdmin(BasePermission): + def has_permission(self, request, view): + return request.user.is_superuser + + +class IsExpOwner(BasePermission): + def get_object(self, pk): + try: + return TrainingExperiment.objects.get(pk=pk) + except TrainingExperiment.DoesNotExist: + return None + + def has_permission(self, request, view): + pk = view.kwargs.get("pk", None) + if not pk: + return False + training_exp = self.get_object(pk) + if not training_exp: + return False + if training_exp.owner.id == request.user.id: + return True + else: + return False + + +# TODO: check effciency / database costs +class IsAssociatedDatasetOwner(BasePermission): + def has_permission(self, request, view): + pk = view.kwargs.get("pk", None) + if not pk: + return False + + if not request.user.is_authenticated: + # This check is to prevent internal server error + # since user.dataset_set is used below + return False + + latest_datasets_assocs_status = ( + ExperimentDataset.objects.all() + .filter(training_exp__id=pk, dataset__id=OuterRef("id")) + .order_by("-created_at")[:1] + .values("approval_status") + ) + + user_associated_datasets = ( + request.user.dataset_set.all() + .annotate(assoc_status=Subquery(latest_datasets_assocs_status)) + .filter(assoc_status="APPROVED") + ) + + if user_associated_datasets.exists(): + return True + else: + return False + + +class IsAggregatorOwner(BasePermission): + def has_permission(self, request, view): + pk = view.kwargs.get("pk", None) + if not pk: + return False + + if not request.user.is_authenticated: + # This check is to prevent internal server error + # since user.dataset_set is used below + return False + + training_exp = TrainingExperiment.objects.get(pk=pk) + aggregator = training_exp.aggregator + if not aggregator: + return False + + if aggregator.owner.id == request.user.id: + return True + else: + return False diff --git a/server/training/serializers.py b/server/training/serializers.py new file mode 100644 index 000000000..b9df8a5c9 --- /dev/null +++ b/server/training/serializers.py @@ -0,0 +1,98 @@ +from rest_framework import serializers +from django.utils import timezone +from .models import TrainingExperiment + + +class WriteTrainingExperimentSerializer(serializers.ModelSerializer): + class Meta: + model = TrainingExperiment + fields = "__all__" + read_only_fields = ["owner", "approved_at", "approval_status"] + + def validate(self, data): + owner = self.context["request"].user + pending_experiments = TrainingExperiment.objects.filter( + owner=owner, approval_status="PENDING" + ) + if len(pending_experiments) > 0: + raise serializers.ValidationError( + "User can own at most one pending experiment" + ) + + if "state" in data and data["state"] == "OPERATION": + dev_mlcubes = [ + data["data_preparation_mlcube"].state == "DEVELOPMENT", + data["fl_mlcube"].state == "DEVELOPMENT", + ] + if any(dev_mlcubes): + raise serializers.ValidationError( + "User cannot mark an experiment as operational" + " if its MLCubes are not operational" + ) + + return data + + +class ReadTrainingExperimentSerializer(serializers.ModelSerializer): + class Meta: + model = TrainingExperiment + read_only_fields = ["owner", "approved_at"] + fields = "__all__" + + def update(self, instance, validated_data): + if "approval_status" in validated_data: + if validated_data["approval_status"] != instance.approval_status: + instance.approval_status = validated_data["approval_status"] + if instance.approval_status != "PENDING": + instance.approved_at = timezone.now() + validated_data.pop("approval_status", None) + for k, v in validated_data.items(): + setattr(instance, k, v) + instance.save() + return instance + + def validate_approval_status(self, approval_status): + if approval_status == "PENDING": + raise serializers.ValidationError( + "User can only approve or reject an experiment" + ) + if approval_status == "APPROVED": + if self.instance.approval_status == "REJECTED": + raise serializers.ValidationError( + "User can approve only a pending request" + ) + return approval_status + + def validate_state(self, state): + if state == "OPERATION" and self.instance.state != "OPERATION": + dev_mlcubes = [ + self.instance.data_preparation_mlcube.state == "DEVELOPMENT", + self.instance.fl_mlcube.state == "DEVELOPMENT", + ] + if any(dev_mlcubes): + raise serializers.ValidationError( + "User cannot mark an experiment as operational" + " if its MLCubes are not operational" + ) + return state + + def validate(self, data): + event = self.instance.event + if event and not event.finished: + raise serializers.ValidationError( + "User cannot update an experiment with ongoing event" + ) + if self.instance.state == "OPERATION": + editable_fields = [ + "is_valid", + "user_metadata", + "approval_status", + "demo_dataset_tarball_url", + ] + for k, v in data.items(): + if k not in editable_fields: + if v != getattr(self.instance, k): + raise serializers.ValidationError( + "User cannot update non editable fields in Operation mode" + ) + return data diff --git a/server/training/urls.py b/server/training/urls.py new file mode 100644 index 000000000..054da57fd --- /dev/null +++ b/server/training/urls.py @@ -0,0 +1,15 @@ +from django.urls import path, include +from . import views + +app_name = "training" + +urlpatterns = [ + path("", views.TrainingExperimentList.as_view()), + path("/", views.TrainingExperimentDetail.as_view()), + path("/datasets/", views.TrainingDatasetList.as_view()), + path("/aggregator/", views.TrainingAggregator.as_view()), + path("/ca/", views.TrainingCA.as_view()), + path("/event/", views.GetTrainingEvent.as_view()), + path("/participants_info/", views.ParticipantsInfo.as_view()), + path("events/", include("trainingevent.urls", namespace=app_name), name="event"), +] diff --git a/server/training/views.py b/server/training/views.py new file mode 100644 index 000000000..f28e86cd8 --- /dev/null +++ b/server/training/views.py @@ -0,0 +1,242 @@ +from aggregator.serializers import ( + AggregatorSerializer, +) +from traindataset_association.serializers import ( + TrainingExperimentListofDatasetsSerializer, +) +from ca.serializers import CASerializer +from trainingevent.serializers import EventDetailSerializer +from dataset.serializers import DatasetWithOwnerInfoSerializer +from django.http import Http404 +from rest_framework.generics import GenericAPIView +from rest_framework.response import Response +from rest_framework import status +from drf_spectacular.utils import extend_schema + +from django.db.models import OuterRef, Subquery +from django.contrib.auth import get_user_model +from dataset.models import Dataset +from .models import TrainingExperiment +from .serializers import ( + WriteTrainingExperimentSerializer, + ReadTrainingExperimentSerializer, +) +from .permissions import ( + IsAdmin, + IsExpOwner, + IsAssociatedDatasetOwner, + IsAggregatorOwner, +) + +User = get_user_model() + + +class TrainingExperimentList(GenericAPIView): + serializer_class = WriteTrainingExperimentSerializer + queryset = "" + + @extend_schema(operation_id="training_retrieve_all") + def get(self, request, format=None): + """ + List all training experiments + """ + training_exps = TrainingExperiment.objects.all() + training_exps = self.paginate_queryset(training_exps) + serializer = WriteTrainingExperimentSerializer(training_exps, many=True) + return self.get_paginated_response(serializer.data) + + def post(self, request, format=None): + """ + Create a new TrainingExperiment + """ + serializer = WriteTrainingExperimentSerializer( + data=request.data, context={"request": request} + ) + if serializer.is_valid(): + serializer.save(owner=request.user) + return Response(serializer.data, status=status.HTTP_201_CREATED) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + +class TrainingAggregator(GenericAPIView): + permission_classes = [ + IsAdmin | IsExpOwner | IsAssociatedDatasetOwner | IsAggregatorOwner + ] + serializer_class = AggregatorSerializer + queryset = "" + + def get_object(self, pk): + try: + training_exp = TrainingExperiment.objects.get(pk=pk) + except TrainingExperiment.DoesNotExist: + raise Http404 + + aggregator = training_exp.aggregator + if not aggregator: + raise Http404 + return aggregator + + def get(self, request, pk, format=None): + """ + Retrieve the aggregator associated with a training exp instance. + """ + aggregator = self.get_object(pk) + serializer = AggregatorSerializer(aggregator) + return Response(serializer.data) + + +class TrainingDatasetList(GenericAPIView): + permission_classes = [IsAdmin | IsExpOwner] + serializer_class = TrainingExperimentListofDatasetsSerializer + queryset = "" + + def get_object(self, pk): + try: + return TrainingExperiment.objects.get(pk=pk) + except TrainingExperiment.DoesNotExist: + raise Http404 + + def get(self, request, pk, format=None): + """ + Retrieve datasets associated with a training experiment instance. + """ + training_exp = self.get_object(pk) + datasets = training_exp.experimentdataset_set.all() + datasets = self.paginate_queryset(datasets) + serializer = TrainingExperimentListofDatasetsSerializer(datasets, many=True) + return self.get_paginated_response(serializer.data) + + +class TrainingCA(GenericAPIView): + # permission_classes = [ + # IsAdmin | IsExpOwner | IsAssociatedDatasetOwner | IsAggregatorOwner + # ] + serializer_class = CASerializer + queryset = "" + + def get_object(self, pk): + try: + training_exp = TrainingExperiment.objects.get(pk=pk) + except TrainingExperiment.DoesNotExist: + raise Http404 + + ca = training_exp.ca + if not ca: + raise Http404 + return ca + + def get(self, request, pk, format=None): + """ + Retrieve CA associated with a training experiment instance. + """ + ca = self.get_object(pk) + serializer = CASerializer(ca) + return Response(serializer.data) + + +class GetTrainingEvent(GenericAPIView): + permission_classes = [ + IsAdmin | IsExpOwner | IsAssociatedDatasetOwner | IsAggregatorOwner + ] + serializer_class = EventDetailSerializer + queryset = "" + + def get_object(self, pk): + try: + training_exp = TrainingExperiment.objects.get(pk=pk) + except TrainingExperiment.DoesNotExist: + raise Http404 + + event = training_exp.event + if not event: + raise Http404 + return event + + def get(self, request, pk, format=None): + """ + Retrieve latest event of a training experiment instance. + """ + event = self.get_object(pk) + serializer = EventDetailSerializer(event) + return Response(serializer.data) + + +class TrainingExperimentDetail(GenericAPIView): + serializer_class = ReadTrainingExperimentSerializer + queryset = "" + + def get_permissions(self): + if self.request.method == "PUT": + self.permission_classes = [IsAdmin | IsExpOwner] + if "approval_status" in self.request.data: + self.permission_classes = [IsAdmin] + elif self.request.method == "DELETE": + self.permission_classes = [IsAdmin] + return super(self.__class__, self).get_permissions() + + def get_object(self, pk): + try: + return TrainingExperiment.objects.get(pk=pk) + except TrainingExperiment.DoesNotExist: + raise Http404 + + def get(self, request, pk, format=None): + """ + Retrieve a TrainingExperiment instance. + """ + training_exp = self.get_object(pk) + serializer = ReadTrainingExperimentSerializer(training_exp) + return Response(serializer.data) + + def put(self, request, pk, format=None): + """ + Update a TrainingExperiment instance. + """ + training_exp = self.get_object(pk) + serializer = ReadTrainingExperimentSerializer( + training_exp, data=request.data, partial=True + ) + if serializer.is_valid(): + serializer.save() + return Response(serializer.data) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + def delete(self, request, pk, format=None): + """ + Delete a training experiment instance. + """ + training_exp = self.get_object(pk) + training_exp.delete() + return Response(status=status.HTTP_204_NO_CONTENT) + + +class ParticipantsInfo(GenericAPIView): + permission_classes = [IsAdmin | IsExpOwner] + serializer_class = TrainingExperimentListofDatasetsSerializer + queryset = "" + + def get_object(self, pk): + try: + return TrainingExperiment.objects.get(pk=pk) + except TrainingExperiment.DoesNotExist: + raise Http404 + + def get(self, request, pk, format=None): + """ + Retrieve datasets associated with a training experiment instance. + """ + training_exp = self.get_object(pk) + latest_datasets_assocs_status = ( + training_exp.experimentdataset_set.all() + .filter(dataset__id=OuterRef("id")) + .order_by("-created_at")[:1] + .values("approval_status") + ) + datasets_with_users = ( + Dataset.objects.all() + .annotate(assoc_status=Subquery(latest_datasets_assocs_status)) + .filter(assoc_status="APPROVED") + ) + datasets_with_users = self.paginate_queryset(datasets_with_users) + serializer = DatasetWithOwnerInfoSerializer(datasets_with_users, many=True) + return self.get_paginated_response(serializer.data) diff --git a/server/trainingevent/__init__.py b/server/trainingevent/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/server/trainingevent/admin.py b/server/trainingevent/admin.py new file mode 100644 index 000000000..a6e5ca6af --- /dev/null +++ b/server/trainingevent/admin.py @@ -0,0 +1,7 @@ +from django.contrib import admin +from .models import TrainingEvent + + +@admin.register(TrainingEvent) +class TrainingEventAdmin(admin.ModelAdmin): + list_display = [field.name for field in TrainingEvent._meta.fields] diff --git a/server/trainingevent/apps.py b/server/trainingevent/apps.py new file mode 100644 index 000000000..9d1295a0f --- /dev/null +++ b/server/trainingevent/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class TrainingeventConfig(AppConfig): + default_auto_field = "django.db.models.BigAutoField" + name = "trainingevent" diff --git a/server/trainingevent/migrations/0001_initial.py b/server/trainingevent/migrations/0001_initial.py new file mode 100644 index 000000000..3aaa8d673 --- /dev/null +++ b/server/trainingevent/migrations/0001_initial.py @@ -0,0 +1,58 @@ +# Generated by Django 4.2.11 on 2024-04-29 13:21 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ("training", "0001_initial"), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.CreateModel( + name="TrainingEvent", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("name", models.CharField(max_length=20, unique=True)), + ("is_valid", models.BooleanField(default=True)), + ("finished", models.BooleanField(default=False)), + ("participants", models.JSONField()), + ("report", models.JSONField(blank=True, null=True)), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("modified_at", models.DateTimeField(auto_now=True)), + ("finished_at", models.DateTimeField(blank=True, null=True)), + ( + "owner", + models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, + to=settings.AUTH_USER_MODEL, + ), + ), + ( + "training_exp", + models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, + related_name="events", + to="training.trainingexperiment", + ), + ), + ], + options={ + "ordering": ["created_at"], + }, + ), + ] diff --git a/server/trainingevent/migrations/__init__.py b/server/trainingevent/migrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/server/trainingevent/models.py b/server/trainingevent/models.py new file mode 100644 index 000000000..6c7b6bc7a --- /dev/null +++ b/server/trainingevent/models.py @@ -0,0 +1,24 @@ +from django.db import models +from training.models import TrainingExperiment +from django.contrib.auth import get_user_model + +User = get_user_model() + + +# Create your models here. +class TrainingEvent(models.Model): + name = models.CharField(max_length=20, unique=True) + owner = models.ForeignKey(User, on_delete=models.PROTECT) + is_valid = models.BooleanField(default=True) + finished = models.BooleanField(default=False) + training_exp = models.ForeignKey( + TrainingExperiment, on_delete=models.PROTECT, related_name="events" + ) + participants = models.JSONField() + report = models.JSONField(blank=True, null=True) + created_at = models.DateTimeField(auto_now_add=True) + modified_at = models.DateTimeField(auto_now=True) + finished_at = models.DateTimeField(null=True, blank=True) + + class Meta: + ordering = ["created_at"] diff --git a/server/trainingevent/permissions.py b/server/trainingevent/permissions.py new file mode 100644 index 000000000..1bb0c4b0d --- /dev/null +++ b/server/trainingevent/permissions.py @@ -0,0 +1,67 @@ +from rest_framework.permissions import BasePermission +from .models import TrainingEvent, TrainingExperiment + + +class IsAdmin(BasePermission): + def has_permission(self, request, view): + return request.user.is_superuser + + +class IsExpOwner(BasePermission): + def get_object(self, tid): + try: + return TrainingExperiment.objects.get(pk=tid) + except TrainingExperiment.DoesNotExist: + return None + + def get_event_object(self, pk): + try: + return TrainingEvent.objects.get(pk=pk) + except TrainingEvent.DoesNotExist: + return None + + def has_permission(self, request, view): + if request.method == "POST": + tid = request.data.get("training_exp", None) + else: + pk = view.kwargs.get("pk", None) + if not pk: + return False + event = self.get_event_object(pk) + if not event: + return False + tid = event.training_exp.id + + if not tid: + return False + training_exp = self.get_object(tid) + if not training_exp: + return False + if training_exp.owner.id == request.user.id: + return True + else: + return False + + +class IsAggregatorOwner(BasePermission): + def get_object(self, pk): + try: + return TrainingEvent.objects.get(pk=pk) + except TrainingEvent.DoesNotExist: + return None + + def has_permission(self, request, view): + pk = view.kwargs.get("pk", None) + if not pk: + return False + event = self.get_object(pk) + if not event: + return False + aggregator = event.training_exp.aggregator + if not aggregator: + return False + + if aggregator.owner.id == request.user.id: + return True + else: + return False diff --git a/server/trainingevent/serializers.py b/server/trainingevent/serializers.py new file mode 100644 index 000000000..863f83688 --- /dev/null +++ b/server/trainingevent/serializers.py @@ -0,0 +1,61 @@ +from rest_framework import serializers +from .models import TrainingEvent +from django.utils import timezone + + +class EventSerializer(serializers.ModelSerializer): + class Meta: + model = TrainingEvent + fields = "__all__" + read_only_fields = ["finished", "finished_at", "report", "owner"] + + def validate(self, data): + training_exp = data["training_exp"] + if training_exp.approval_status != "APPROVED": + raise serializers.ValidationError( + "User cannot create an event unless the experiment is approved" + ) + prev_event = training_exp.event + if prev_event and not training_exp.event.finished: + raise serializers.ValidationError( + "User cannot create a new event unless the previous event has finished" + ) + aggregator = training_exp.aggregator + if not aggregator: + raise serializers.ValidationError( + "User cannot create a new event if the experiment has no aggregator" + ) + plan = training_exp.plan + if plan is None: + raise serializers.ValidationError( + "User cannot create a new event if the experiment has no plan" + ) + + return data + + +class EventDetailSerializer(serializers.ModelSerializer): + class Meta: + model = TrainingEvent + fields = "__all__" + read_only_fields = [ + "finished_at", + "training_exp", + "participants", + "finished", + "owner", + "name", + ] + + def validate(self, data): + if self.instance.finished: + raise serializers.ValidationError("User cannot edit a finished event") + return data + + def update(self, instance, validated_data): + if "report" in validated_data: + instance.report = validated_data["report"] + instance.finished = True + instance.finished_at = timezone.now() + instance.save() + return instance diff --git a/server/trainingevent/tests.py b/server/trainingevent/tests.py new file mode 100644 index 000000000..7ce503c2d --- /dev/null +++ b/server/trainingevent/tests.py @@ -0,0 +1,3 @@ +from django.test import TestCase + +# Create your tests here. diff --git a/server/trainingevent/urls.py b/server/trainingevent/urls.py new file mode 100644 index 000000000..e23cd188c --- /dev/null +++ b/server/trainingevent/urls.py @@ -0,0 +1,9 @@ +from django.urls import path +from . import views + +app_name = "events" + +urlpatterns = [ + path("", views.EventList.as_view()), + path("/", views.EventDetail.as_view()), +] diff --git a/server/trainingevent/views.py b/server/trainingevent/views.py new file mode 100644 index 000000000..f73a9d28c --- /dev/null +++ b/server/trainingevent/views.py @@ -0,0 +1,68 @@ +from .models import TrainingEvent +from django.http import Http404 +from rest_framework.generics import GenericAPIView +from rest_framework.response import Response +from rest_framework import status + +from .permissions import IsAdmin, IsExpOwner, IsAggregatorOwner +from .serializers import EventSerializer, EventDetailSerializer + + +class EventList(GenericAPIView): + permission_classes = [IsAdmin | IsExpOwner] + serializer_class = EventSerializer + queryset = "" + + def post(self, request, format=None): + """ + Create an event for an experiment + """ + serializer = EventSerializer(data=request.data) + if serializer.is_valid(): + serializer.save(owner=request.user) + return Response(serializer.data, status=status.HTTP_201_CREATED) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + def get(self, request, format=None): + """ + get all events + """ + events = TrainingEvent.objects.all() + events = self.paginate_queryset(events) + serializer = EventSerializer(events, many=True) + return self.get_paginated_response(serializer.data) + + +class EventDetail(GenericAPIView): + serializer_class = EventDetailSerializer + queryset = "" + + def get_permissions(self): + if self.request.method == "PUT": + self.permission_classes = [IsAdmin | IsExpOwner | IsAggregatorOwner] + return super(self.__class__, self).get_permissions() + + def get_object(self, pk): + try: + return TrainingEvent.objects.get(pk=pk) + except TrainingEvent.DoesNotExist: + raise Http404 + + def get(self, request, pk, format=None): + """ + Retrieve an event + """ + event = self.get_object(pk) + serializer = EventDetailSerializer(event) + return Response(serializer.data) + + def put(self, request, pk, format=None): + """ + Update an event + """ + event = self.get_object(pk) + serializer = EventDetailSerializer(event, data=request.data) + if serializer.is_valid(): + serializer.save() + return Response(serializer.data) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) diff --git a/server/utils/associations.py b/server/utils/associations.py new file mode 100644 index 000000000..54014a82f --- /dev/null +++ b/server/utils/associations.py @@ -0,0 +1,39 @@ +from rest_framework import serializers + + +def validate_approval_status_on_creation(last_association, approval_status): + if not last_association: + if approval_status != "PENDING": + raise serializers.ValidationError( + "User can approve or reject association request only if there are prior requests" + ) + else: + if approval_status == "PENDING": + if last_association.approval_status != "REJECTED": + raise serializers.ValidationError( + "User can create a new request only if prior request is rejected" + ) + elif approval_status == "APPROVED": + raise serializers.ValidationError( + "User cannot create an approved association request" + ) + # approval_status == "REJECTED": + else: + if last_association.approval_status != "APPROVED": + raise serializers.ValidationError( + "User can reject request only if prior request is approved" + ) + + +def validate_approval_status_on_update( + last_approval_status, cur_approval_status, initiated_user, current_user +): + if last_approval_status != "PENDING": + raise serializers.ValidationError( + "User can approve or reject only a pending request" + ) + if cur_approval_status == "APPROVED": + if current_user.id == initiated_user.id: + raise serializers.ValidationError( + "Same user cannot approve the association request" + ) diff --git a/server/utils/urls.py b/server/utils/urls.py index 736505ae6..47f3c35c3 100644 --- a/server/utils/urls.py +++ b/server/utils/urls.py @@ -9,6 +9,18 @@ path("datasets/", views.DatasetList.as_view()), path("mlcubes/", views.MlCubeList.as_view()), path("results/", views.ModelResultList.as_view()), + path("training/", views.TrainingExperimentList.as_view()), + path("aggregators/", views.AggregatorList.as_view()), + path("training/events/", views.TrainingEventList.as_view()), + path("cas/", views.CAList.as_view()), path("datasets/associations/", views.DatasetAssociationList.as_view()), path("mlcubes/associations/", views.MlCubeAssociationList.as_view()), + path( + "datasets/training_associations/", + views.DatasetTrainingAssociationList.as_view(), + ), + path( + "aggregators/training_associations/", views.AggregatorAssociationList.as_view() + ), + path("cas/training_associations/", views.CAAssociationList.as_view()), ] diff --git a/server/utils/views.py b/server/utils/views.py index 95d12961b..d80303adc 100644 --- a/server/utils/views.py +++ b/server/utils/views.py @@ -19,6 +19,20 @@ from rest_framework.permissions import AllowAny from drf_spectacular.utils import extend_schema, inline_serializer from rest_framework import serializers +from training.models import TrainingExperiment +from training.serializers import ReadTrainingExperimentSerializer +from aggregator.models import Aggregator +from aggregator.serializers import AggregatorSerializer +from traindataset_association.models import ExperimentDataset +from traindataset_association.serializers import ExperimentDatasetListSerializer +from aggregator_association.models import ExperimentAggregator +from aggregator_association.serializers import ExperimentAggregatorListSerializer +from ca_association.models import ExperimentCA +from ca_association.serializers import ExperimentCAListSerializer +from trainingevent.serializers import EventDetailSerializer +from ca.serializers import CASerializer +from trainingevent.models import TrainingEvent +from ca.models import CA class User(GenericAPIView): @@ -54,6 +68,86 @@ def get(self, request, format=None): return self.get_paginated_response(serializer.data) +class TrainingExperimentList(GenericAPIView): + serializer_class = ReadTrainingExperimentSerializer + queryset = "" + + def get_object(self, pk): + try: + return TrainingExperiment.objects.filter(owner__id=pk) + except TrainingExperiment.DoesNotExist: + raise Http404 + + def get(self, request, format=None): + """ + Retrieve all training_exps owned by the current user + """ + training_exps = self.get_object(request.user.id) + training_exps = self.paginate_queryset(training_exps) + serializer = ReadTrainingExperimentSerializer(training_exps, many=True) + return self.get_paginated_response(serializer.data) + + +class TrainingEventList(GenericAPIView): + serializer_class = EventDetailSerializer + queryset = "" + + def get_object(self, pk): + try: + return TrainingEvent.objects.filter(owner__id=pk) + except TrainingEvent.DoesNotExist: + raise Http404 + + def get(self, request, format=None): + """ + Retrieve all events owned by the current user + """ + training_events = self.get_object(request.user.id) + training_events = self.paginate_queryset(training_events) + serializer = EventDetailSerializer(training_events, many=True) + return self.get_paginated_response(serializer.data) + + +class AggregatorList(GenericAPIView): + serializer_class = AggregatorSerializer + queryset = "" + + def get_object(self, pk): + try: + return Aggregator.objects.filter(owner__id=pk) + except Aggregator.DoesNotExist: + raise Http404 + + def get(self, request, format=None): + """ + Retrieve all aggregators owned by the current user + """ + aggregators = self.get_object(request.user.id) + aggregators = self.paginate_queryset(aggregators) + serializer = AggregatorSerializer(aggregators, many=True) + return self.get_paginated_response(serializer.data) + + +class CAList(GenericAPIView): + serializer_class = CASerializer + queryset = "" + + def get_object(self, pk): + try: + return CA.objects.filter(owner__id=pk) + except CA.DoesNotExist: + raise Http404 + + def get(self, request, format=None): + """ + Retrieve all CAs owned by the current user + """ + cas = self.get_object(request.user.id) + cas = self.paginate_queryset(cas) + serializer = CASerializer(cas, many=True) + return self.get_paginated_response(serializer.data) + + class MlCubeList(GenericAPIView): serializer_class = MlCubeSerializer queryset = "" @@ -158,6 +252,73 @@ def get(self, request, format=None): return self.get_paginated_response(serializer.data) +class DatasetTrainingAssociationList(GenericAPIView): + serializer_class = ExperimentDatasetListSerializer + queryset = "" + + def get_object(self, pk): + try: + # TODO: this retrieves everything (not just latest ones) + return ExperimentDataset.objects.filter( + Q(dataset__owner__id=pk) | Q(training_exp__owner__id=pk) + ) + except ExperimentDataset.DoesNotExist: + raise Http404 + + def get(self, request, format=None): + """ + Retrieve all training dataset associations involving an asset of mine + """ + experiment_datasets = self.get_object(request.user.id) + experiment_datasets = self.paginate_queryset(experiment_datasets) + serializer = ExperimentDatasetListSerializer(experiment_datasets, many=True) + return self.get_paginated_response(serializer.data) + + +class AggregatorAssociationList(GenericAPIView): + serializer_class = ExperimentAggregatorListSerializer + queryset = "" + + def get_object(self, pk): + try: + return ExperimentAggregator.objects.filter( + Q(aggregator__owner__id=pk) | Q(training_exp__owner__id=pk) + ) + except ExperimentAggregator.DoesNotExist: + raise Http404 + + def get(self, request, format=None): + """ + Retrieve all aggregator associations involving an asset of mine + """ + experiment_aggs = self.get_object(request.user.id) + experiment_aggs = self.paginate_queryset(experiment_aggs) + serializer = ExperimentAggregatorListSerializer(experiment_aggs, many=True) + return self.get_paginated_response(serializer.data) + + +class CAAssociationList(GenericAPIView): + serializer_class = ExperimentCAListSerializer + queryset = "" + + def get_object(self, pk): + try: + return ExperimentCA.objects.filter( + Q(ca__owner__id=pk) | Q(training_exp__owner__id=pk) + ) + except ExperimentCA.DoesNotExist: + raise Http404 + + def get(self, request, format=None): + """ + Retrieve all ca associations involving an asset of mine + """ + experiment_cas = self.get_object(request.user.id) + experiment_cas = self.paginate_queryset(experiment_cas) + serializer = ExperimentCAListSerializer(experiment_cas, many=True) + return self.get_paginated_response(serializer.data) + + class ServerAPIVersion(GenericAPIView): permission_classes = (AllowAny,) queryset = ""