diff --git a/examples/FedSimSiam/.dockerignore b/examples/FedSimSiam/.dockerignore new file mode 100644 index 000000000..8ba9024ad --- /dev/null +++ b/examples/FedSimSiam/.dockerignore @@ -0,0 +1,4 @@ +data +seed.npz +*.tgz +*.tar.gz \ No newline at end of file diff --git a/examples/FedSimSiam/.gitignore b/examples/FedSimSiam/.gitignore new file mode 100644 index 000000000..047341d71 --- /dev/null +++ b/examples/FedSimSiam/.gitignore @@ -0,0 +1,6 @@ +data +*.npz +*.tgz +*.tar.gz +.fedsimsiam +client.yaml \ No newline at end of file diff --git a/examples/FedSimSiam/README.rst b/examples/FedSimSiam/README.rst new file mode 100644 index 000000000..54434c6dc --- /dev/null +++ b/examples/FedSimSiam/README.rst @@ -0,0 +1,125 @@ +FEDn Project: FedSimSiam on CIFAR-10 +------------------------------------ + +This is an example FEDn Project that runs the federated self-supervised learning algorithm FedSimSiam on +the CIFAR-10 dataset. This is a standard example often used for benchmarking. To be able to run this example, you +need to have GPU access. + + **Note: We recommend all new users to start by following the Quickstart Tutorial: https://fedn.readthedocs.io/en/stable/quickstart.html** + +Prerequisites +------------- + +- `Python 3.8, 3.9, 3.10 or 3.11 `__ +- `A FEDn Studio account `__ +- Change the dependencies in the 'client/python_env.yaml' file to match your cuda version. + +Creating the compute package and seed model +------------------------------------------- + +Install fedn: + +.. code-block:: + + pip install fedn + +Clone this repository, then locate into this directory: + +.. code-block:: + + git clone https://github.com/scaleoutsystems/fedn.git + cd fedn/examples/FedSimSiam + +Create the compute package: + +.. code-block:: + + fedn package create --path client + +This should create a file 'package.tgz' in the project folder. + +Next, generate a seed model (the first model in a global model trail): + +.. code-block:: + + fedn run build --path client + +This will create a seed model called 'seed.npz' in the root of the project. This step will take a few minutes, depending on hardware and internet connection (builds a virtualenv). + +Using FEDn Studio +----------------- + +Follow the instructions to register for FEDN Studio and start a project (https://fedn.readthedocs.io/en/stable/studio.html). + +In your Studio project: + +- Go to the 'Sessions' menu, click on 'New session', and upload the compute package (package.tgz) and seed model (seed.npz). +- In the 'Clients' menu, click on 'Connect client' and download the client configuration file (client.yaml) +- Save the client configuration file to the FedSimSiam example directory (fedn/examples/FedSimSiam) + +To connect a client, run the following command in your terminal: + +.. code-block:: + + fedn client start -in client.yaml --secure=True --force-ssl + + +Running the example +------------------- + +After everything is set up, go to 'Sessions' and click on 'New Session'. Click on 'Start run' and the example will execute. You can follow the training progress on 'Events' and 'Models', where you +can monitor the training progress. The monitoring is done using a kNN classifier that is fitted on the feature embeddings of the training images that are obtained by +FedSimSiam's encoder, and evaluated on the feature embeddings of the test images. This process is repeated after each training round. + +This is a common method to track FedSimSiam's training progress, as FedSimSiam aims to minimize the distance between the embeddings of similar images. +A high accuracy implies that the feature embeddings for images within the same class are indeed close to each other in the +embedding space, i.e., FedSimSiam learned useful feature embeddings. + + +Running FEDn in local development mode: +--------------------------------------- + +Follow the steps above to install FEDn, generate 'package.tgz' and 'seed.tgz'. + +Start a pseudo-distributed FEDn network using docker-compose: +.. code-block:: + + docker compose \ + -f ../../docker-compose.yaml \ + -f docker-compose.override.yaml \ + up + +This starts up local services for MongoDB, Minio, the API Server, one Combiner and two clients. +You can verify the deployment using these urls: + +- API Server: http://localhost:8092/get_controller_status +- Minio: http://localhost:9000 +- Mongo Express: http://localhost:8081 + +Upload the package and seed model to FEDn controller using the APIClient: + +.. code-block:: + + from fedn import APIClient + client = APIClient(host="localhost", port=8092) + client.set_active_package("package.tgz", helper="numpyhelper") + client.set_active_model("seed.npz") + + +You can now start a training session with 100 rounds using the API client: + +.. code-block:: + + client.start_session(rounds=100) + +Clean up +-------- + +You can clean up by running + +.. code-block:: + + docker-compose \ + -f ../../docker-compose.yaml \ + -f docker-compose.override.yaml \ + down -v diff --git a/examples/FedSimSiam/client/data.py b/examples/FedSimSiam/client/data.py new file mode 100644 index 000000000..95b10e7db --- /dev/null +++ b/examples/FedSimSiam/client/data.py @@ -0,0 +1,150 @@ +import os +from math import floor + +import numpy as np +import torch +import torchvision +from torchvision import transforms + +dir_path = os.path.dirname(os.path.realpath(__file__)) +abs_path = os.path.abspath(dir_path) + + +def get_data(out_dir="data"): + # Make dir if necessary + if not os.path.exists(out_dir): + os.mkdir(out_dir) + + # Only download if not already downloaded + if not os.path.exists(f"{out_dir}/train"): + torchvision.datasets.CIFAR10( + root=f"{out_dir}/train", train=True, download=True) + + if not os.path.exists(f"{out_dir}/test"): + torchvision.datasets.CIFAR10( + root=f"{out_dir}/test", train=False, download=True) + + +def load_data(data_path, is_train=True): + """ Load data from disk. + + :param data_path: Path to data file. + :type data_path: str + :param is_train: Whether to load training or test data. + :type is_train: bool + :return: Tuple of data and labels. + :rtype: tuple + """ + if data_path is None: + data_path = os.environ.get( + "FEDN_DATA_PATH", abs_path+"/data/clients/1/cifar10.pt") + + data = torch.load(data_path) + + if is_train: + X = data["x_train"] + y = data["y_train"] + else: + X = data["x_test"] + y = data["y_test"] + + return X, y + + +def create_knn_monitoring_dataset(out_dir="data"): + """ Creates dataset that is used to monitor the training progress via knn accuracies """ + if not os.path.exists(out_dir): + os.mkdir(out_dir) + + n_splits = int(os.environ.get("FEDN_NUM_DATA_SPLITS", 2)) + + # Make dir + if not os.path.exists(f"{out_dir}/clients"): + os.mkdir(f"{out_dir}/clients") + + normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], + std=[0.247, 0.243, 0.261]) + + memoryset = torchvision.datasets.CIFAR10(root="./data", train=True, + download=True, transform=transforms.Compose([transforms.ToTensor(), normalize])) + testset = torchvision.datasets.CIFAR10(root="./data", train=False, + download=True, transform=transforms.Compose([transforms.ToTensor(), normalize])) + + # save monitoring datasets to all clients + for i in range(n_splits): + subdir = f"{out_dir}/clients/{str(i+1)}" + if not os.path.exists(subdir): + os.mkdir(subdir) + torch.save(memoryset, f"{subdir}/knn_memoryset.pt") + torch.save(testset, f"{subdir}/knn_testset.pt") + + +def load_knn_monitoring_dataset(data_path, batch_size=16): + """ Loads the KNN monitoring dataset.""" + if data_path is None: + data_path = os.environ.get( + "FEDN_DATA_PATH", abs_path+"/data/clients/1/cifar10.pt") + + data_directory = os.path.dirname(data_path) + memory_path = os.path.join(data_directory, "knn_memoryset.pt") + testset_path = os.path.join(data_directory, "knn_testset.pt") + + memoryset = torch.load(memory_path) + testset = torch.load(testset_path) + + memoryset_loader = torch.utils.data.DataLoader( + memoryset, batch_size=batch_size, shuffle=False) + testset_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, + shuffle=False) + return memoryset_loader, testset_loader + + +def splitset(dataset, parts): + n = dataset.shape[0] + local_n = floor(n/parts) + result = [] + for i in range(parts): + result.append(dataset[i*local_n: (i+1)*local_n]) + return result + + +def split(out_dir="data"): + + n_splits = int(os.environ.get("FEDN_NUM_DATA_SPLITS", 2)) + + # Make dir + if not os.path.exists(f"{out_dir}/clients"): + os.mkdir(f"{out_dir}/clients") + + train_data = torchvision.datasets.CIFAR10( + root=f"{out_dir}/train", train=True) + test_data = torchvision.datasets.CIFAR10( + root=f"{out_dir}/test", train=False) + + data = { + "x_train": splitset(train_data.data, n_splits), + "y_train": splitset(np.array(train_data.targets), n_splits), + "x_test": splitset(test_data.data, n_splits), + "y_test": splitset(np.array(test_data.targets), n_splits), + } + + # Make splits + for i in range(n_splits): + subdir = f"{out_dir}/clients/{str(i+1)}" + if not os.path.exists(subdir): + os.mkdir(subdir) + torch.save({ + "x_train": data["x_train"][i], + "y_train": data["y_train"][i], + "x_test": data["x_test"][i], + "y_test": data["y_test"][i], + }, + f"{subdir}/cifar10.pt") + + +if __name__ == "__main__": + # Prepare data if not already done + if not os.path.exists(abs_path+"/data/clients/1"): + get_data() + split() + create_knn_monitoring_dataset() diff --git a/examples/FedSimSiam/client/fedn.yaml b/examples/FedSimSiam/client/fedn.yaml new file mode 100644 index 000000000..b05504102 --- /dev/null +++ b/examples/FedSimSiam/client/fedn.yaml @@ -0,0 +1,10 @@ +python_env: python_env.yaml +entry_points: + build: + command: python model.py + startup: + command: python data.py + train: + command: python train.py + validate: + command: python validate.py \ No newline at end of file diff --git a/examples/FedSimSiam/client/model.py b/examples/FedSimSiam/client/model.py new file mode 100644 index 000000000..d50d15c2e --- /dev/null +++ b/examples/FedSimSiam/client/model.py @@ -0,0 +1,144 @@ +import collections + +import torch +import torch.nn.functional as f +from torch import nn +from torchvision.models import resnet18 + +from fedn.utils.helpers.helpers import get_helper + +HELPER_MODULE = "numpyhelper" +helper = get_helper(HELPER_MODULE) + + +def D(p, z, version="simplified"): # negative cosine similarity + if version == "original": + z = z.detach() # stop gradient + p = f.normalize(p, dim=1) # l2-normalize + z = f.normalize(z, dim=1) # l2-normalize + return -(p*z).sum(dim=1).mean() + + elif version == "simplified": # same thing, much faster. Scroll down, speed test in __main__ + return - f.cosine_similarity(p, z.detach(), dim=-1).mean() + else: + raise Exception + + +class ProjectionMLP(nn.Module): + """Projection MLP f""" + + def __init__(self, in_features, h1_features, h2_features, out_features): + super(ProjectionMLP, self).__init__() + self.l1 = nn.Sequential( + nn.Linear(in_features, h1_features), + nn.BatchNorm1d(h1_features), + nn.ReLU(inplace=True) + ) + self.l2 = nn.Sequential( + nn.Linear(h1_features, out_features), + nn.BatchNorm1d(out_features) + ) + + def forward(self, x): + x = self.l1(x) + x = self.l2(x) + return x + + +class PredictionMLP(nn.Module): + """Prediction MLP h""" + + def __init__(self, in_features, hidden_features, out_features): + super(PredictionMLP, self).__init__() + self.l1 = nn.Sequential( + nn.Linear(in_features, hidden_features), + nn.BatchNorm1d(hidden_features), + nn.ReLU(inplace=True) + ) + self.l2 = nn.Linear(hidden_features, out_features) + + def forward(self, x): + x = self.l1(x) + x = self.l2(x) + return x + + +class SimSiam(nn.Module): + def __init__(self): + super(SimSiam, self).__init__() + backbone = resnet18(pretrained=False) + backbone.output_dim = backbone.fc.in_features + backbone.fc = torch.nn.Identity() + + self.backbone = backbone + + self.projector = ProjectionMLP(backbone.output_dim, 2048, 2048, 2048) + self.encoder = nn.Sequential( + self.backbone, + self.projector + ) + self.predictor = PredictionMLP(2048, 512, 2048) + + def forward(self, x1, x2): + f, h = self.encoder, self.predictor + z1, z2 = f(x1), f(x2) + p1, p2 = h(z1), h(z2) + L = D(p1, z2) / 2 + D(p2, z1) / 2 + return {"loss": L} + + +def compile_model(): + """ Compile the pytorch model. + + :return: The compiled model. + :rtype: torch.nn.Module + """ + model = SimSiam() + + return model + + +def save_parameters(model, out_path): + """ Save model paramters to file. + + :param model: The model to serialize. + :type model: torch.nn.Module + :param out_path: The path to save to. + :type out_path: str + """ + parameters_np = [val.cpu().numpy() + for _, val in model.state_dict().items()] + helper.save(parameters_np, out_path) + + +def load_parameters(model_path): + """ Load model parameters from file and populate model. + + param model_path: The path to load from. + :type model_path: str + :return: The loaded model. + :rtype: torch.nn.Module + """ + model = compile_model() + parameters_np = helper.load(model_path) + + params_dict = zip(model.state_dict().keys(), parameters_np) + state_dict = collections.OrderedDict( + {key: torch.tensor(x) for key, x in params_dict}) + model.load_state_dict(state_dict, strict=True) + return model + + +def init_seed(out_path="seed.npz"): + """ Initialize seed model and save it to file. + + :param out_path: The path to save the seed model to. + :type out_path: str + """ + # Init and save + model = compile_model() + save_parameters(model, out_path) + + +if __name__ == "__main__": + init_seed("../seed.npz") diff --git a/examples/FedSimSiam/client/monitoring.py b/examples/FedSimSiam/client/monitoring.py new file mode 100644 index 000000000..245e7f308 --- /dev/null +++ b/examples/FedSimSiam/client/monitoring.py @@ -0,0 +1,62 @@ +""" knn monitor as in InstDisc https://arxiv.org/abs/1805.01978. +This implementation follows http://github.com/zhirongw/lemniscate.pytorch and https://github.com/leftthomas/SimCLR +""" +import torch +import torch.nn.functional as f + + +def knn_monitor(net, memory_data_loader, test_data_loader, k=200, t=0.1): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + net = net.to(device) + net.eval() + classes = len(memory_data_loader.dataset.classes) + total_top1, total_num, feature_bank = 0.0, 0, [] + with torch.no_grad(): + # generate feature bank + for data, target in memory_data_loader: + # feature = net(data.cuda(non_blocking=True)) + feature = net(data.to(device)) + feature = f.normalize(feature, dim=1) + feature_bank.append(feature) + # [D, N] + feature_bank = torch.cat(feature_bank, dim=0).t().contiguous() + # [N] + feature_labels = torch.tensor( + memory_data_loader.dataset.targets, device=feature_bank.device) + # loop test data to predict the label by weighted knn search + for data, target in test_data_loader: + data, target = data.cuda( + non_blocking=True), target.cuda(non_blocking=True) + feature = net(data) + feature = f.normalize(feature, dim=1) + + pred_labels = knn_predict( + feature, feature_bank, feature_labels, classes, k, t) + + total_num += data.size(0) + total_top1 += (pred_labels[:, 0] == target).float().sum().item() + return total_top1 / total_num + + +def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t): + # compute cos similarity between each feature vector and feature bank ---> [B, N] + sim_matrix = torch.mm(feature, feature_bank) + # [B, K] + sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1) + # [B, K] + sim_labels = torch.gather(feature_labels.expand( + feature.size(0), -1), dim=-1, index=sim_indices) + sim_weight = (sim_weight / knn_t).exp() + + # counts for each class + one_hot_label = torch.zeros(feature.size( + 0) * knn_k, classes, device=sim_labels.device) + # [B*K, C] + one_hot_label = one_hot_label.scatter( + dim=-1, index=sim_labels.view(-1, 1), value=1.0) + # weighted score ---> [B, C] + pred_scores = torch.sum(one_hot_label.view(feature.size( + 0), -1, classes) * sim_weight.unsqueeze(dim=-1), dim=1) + + pred_labels = pred_scores.argsort(dim=-1, descending=True) + return pred_labels diff --git a/examples/FedSimSiam/client/python_env.yaml b/examples/FedSimSiam/client/python_env.yaml new file mode 100644 index 000000000..49b1ad2ec --- /dev/null +++ b/examples/FedSimSiam/client/python_env.yaml @@ -0,0 +1,9 @@ +name: fedsimsiam +build_dependencies: + - pip + - setuptools + - wheel==0.37.1 +dependencies: + - torch==2.2.0 + - torchvision==0.17.0 + - fedn==0.9.0 \ No newline at end of file diff --git a/examples/FedSimSiam/client/train.py b/examples/FedSimSiam/client/train.py new file mode 100644 index 000000000..0e7c565f6 --- /dev/null +++ b/examples/FedSimSiam/client/train.py @@ -0,0 +1,129 @@ +import os +import sys + +import numpy as np +import torch +from data import load_data +from model import load_parameters, save_parameters +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms +from utils import init_lrscheduler + +from fedn.utils.helpers.helpers import save_metadata + +dir_path = os.path.dirname(os.path.realpath(__file__)) +sys.path.append(os.path.abspath(dir_path)) + + +class SimSiamDataset(Dataset): + def __init__(self, x, y, is_train=True): + self.x = x + self.y = y + self.is_train = is_train + + def __getitem__(self, idx): + x = self.x[idx] + x = Image.fromarray(x.astype(np.uint8)) + + y = self.y[idx] + + normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], + std=[0.247, 0.243, 0.261]) + augmentation = [ + transforms.RandomResizedCrop(32, scale=(0.2, 1.)), + transforms.RandomApply([ + transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) + ], p=0.8), + transforms.RandomGrayscale(p=0.2), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize + ] + + if self.is_train: + transform = transforms.Compose(augmentation) + + x1 = transform(x) + x2 = transform(x) + return [x1, x2], y + + else: + transform = transforms.Compose([transforms.ToTensor(), normalize]) + + x = transform(x) + return x, y + + def __len__(self): + return len(self.x) + + +def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1, lr=0.01): + """ Complete a model update. + + Load model paramters from in_model_path (managed by the FEDn client), + perform a model update, and write updated paramters + to out_model_path (picked up by the FEDn client). + + :param in_model_path: The path to the input model. + :type in_model_path: str + :param out_model_path: The path to save the output model to. + :type out_model_path: str + :param data_path: The path to the data file. + :type data_path: str + :param batch_size: The batch size to use. + :type batch_size: int + :param epochs: The number of epochs to train. + :type epochs: int + :param lr: The learning rate to use. + :type lr: float + """ + # Load data + x_train, y_train = load_data(data_path) + + # Load parmeters and initialize model + model = load_parameters(in_model_path) + + trainset = SimSiamDataset(x_train, y_train, is_train=True) + trainloader = DataLoader( + trainset, batch_size=batch_size, shuffle=True) + + device = torch.device( + "cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + + model.train() + + optimizer, lr_scheduler = init_lrscheduler( + model, 500, trainloader) + + for epoch in range(epochs): + for idx, data in enumerate(trainloader): + images = data[0] + optimizer.zero_grad() + data_dict = model.forward(images[0].to( + device, non_blocking=True), images[1].to(device, non_blocking=True)) + loss = data_dict["loss"].mean() + print(loss) + loss.backward() + optimizer.step() + lr_scheduler.step() + + # Metadata needed for aggregation server side + metadata = { + # num_examples are mandatory + "num_examples": len(x_train), + "batch_size": batch_size, + "epochs": epochs, + "lr": lr + } + + # Save JSON metadata file (mandatory) + save_metadata(metadata, out_model_path) + + # Save model update (mandatory) + save_parameters(model, out_model_path) + + +if __name__ == "__main__": + train(sys.argv[1], sys.argv[2]) diff --git a/examples/FedSimSiam/client/utils.py b/examples/FedSimSiam/client/utils.py new file mode 100644 index 000000000..b10e0f06d --- /dev/null +++ b/examples/FedSimSiam/client/utils.py @@ -0,0 +1,78 @@ +import numpy as np +import torch + + +class LrScheduler(object): + def __init__(self, optimizer, warmup_epochs, warmup_lr, num_epochs, base_lr, final_lr, iter_per_epoch, constant_predictor_lr=False): + self.base_lr = base_lr + self.constant_predictor_lr = constant_predictor_lr + warmup_iter = iter_per_epoch * warmup_epochs + warmup_lr_schedule = np.linspace(warmup_lr, base_lr, warmup_iter) + decay_iter = iter_per_epoch * (num_epochs - warmup_epochs) + cosine_lr_schedule = final_lr+0.5 * \ + (base_lr-final_lr)*(1+np.cos(np.pi*np.arange(decay_iter)/decay_iter)) + + self.lr_schedule = np.concatenate( + (warmup_lr_schedule, cosine_lr_schedule)) + self.optimizer = optimizer + self.iter = 0 + self.current_lr = 0 + + def step(self): + for param_group in self.optimizer.param_groups: + + if self.constant_predictor_lr and param_group["name"] == "predictor": + param_group["lr"] = self.base_lr + else: + lr = param_group["lr"] = self.lr_schedule[self.iter] + + self.iter += 1 + self.current_lr = lr + return lr + + def get_lr(self): + return self.current_lr + + +def get_optimizer(name, model, lr, momentum, weight_decay): + + predictor_prefix = ("module.predictor", "predictor") + parameters = [{ + "name": "base", + "params": [param for name, param in model.named_parameters() if not name.startswith(predictor_prefix)], + "lr": lr + }, { + "name": "predictor", + "params": [param for name, param in model.named_parameters() if name.startswith(predictor_prefix)], + "lr": lr + }] + + if name == "sgd": + optimizer = torch.optim.SGD( + parameters, lr=lr, momentum=momentum, weight_decay=weight_decay) + + return optimizer + + +def init_lrscheduler(model, total_epochs, dataloader): + warmup_epochs = 10 + warmup_lr = 0 + base_lr = 0.03 + final_lr = 0 + momentum = 0.9 + weight_decay = 0.0005 + batch_size = 64 + + optimizer = get_optimizer( + "sgd", model, + lr=base_lr*batch_size/256, + momentum=momentum, + weight_decay=weight_decay) + + lr_scheduler = LrScheduler( + optimizer, warmup_epochs, warmup_lr*batch_size/256, + total_epochs, base_lr*batch_size/256, final_lr*batch_size/256, + len(dataloader), + constant_predictor_lr=True + ) + return optimizer, lr_scheduler diff --git a/examples/FedSimSiam/client/validate.py b/examples/FedSimSiam/client/validate.py new file mode 100644 index 000000000..5e6c5ac53 --- /dev/null +++ b/examples/FedSimSiam/client/validate.py @@ -0,0 +1,63 @@ +import os +import sys + +import numpy as np +import torch +from data import load_knn_monitoring_dataset +from model import load_parameters +from monitoring import knn_monitor +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + +from fedn.utils.helpers.helpers import save_metrics + +dir_path = os.path.dirname(os.path.realpath(__file__)) +sys.path.append(os.path.abspath(dir_path)) + + +class Cifar10(Dataset): + def __init__(self, x, y): + self.x = x + self.y = y + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], # Approx. CIFAR-10 means + std=[0.247, 0.243, 0.261]) # Approx. CIFAR-10 std deviations + ]) + + def __len__(self): + return len(self.x) + + def __getitem__(self, idx): + x = self.x[idx] + x = Image.fromarray(x.astype(np.uint8)) + x = self.transform(x) + y = self.y[idx] + return x, y + + +def validate(in_model_path, out_json_path, data_path=None): + + memory_loader, test_loader = load_knn_monitoring_dataset(data_path) + + model = load_parameters(in_model_path) + device = torch.device( + "cuda") if torch.cuda.is_available() else torch.device("cpu") + + knn_accuracy = knn_monitor(model.encoder, memory_loader, test_loader, device, k=min( + 25, len(memory_loader.dataset))) + + print("knn accuracy: ", knn_accuracy) + + # JSON schema + report = { + "knn_accuracy": knn_accuracy, + } + + # Save JSON + save_metrics(report, out_json_path) + + +if __name__ == "__main__": + validate(sys.argv[1], sys.argv[2]) diff --git a/examples/FedSimSiam/docker-compose.override.yaml b/examples/FedSimSiam/docker-compose.override.yaml new file mode 100644 index 000000000..524e39d1d --- /dev/null +++ b/examples/FedSimSiam/docker-compose.override.yaml @@ -0,0 +1,35 @@ +# Compose schema version +version: '3.4' + +# Overriding requirements + +x-env: &defaults + GET_HOSTS_FROM: dns + FEDN_PACKAGE_EXTRACT_DIR: package + FEDN_NUM_DATA_SPLITS: 2 + +services: + + client1: + extends: + file: ${HOST_REPO_DIR:-.}/docker-compose.yaml + service: client + environment: + <<: *defaults + FEDN_DATA_PATH: /app/package/client/data/clients/1/cifar10.pt + deploy: + replicas: 1 + volumes: + - ${HOST_REPO_DIR:-.}/fedn:/app/fedn + + client2: + extends: + file: ${HOST_REPO_DIR:-.}/docker-compose.yaml + service: client + environment: + <<: *defaults + FEDN_DATA_PATH: /app/package/client/data/clients/2/cifar10.pt + deploy: + replicas: 1 + volumes: + - ${HOST_REPO_DIR:-.}/fedn:/app/fedn