Skip to content

Commit

Permalink
Horovod demo (securefederatedai#899)
Browse files Browse the repository at this point in the history
* initial commit

* changes

* name changes, added readme

* remove files, new dataloader inherit

* move files, flake changes

* flake fix

---------

Co-authored-by: Patrick Foley <[email protected]>
Signed-off-by: nammbash <[email protected]>
  • Loading branch information
2 people authored and nammbash committed Feb 27, 2024
1 parent dc11377 commit 8ed2895
Show file tree
Hide file tree
Showing 15 changed files with 987 additions and 0 deletions.
2 changes: 2 additions & 0 deletions openfl-workspace/torch_llm_horovod/.workspace
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
current_plan_name: default

40 changes: 40 additions & 0 deletions openfl-workspace/torch_llm_horovod/LLM_Horovod.MD
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
This readme provides instructions for setting up and running the Horovod example using OpenFL.

## Prerequisites
Before running the Horovod example, ensure that the following prerequisites are met:

1. Python environment should be set up on all nodes.
2. The environment should be sourced when logging into SSH.

## Setting up Horovod Dependencies
To set up the Horovod dependencies, follow these steps:

1. Run the `setup_env.sh` script located in `openfl-workspace/torch_llm_horovod/setup_env.sh` within your virtual environment (venv).
2. Create aggregator and collaborator workspaces. See [Running the experiment](#running-the-experiment)
3. Ensure that the collaborator workspace is present in each node with the same file structure.
4. Make sure the dataset is available in each node.

## Setting up Passwordless SSH Login
Horovod requires passwordless SSH login. Follow the instructions provided at [this link](http://www.linuxproblem.org/art_9.html) to set it up.

## Environmental Variables
Set the following environmental variables for Horovod:

- `OPENFL_HOROVOD_DEMO_NP`: Set this variable to the number of processes to run (e.g., "4").
- `OPENFL_HOROVOD_DEMO_NICS`: Set this variable to the common network interface name to use with all nodes (e.g., "en01,lo").
- `OPENFL_HOROVOD_DEMO_LOCALHOSTIP`: Set this variable to the IP address of the local node (e.g., "ip1").
- `OPENFL_HOROVOD_DEMO_HOSTS`: Set this variable to the IP address of each node and the number of slots (e.g., "ip1:2,ip2:2").

## Customizing Data and Models
To use your own data and models, follow these steps:

1. Copy the `openfl/openfl-workspace/torch_llm_horovod` directory to `openfl/openfl-workspace/name_of_your_template`.
2. In the `src/InHorovodrun` file, make the following changes:
- Replace `GlueMrpcDataLoader` with your own dataloader.
- Replace `LLMTrainer` with your own training/validation scripts.

## Running the Experiment
To run the experiment, follow the instructions provided in the [OpenFL documentation](https://openfl.readthedocs.io/en/latest/running_the_federation.html#bare-metal-approach) using either the `torch_llm_horovod` template or your own template.

That's it! You're now ready to use the Horovod example with your own data and models. Enjoy!

5 changes: 5 additions & 0 deletions openfl-workspace/torch_llm_horovod/plan/cols.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (C) 2020-2021 Intel Corporation
# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you.

collaborators:

9 changes: 9 additions & 0 deletions openfl-workspace/torch_llm_horovod/plan/data.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
## Copyright (C) 2020-2021 Intel Corporation
# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you.

# all keys under 'collaborators' corresponds to a specific colaborator name the corresponding dictionary has data_name, data_path pairs.
# Note that in the mnist case we do not store the data locally, and the data_path is used to pass an integer that helps the data object
# construct the shard of the mnist dataset to be use for this collaborator.

# collaborator_name ,data_directory_path
one,1
2 changes: 2 additions & 0 deletions openfl-workspace/torch_llm_horovod/plan/defaults
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
../../workspace/plan/defaults

45 changes: 45 additions & 0 deletions openfl-workspace/torch_llm_horovod/plan/plan.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (C) 2020-2021 Intel Corporation
# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you.

aggregator :
defaults : plan/defaults/aggregator.yaml
template : openfl.component.Aggregator
settings :
init_state_path : save/torch_llm_init.pbuf
best_state_path : save/torch_llm_best.pbuf
last_state_path : save/torch_llm_last.pbuf
rounds_to_train : 5
log_metric_callback :
template : src.glue_utils.write_metric


collaborator :
defaults : plan/defaults/collaborator.yaml
template : openfl.component.Collaborator
settings :
delta_updates : false
opt_treatment : RESET

data_loader :
defaults : plan/defaults/data_loader.yaml
template : src.ptglue_inmemory.GlueMrpcFederatedDataLoader
settings :
collaborator_count : 2
data_group_name : mnist
batch_size : 256

task_runner :
defaults : plan/defaults/task_runner.yaml
template : src.pt_model.LLMTaskRunner

network :
defaults : plan/defaults/network.yaml

assigner :
defaults : plan/defaults/assigner.yaml

tasks :
defaults : plan/defaults/tasks_torch.yaml

compression_pipeline :
defaults : plan/defaults/compression_pipeline.yaml
14 changes: 14 additions & 0 deletions openfl-workspace/torch_llm_horovod/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
torch
tensorboard
wheel>=0.38.0 # not directly required, pinned by Snyk to avoid a vulnerability
sentencepiece
accelerate
jupyter
huggingface_hub
peft
transformers[torch]
datasets
evaluate
seqeval
horovod
torchvision
8 changes: 8 additions & 0 deletions openfl-workspace/torch_llm_horovod/setup_env.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@

pip install -U pip --no-cache
pip install torch==2.0.0 torchvision==0.15.1 torchaudio==2.0.1
export HOROVOD_WITH_PYTORCH=1
export HOROVOD_WITHOUT_MPI=1
pip install -r openfl-workspace/torch_llm_horovod/requirements.txt --no-cache


192 changes: 192 additions & 0 deletions openfl-workspace/torch_llm_horovod/src/InHorovodLLMTrainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# Copyright (C) 2020-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""You may copy this file as the starting point of your own model."""

import os
import sys
from typing import Any, Mapping

import horovod.torch as hvd
import numpy as np
import torch
import torch as pt
import torch.nn as nn
import tqdm
from datasets import load_metric
from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from openfl.utilities import Metric

SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(SCRIPT_DIR))

from src.model_utils import _init_model, _init_optimizer # noqa: E402


class LLMTrainer(nn.Module):
def __init__(
self,
data_loader,
base_model_name="roberta-base",
device=None,
metric=None,
args=None,
**kwargs,
):
super().__init__()
self.data_loader = data_loader
self.base_model_name = base_model_name
self.kwargs = kwargs
self.device = device
self.metric = load_metric("glue", "mrpc")
self.model = _init_model(base_model_name, device)
self.optimizer, self.lr_scheduler = _init_optimizer(
self.model, len(self.data_loader.train_set)
)

def train(self):
return self.model.train()

def state_dict(self):
return get_peft_model_state_dict(self.model)

def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
return set_peft_model_state_dict(self.model, state_dict)

def load_state(self, kwargs):
print("loading data", os.getcwd())
if hvd.rank() == 0:
checkpoint = torch.load(kwargs["state_path"])
self.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
kwargs.update(checkpoint["kwargs"])
print("loaded")
print("kwags broadcast")
kwargs = hvd.broadcast_object(kwargs, root_rank=0)
print("optimizer broadcast")
optim_state = hvd.broadcast_object(self.optimizer.state_dict(), root_rank=0)
print("model broadcast")
state_dict = hvd.broadcast_object(self.state_dict(), root_rank=0)
print("scheduler broadcast")
lr_scheduler_state_dict = hvd.broadcast_object(
self.lr_scheduler.state_dict(), root_rank=0
)
if hvd.rank() > 0:
self.load_state_dict(state_dict)
self.lr_scheduler.load_state_dict(lr_scheduler_state_dict)
self.optimizer.load_state_dict(optim_state)

def train_batches(self, round_num, use_tqdm=False, epochs=1, **kwargs):
"""Train batches.
Train the model on the requested number of batches.
Args:
col_name: Name of the collaborator
round_num: What round is it
input_tensor_dict: Required input tensors (for model)
use_tqdm (bool): Use tqdm to print a progress bar (Default=True)
epochs: The number of epochs to train
Returns:
global_output_dict: Tensors to send back to the aggregator
local_output_dict: Tensors to maintain in the local TensorDB
"""
self.load_state(kwargs)

self.train()
self.to(self.device)
for epoch in range(epochs):
loader = self.data_loader.get_train_loader()
if use_tqdm:
loader = tqdm.tqdm(loader, desc="train epoch")
metric = self.train_epoch(loader)
metric = hvd.allreduce(torch.from_numpy(metric))
if hvd.rank() == 0:
if self.model.config.problem_type == "regression":
loss_fct = MSELoss()
elif self.model.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
elif self.model.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
torch.save(
{
"output": metric,
"model_state_dict": self.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"loss_fct_name": loss_fct._get_name(),
"lr_scheduler": self.lr_scheduler.state_dict(),
},
kwargs["out_path"],
)

def validate(self, round_num, use_tqdm=False, **kwargs):
"""Validate.
Run validation of the model on the local data.
Args:
col_name: Name of the collaborator
round_num: What round is it
input_tensor_dict: Required input tensors (for model)
use_tqdm (bool): Use tqdm to print a progress bar (Default=True)
Returns:
global_output_dict: Tensors to send back to the aggregator
local_output_dict: Tensors to maintain in the local TensorDB
"""
self.load_state(kwargs)

self.model.eval()
self.model.to(self.device)
val_score = 0
total_samples = 0
loader = self.data_loader.get_valid_loader()
if use_tqdm:
loader = tqdm.tqdm(loader, desc="validate")
samples_run = 0
with pt.no_grad():
for sample in loader:
samples = sample["input_ids"].shape[0]
total_samples += samples
output = self.model(**sample)
# get the index of the max log-probability
logits = output.logits
predictions = torch.argmax(logits, dim=-1)
self.metric.add_batch(
predictions=predictions, references=sample["labels"]
)
samples_run += len(sample)
val_score = np.asarray(self.metric.compute()["accuracy"])
result = hvd.allreduce(torch.from_numpy(val_score))
if hvd.rank() == 0:
torch.save({"output": result}, kwargs["out_path"])
hvd.join()

def train_epoch(self, batch_generator) -> Metric:
"""Train single epoch.
Override this function in order to use custom training.
Args:
batch_generator: Train dataset batch generator. Yields (samples, targets) tuples of
size = `self.data_loader.batch_size`.
Returns:
Metric: An object containing name and np.ndarray value.
"""
losses = []
for sample in batch_generator:
self.model.zero_grad()
output = self.model(**sample)
loss = output.loss
loss.backward(loss)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.step()
self.lr_scheduler.step()
losses.append(loss.detach().cpu().numpy())
loss = np.mean(losses)
return np.array(loss)
Loading

0 comments on commit 8ed2895

Please sign in to comment.