diff --git a/.github/workflows/openfl-test.yml b/.github/workflows/openfl-test.yml index 30e83d5fb..00bf95506 100644 --- a/.github/workflows/openfl-test.yml +++ b/.github/workflows/openfl-test.yml @@ -8,6 +8,10 @@ on: branches: [master] pull_request: {} +env: + # A workaround for long FQDN names provided by GitHub actions. + FQDN: "localhost" + jobs: openfl-test: runs-on: ubuntu-latest diff --git a/GANDLF/optimizers/__init__.py b/GANDLF/optimizers/__init__.py index 97de43fa1..b59afb22f 100644 --- a/GANDLF/optimizers/__init__.py +++ b/GANDLF/optimizers/__init__.py @@ -15,6 +15,8 @@ from .wrap_monai import novograd_wrapper +from .ademamix import ademamix_wrapper + global_optimizer_dict = { "sgd": sgd, "asgd": asgd, @@ -29,6 +31,7 @@ "radam": radam, "novograd": novograd_wrapper, "nadam": nadam, + "ademamix": ademamix_wrapper, } diff --git a/GANDLF/optimizers/ademamix.py b/GANDLF/optimizers/ademamix.py new file mode 100644 index 000000000..63f68d9f9 --- /dev/null +++ b/GANDLF/optimizers/ademamix.py @@ -0,0 +1,204 @@ +import math +from typing import Callable, Iterable, List, Optional, Tuple + +import torch +from torch import Tensor +from torch.optim import Optimizer + + +class AdEMAMix(Optimizer): + r"""Adapted from https://github.com/frgfm/Holocron/blob/main/holocron/optim/ademamix.py + + Implements the AdEMAMix optimizer from `"The AdEMAMix Optimizer: Better, Faster, Older" `_. + + The estimation of momentums is described as follows, :math:`\forall t \geq 1`: + + .. math:: + m_{1,t} \leftarrow \beta_1 m_{1, t-1} + (1 - \beta_1) g_t \\ + m_{2,t} \leftarrow \beta_3 m_{2, t-1} + (1 - \beta_3) g_t \\ + s_t \leftarrow \beta_2 s_{t-1} + (1 - \beta_2) (g_t - m_t)^2 + \epsilon + + where :math:`g_t` is the gradient of :math:`\theta_t`, + :math:`\beta_1, \beta_2, \beta_3 \in [0, 1]^3` are the exponential average smoothing coefficients, + :math:`m_{1,0} = 0,\ m_{2,0} = 0,\ s_0 = 0`, :math:`\epsilon > 0`. + + Then we correct their biases using: + + .. math:: + \hat{m_{1,t}} \leftarrow \frac{m_{1,t}}{1 - \beta_1^t} \\ + \hat{s_t} \leftarrow \frac{s_t}{1 - \beta_2^t} + + And finally the update step is performed using the following rule: + + .. math:: + \theta_t \leftarrow \theta_{t-1} - \eta \frac{\hat{m_{1,t}} + \alpha m_{2,t}}{\sqrt{\hat{s_t}} + \epsilon} + + where :math:`\theta_t` is the parameter value at step :math:`t` (:math:`\theta_0` being the initialization value), + :math:`\eta` is the learning rate, :math:`\alpha > 0` :math:`\epsilon > 0`. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups + lr (float, optional): learning rate + betas (Tuple[float, float, float], optional): coefficients used for running averages (default: (0.9, 0.999, 0.9999)) + alpha (float, optional): the exponential decay rate of the second moment estimates (default: 5.0) + eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsgrad (bool, optional): whether to use the AMSGrad variant (default: False) + """ + + def __init__( + self, + params: Iterable[torch.nn.Parameter], + lr: float = 1e-3, + betas: Tuple[float, float, float] = (0.9, 0.999, 0.9999), + alpha: float = 5.0, + eps: float = 1e-8, + weight_decay: float = 0.0, + ) -> None: + assert lr >= 0.0, f"Invalid learning rate: {lr}" + assert eps >= 0.0, f"Invalid epsilon value: {eps}" + assert all( + 0.0 <= beta < 1.0 for beta in betas + ), f"Invalid beta parameters: {betas}" + defaults = { + "lr": lr, + "betas": betas, + "alpha": alpha, + "eps": eps, + "weight_decay": weight_decay, + } + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: # type: ignore[override] + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avgs_slow = [] + exp_avg_sqs = [] + state_steps = [] + + for p in group["params"]: + if p.grad is not None: + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError( + f"{self.__class__.__name__} does not support sparse gradients" + ) + grads.append(p.grad) + + state = self.state[p] + # Lazy state initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + state["exp_avg_slow"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avgs.append(state["exp_avg"]) + exp_avgs_slow.append(state["exp_avg_slow"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + + # update the steps for each param group update + state["step"] += 1 + # record the step after step update + state_steps.append(state["step"]) + + beta1, beta2, beta3 = group["betas"] + _update_ademamix( + params_with_grad, + grads, + exp_avgs, + exp_avgs_slow, + exp_avg_sqs, + state_steps, + beta1, + beta2, + beta3, + group["alpha"], + group["lr"], + group["weight_decay"], + group["eps"], + ) + return loss + + +def _update_ademamix( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avgs_slow: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[int], + beta1: float, + beta2: float, + beta3: float, + alpha: float, + lr: float, + weight_decay: float, + eps: float, +) -> None: + r"""Functional API that performs AdaBelief algorithm computation. + See :class:`~holocron.optim.AdaBelief` for details. + """ + for i, param in enumerate(params): + grad = grads[i] + m1 = exp_avgs[i] + m2 = exp_avgs_slow[i] + nu = exp_avg_sqs[i] + step = state_steps[i] + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + if weight_decay != 0: + grad = grad.add(param, alpha=weight_decay) + + # Decay the first and second moment running average coefficient + m1.mul_(beta1).add_(grad, alpha=1 - beta1) + nu.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + m2.mul_(beta3).add_(grad, alpha=1 - beta3) + + denom = (nu.sqrt() / math.sqrt(bias_correction2)).add_(eps) + + param.addcdiv_(m1 / bias_correction1 + alpha * m2, denom, value=-lr) + + +def ademamix_wrapper(parameters: dict) -> torch.optim.Optimizer: + """ + Creates an AdEMAMix optimizer from the PyTorch `torch.optim` module using the input parameters. + + Args: + parameters (dict): A dictionary containing the input parameters for the optimizer. + + Returns: + torch.optim.Optimizer: An AdEMAMix optimizer. + """ + + return AdEMAMix( + params=parameters["model_parameters"], + lr=parameters.get("learning_rate", 1e-3), + betas=parameters.get("betas", (0.9, 0.999, 0.9999)), + alpha=parameters.get("alpha", 5.0), + eps=parameters.get("eps", 1e-8), + weight_decay=parameters.get("weight_decay", 0.0), + ) diff --git a/GANDLF/optimizers/wrap_monai.py b/GANDLF/optimizers/wrap_monai.py index 23745e4a5..221ba57bd 100644 --- a/GANDLF/optimizers/wrap_monai.py +++ b/GANDLF/optimizers/wrap_monai.py @@ -1,10 +1,11 @@ +import monai from monai.optimizers import Novograd -def novograd_wrapper(parameters): +def novograd_wrapper(parameters: dict) -> monai.optimizers.Novograd: return Novograd( parameters["model_parameters"], - lr=parameters.get("learning_rate"), + lr=parameters.get("learning_rate", 1e-3), betas=parameters["optimizer"].get("betas", (0.9, 0.999)), eps=parameters["optimizer"].get("eps", 1e-8), weight_decay=parameters["optimizer"].get("weight_decay", 3e-05), diff --git a/GANDLF/optimizers/wrap_torch.py b/GANDLF/optimizers/wrap_torch.py index 9852f7973..2f4650bdb 100644 --- a/GANDLF/optimizers/wrap_torch.py +++ b/GANDLF/optimizers/wrap_torch.py @@ -1,3 +1,4 @@ +import torch from torch.optim import ( SGD, ASGD, @@ -14,7 +15,7 @@ ) -def sgd(parameters): +def sgd(parameters: dict) -> torch.optim.SGD: """ Creates a Stochastic Gradient Descent optimizer from the PyTorch `torch.optim` module using the input parameters. @@ -26,7 +27,7 @@ def sgd(parameters): """ # Create the optimizer using the input parameters - optimizer = SGD( + return SGD( parameters["model_parameters"], lr=parameters.get("learning_rate"), momentum=parameters["optimizer"].get("momentum", 0.99), @@ -35,10 +36,8 @@ def sgd(parameters): nesterov=parameters["optimizer"].get("nesterov", True), ) - return optimizer - -def asgd(parameters): +def asgd(parameters: dict) -> torch.optim.ASGD: """ Creates an Averaged Stochastic Gradient Descent optimizer from the PyTorch `torch.optim` module using the input parameters. @@ -60,7 +59,7 @@ def asgd(parameters): ) -def adam(parameters, opt_type="normal"): +def adam(parameters: dict, opt_type: str = "normal") -> torch.optim.Adam: """ Creates an Adam or AdamW optimizer from the PyTorch `torch.optim` module using the input parameters. @@ -91,7 +90,7 @@ def adam(parameters, opt_type="normal"): ) -def adamw(parameters): +def adamw(parameters: dict) -> torch.optim.AdamW: """ Creates an AdamW optimizer from the PyTorch `torch.optim` module using the input parameters. @@ -105,7 +104,7 @@ def adamw(parameters): return adam(parameters, opt_type="AdamW") -def adamax(parameters): +def adamax(parameters: dict) -> torch.optim.Adamax: """ Creates an Adamax optimizer from the PyTorch `torch.optim` module using the input parameters. @@ -141,7 +140,7 @@ def adamax(parameters): # ) -def rprop(parameters): +def rprop(parameters: dict) -> torch.optim.Rprop: """ Creates a Resilient Backpropagation optimizer from the PyTorch `torch.optim` module using the input parameters. @@ -161,7 +160,7 @@ def rprop(parameters): ) -def adadelta(parameters): +def adadelta(parameters: dict) -> torch.optim.Adadelta: """ Creates an Adadelta optimizer from the PyTorch `torch.optim` module using the input parameters. @@ -182,7 +181,7 @@ def adadelta(parameters): ) -def adagrad(parameters): +def adagrad(parameters: dict) -> torch.optim.Adagrad: """ Creates an Adagrad optimizer from the PyTorch `torch.optim` module using the input parameters. @@ -204,7 +203,7 @@ def adagrad(parameters): ) -def rmsprop(parameters): +def rmsprop(parameters: dict) -> torch.optim.RMSprop: """ Creates an RMSprop optimizer from the PyTorch `torch.optim` module using the input parameters. @@ -227,7 +226,7 @@ def rmsprop(parameters): ) -def radam(parameters): +def radam(parameters: dict) -> torch.optim.RAdam: """ Creates a RAdam optimizer from the PyTorch `torch.optim` module using the input parameters. @@ -248,7 +247,7 @@ def radam(parameters): ) -def nadam(parameters): +def nadam(parameters: dict) -> torch.optim.NAdam: """ Creates a NAdam optimizer from the PyTorch `torch.optim` module using the input parameters. diff --git a/GANDLF/utils/data_splitter.py b/GANDLF/utils/data_splitter.py index f289f9534..939c8cfca 100644 --- a/GANDLF/utils/data_splitter.py +++ b/GANDLF/utils/data_splitter.py @@ -58,7 +58,7 @@ def split_data( # put 2 just so that the first for-loop does not fail testing_folds = 2 print( - "WARNING: Testing data is empty, which will result in scientifically incorrect results; use at your own risk." + "WARNING: Cross-validation is set to run on a train/validation scheme without testing data. For a more rigorous evaluation and if you wish to tune hyperparameters, make sure to use nested cross-validation." ) # get unique subject IDs diff --git a/GANDLF/utils/gandlf_logging.py b/GANDLF/utils/gandlf_logging.py index f08a3620f..576df868e 100644 --- a/GANDLF/utils/gandlf_logging.py +++ b/GANDLF/utils/gandlf_logging.py @@ -4,6 +4,7 @@ from importlib import resources import tempfile from GANDLF.utils import get_unique_timestamp +import sys def _create_tmp_log_file(): @@ -26,6 +27,13 @@ def _configure_logging_with_logfile(log_file, config_path): logging.config.dictConfig(config_dict) +def gandlf_excepthook(exctype, value, tb): + if issubclass(exctype, Exception): + logging.exception("Uncaught exception", exc_info=(exctype, value, tb)) + else: + sys.__excepthook__(exctype, value, tb) + + def logger_setup(log_file=None, config_path="logging_config.yaml") -> None: """ It sets up the logger. Reads from logging_config. @@ -42,6 +50,7 @@ def logger_setup(log_file=None, config_path="logging_config.yaml") -> None: log_tmp_file = _create_tmp_log_file() _create_log_file(log_tmp_file) _configure_logging_with_logfile(log_tmp_file, config_path) + sys.excepthook = gandlf_excepthook logging.info(f"The logs are saved in {log_tmp_file}") diff --git a/docs/customize.md b/docs/customize.md index f6b9df16e..9c33cd523 100644 --- a/docs/customize.md +++ b/docs/customize.md @@ -119,17 +119,17 @@ This file contains mid-level information regarding various parameters that can b - These are various parameters that control the overall training process. - `verbose`: generate verbose messages on console; generally used for debugging. -- `batch_size`: defines the batch size to be used for training. -- `in_memory`: this is to enable or disable lazy loading - setting to true reads all data once during data loading, resulting in improvements. -- `num_epochs`: defines the number of epochs to train for. -- `patience`: defines the number of epochs to wait for improvement before early stopping. -- `learning_rate`: defines the learning rate to be used for training. -- `scheduler`: defines the learning rate scheduler to be used for training, more details are [here](https://github.com/mlcommons/GaNDLF/blob/master/GANDLF/schedulers/__init__.py); can take the following sub-parameters: +- `batch_size`: batch size to be used for training. +- `in_memory`: this is to enable or disable lazy loading. If set to `True`, all data is loaded onto the RAM at once during the construction of the dataloader (either training/validation/testing), thus resulting in faster training. If set to `False`, data gets read into RAM on-the-go when needed (also called ["lazy loading"](https://en.wikipedia.org/wiki/Lazy_loading)), which slows down training but lessens the memory load. The latter is recommended if the user's RAM has limited capacity. +- `num_epochs`: number of epochs to train for. +- `patience`: number of epochs to wait for improvement in the validation loss before early stopping. +- `learning_rate`: learning rate to be used for training. +- `scheduler`: learning rate scheduler to be used for training, more details are [here](https://github.com/mlcommons/GaNDLF/blob/master/GANDLF/schedulers/__init__.py); can take the following sub-parameters: - `type`: `triangle`, `triangle_modified`, `exp`, `step`, `reduce-on-plateau`, `cosineannealing`, `triangular`, `triangular2`, `exp_range` - - `min_lr`: defines the minimum learning rate to be used for training. - - `max_lr`: defines the maximum learning rate to be used for training. -- `optimizer`: defines the optimizer to be used for training, more details are [here](https://github.com/mlcommons/GaNDLF/blob/master/GANDLF/optimizers/__init__.py). -- `nested_training`: defines the number of folds to use nested training, takes `testing` and `validation` as sub-parameters, with integer values defining the number of folds to use. + - `min_lr`: minimum learning rate to be used for training. + - `max_lr`: maximum learning rate to be used for training. +- `optimizer`: optimizer to be used for training, more details are [here](https://github.com/mlcommons/GaNDLF/blob/master/GANDLF/optimizers/__init__.py). +- `nested_training`: number of folds to use nested training, takes `testing` and `validation` as sub-parameters, with integer values defining the number of folds to use. - `memory_save_mode`: if enabled, resize/resample operations in `data_preprocessing` will save files to disk instead of directly getting read into memory as tensors - **Queue configuration**: this defines how the queue for the input to the model is to be designed **after** the [patching strategy](#patching-strategy) has been applied, and more details are [here](https://torchio.readthedocs.io/data/patch_training.html?#queue). This takes the following sub-parameters: - `q_max_length`: his determines the maximum number of patches that can be stored in the queue. Using a large number means that the queue needs to be filled less often, but more CPU memory is needed to store the patches. diff --git a/docs/faq.md b/docs/faq.md index b0a89cf9c..dfe0c2234 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -49,6 +49,10 @@ Please see https://mlcommons.github.io/GaNDLF/usage/#federating-your-model-using Please see https://mlcommons.github.io/GaNDLF/usage/#federating-your-model-evaluation-using-medperf. +### I was using GaNDLF version `0.0.19` or earlier, and I am facing issues after updating to `0.0.20` or later. What should I do? + +Please read the [migration guide](https://mlcommons.github.io/GaNDLF/migration_guide) to understand the changes that have been made to GaNDLF. If you have any questions, please feel free to [post a support request](https://github.com/mlcommons/GaNDLF/issues/new?assignees=&labels=&template=--questions-help-support.md&title=). + ### What if I have another question? -Please [post a support request](https://github.com/mlcommons/GaNDLF/issues/new?assignees=&labels=&template=--questions-help-support.md&title=). \ No newline at end of file +Please [post a support request](https://github.com/mlcommons/GaNDLF/issues/new?assignees=&labels=&template=--questions-help-support.md&title=). diff --git a/docs/index.md b/docs/index.md index 546726bfb..8180d2956 100644 --- a/docs/index.md +++ b/docs/index.md @@ -28,6 +28,7 @@ and **CSV inputs** that describe the training data. - [Usage](./usage.md) - [Customize the training and inference](./customize.md) - [Extending GaNDLF](./extending.md) +- [ITCR Connectivity](./itcr_connectivity.md) - [FAQ](./faq.md) - [Acknowledgements](./acknowledgements.md) diff --git a/docs/itcr_connectivity.md b/docs/itcr_connectivity.md new file mode 100644 index 000000000..314d16e59 --- /dev/null +++ b/docs/itcr_connectivity.md @@ -0,0 +1,47 @@ +# ITCR Connectivity + +This section includes a reference of all ongoing and existing connections between FeTS and other projects funded under the [Informatics Technology for Cancer Research (ITCR)](https://itcr.cancer.gov/) program. + +A connectivity map featuring all ITCR projects can be found [here](https://www.ndexbio.org/#/network/04c0a7e8-af92-11e7-94d3-0ac135e8bacf). + +- [Existing Connections](#existing-connections) + - [DCMTK](#dcmtk) + - [Synapse PACS](#synapse-pacs) + - [FLAIM](#flaim) +- [Ongoing Development](#ongoing-development) + - [XNAT](#xnat) + - [OHIF](#ohif) + - [Radiomics.io](#radiomics) + - [RadXTools](#radxtools) + - [TCIA](#tcia) + + +## Existing Connections + +### DCMTK +GaNDLF uses [DCMTK - DICOM ToolKit (DCMTK)](https://dicom.offis.de/dcmtk.php.en) (through [ITK](https://itk.org/)) for DICOM file handling. + +### Synapse PACS +[GaNDLF's Metrics Suite](https://docs.mlcommons.org/GaNDLF/usage/#generate-metrics) is used by [Synapse PACS](https://www.synapse.org/). + +### FLAIM +GaNDLF models can be ingested through FLAIM to facilitate interoperability and privacy preservation. + +## Ongoing Development + +### XNAT +Enable the use of GaNDLF's models to databases and population cohorts defined by [XNAT](https://xnat.org). + +### OHIF +Online visualization using OHIF. + +### Radiomics +Extraction of downstream features of automatically generated segmentation maps using [Radiomics.io](https://www.radiomics.io/). + +### RadXTools +Integration with [RadXTools](https://radxtools.github.io/) for extended functionality. + +### TCIA +Integration with TCIA's REST APIs to make downloading datasets easier. + +Contact [gandlf [at] mlcommons.org](mailto:gandlf@mlcommons.org) with any questions. \ No newline at end of file diff --git a/docs/migration_guide.md b/docs/migration_guide.md new file mode 100644 index 000000000..b435e30c0 --- /dev/null +++ b/docs/migration_guide.md @@ -0,0 +1,39 @@ +# Migration Guide + +The [0.0.20 release](https://github.com/mlcommons/GaNDLF/releases/tag/0.0.20) was the final release that supported the old way of using GaNDLF (i.e., `gandlf_run`). Instead, we now have a CLI that is more unified and based on modern CLI parsing (i.e., `gandlf run`). If you have been using version `0.0.20` or earlier, please follow this guide to move your experimental setup to the new CLI [[ref](https://github.com/mlcommons/GaNDLF/pull/845)]. + +## User-level Changes + +### Command Line Interfaces + +- The CLI commands have been moved to use [`click`](https://click.palletsprojects.com/en/8.1.x/) for parsing the command line arguments. This means that the commands are now more user-friendly and easier to remember, as well as with added features like tab completion and type checks. +- All the commands that were previously available in as `gandlf_${functionality}` are now available as `gandlf ${functionality}` (i.e., replace the `_` with ` `). +- The previous commands are still present, but they are deprecated and will be removed in a future release. + +### Configuration Files + +- The main change is the use of the [Version package](https://github.com/keleshev/version) for systematic semantic versioning [[ref](https://github.com/mlcommons/GaNDLF/pull/841)]. +- No change is needed if you are using a [stable version](https://docs.mlcommons.org/GaNDLF/setup/#install-from-package-managers). +- If you have installed GaNDLF [from source](https://docs.mlcommons.org/GaNDLF/setup/#install-from-sources) or using a [nightly build](https://docs.mlcommons.org/GaNDLF/setup/#install-from-package-managers), you will need to ensure that the `maximum` key under `version` in the configuration file contains the correct version number: + - Either **including** the `-dev` identifier of the current version (e.g., if the current version is `0.1.0-dev`, then the `maximum` key should be `0.1.0-dev`). + - Or **excluding** the `-dev` identifier of the current version, but increasing the version number by one on any level (e.g., if the current version is `0.1.0-dev`, then the `maximum` key should be `0.1.1`). + +### Use in HPC Environments + +- If you are using GaNDLF in an HPC environment, you will need to update the job submission scripts to use the new CLI commands. +- The previous API required one to call the interpreter and the specific command (e.g., `${venv_gandlf}/bin/python gandlf_run`), while the new API requires one to call the GaNDLF command directly (e.g., `${venv_gandlf}/bin/gandlf run` or `${venv_gandlf}/bin/gandlf_run`). +- The [Slurm experiments template](https://github.com/IUCompPath/gandlf_experiments_template_slurm) has been appropriately updated to reflect this change. + + +## Developer-level Changes + +### Command Line Interfaces + +- CLI entrypoints are now defined in the `GANDLF.entrypoints` module, which contains argument parsing (using both the old and new API structures). +- CLI entrypoint logic is now defined in the `GANDLF.cli` module, which only contains how the specific functionality is executed from an algorithmic perspective. + - This is to ensure backwards API compatibility, and will **not** be removed. + +### Configuration Files + +- GaNDLF's [`config_manager` module](https://github.com/mlcommons/GaNDLF/blob/master/GANDLF/config_manager.py) is now the primary way to manage configuration files. +- This is going to be updated to use [pydantic](https://docs.pydantic.dev/latest/) in the near future [[ref](https://github.com/mlcommons/GaNDLF/issues/758)]. diff --git a/docs/setup.md b/docs/setup.md index ff9ce222f..9f9cb5397 100644 --- a/docs/setup.md +++ b/docs/setup.md @@ -22,21 +22,21 @@ Alternatively, you can run GaNDLF via [Docker](https://www.docker.com/). This ne ### Install PyTorch -GaNDLF's primary computational foundation is built on PyTorch, and as such it supports all hardware types that PyTorch supports. Please install PyTorch for your hardware type before installing GaNDLF. See the [PyTorch installation instructions](https://pytorch.org/get-started/previous-versions/#v1131) for more details. An example installation using CUDA, ROCm, and CPU-only is shown below: +GaNDLF's primary computational foundation is built on PyTorch, and as such it supports all hardware types that PyTorch supports. Please install PyTorch for your hardware type before installing GaNDLF. See the [PyTorch installation instructions](https://pytorch.org/get-started/previous-versions/#v1131) for more details. + +First, instantiate your environment ```bash (base) $> conda create -n venv_gandlf python=3.9 -y (base) $> conda activate venv_gandlf (venv_gandlf) $> ### subsequent commands go here -### PyTorch installation - https://pytorch.org/get-started/previous-versions/#v210 -## CUDA 12.1 -# (venv_gandlf) $> pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu121 -## CUDA 11.8 -# (venv_gandlf) $> pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu118 -## ROCm 6.0 -# (venv_gandlf) $> pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/rocm6.0 -## CPU-only -# (venv_gandlf) $> pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cpu +``` + +You may install pytorch to be compatible with CUDA, ROCm, or CPU-only. An exhaustive list of PyTorch installations for the specific version compatible with GaNDLF can be found here: https://pytorch.org/get-started/previous-versions/#v231 +Use one of the following depending on your needs: +- CUDA 12.1 +```bash +(venv_gandlf) $> pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu121 ``` ### Optional Dependencies @@ -53,33 +53,38 @@ The following dependencies are **optional**, and are only needed to access speci This option is recommended for most users, and allows for the quickest way to get started with GaNDLF. ```bash -# continue from previous shell (venv_gandlf) $> pip install gandlf # this will give you the latest stable release -## you can also use conda -# (venv_gandlf) $> conda install -c conda-forge gandlf -y +``` +You can also use conda +```bash +(venv_gandlf) $> conda install -c conda-forge gandlf -y ``` If you are interested in running the latest version of GaNDLF, you can install the nightly build by running the following command: ```bash -# continue from previous shell (venv_gandlf) $> pip install --pre gandlf -## you can also use conda -# (venv_gandlf) $> conda install -c conda-forge/label/gandlf_dev -c conda-forge gandlf -y ``` +You can also use conda +```bash +(venv_gandlf) $> conda install -c conda-forge/label/gandlf_dev -c conda-forge gandlf -y +``` ### Install from Sources Use this option if you want to [contribute to GaNDLF](https://github.com/mlcommons/GaNDLF/blob/master/CONTRIBUTING.md), or are interested to make other code-level changes for your own use. ```bash -# continue from previous shell (venv_gandlf) $> git clone https://github.com/mlcommons/GaNDLF.git (venv_gandlf) $> cd GaNDLF (venv_gandlf) $> pip install -e . ``` +Test your installation: +```bash +(venv_gandlf) $> gandlf verify-install +``` ## Docker Installation diff --git a/mkdocs.yml b/mkdocs.yml index 0ab930a83..f1ae63a67 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -6,8 +6,10 @@ nav: - Getting Started: getting_started.md - Installation: setup.md - Usage: usage.md + - Migration Guide: migration_guide.md - Customize Training and Inference: customize.md - Extending GaNDLF: extending.md + - ITCR Connectivity: itcr_connectivity.md - FAQ: faq.md - Acknowledgements: acknowledgements.md theme: