Skip to content

Commit

Permalink
add script
Browse files Browse the repository at this point in the history
  • Loading branch information
mattiasakesson committed Aug 13, 2024
1 parent 17401ba commit 03703bb
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 0 deletions.
99 changes: 99 additions & 0 deletions examples/mnist-pytorch-DPSGD/client/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import os
from math import floor

import torch
import torchvision

dir_path = os.path.dirname(os.path.realpath(__file__))
abs_path = os.path.abspath(dir_path)


def get_data(out_dir="data"):
# Make dir if necessary
if not os.path.exists(out_dir):
os.mkdir(out_dir)

# Only download if not already downloaded
if not os.path.exists(f"{out_dir}/train"):
torchvision.datasets.MNIST(root=f"{out_dir}/train", transform=torchvision.transforms.ToTensor, train=True, download=True)
if not os.path.exists(f"{out_dir}/test"):
torchvision.datasets.MNIST(root=f"{out_dir}/test", transform=torchvision.transforms.ToTensor, train=False, download=True)


def load_data(data_path, is_train=True):
"""Load data from disk.
:param data_path: Path to data file.
:type data_path: str
:param is_train: Whether to load training or test data.
:type is_train: bool
:return: Tuple of data and labels.
:rtype: tuple
"""
print("data_path is None: ", data_path is None)
if data_path is None:
data_path = os.environ.get("FEDN_DATA_PATH", abs_path + "/data/clients/1/mnist.pt")

print("data path: ", data_path)
data = torch.load(data_path)

if is_train:
X = data["x_train"]
y = data["y_train"]
else:
X = data["x_test"]
y = data["y_test"]

# Normalize
X = X / 255

return X, y


def splitset(dataset, parts):
n = dataset.shape[0]
local_n = floor(n / parts)
result = []
for i in range(parts):
result.append(dataset[i * local_n : (i + 1) * local_n])
return result


def split(out_dir="data"):
n_splits = int(os.environ.get("FEDN_NUM_DATA_SPLITS", 2))

# Make dir
if not os.path.exists(f"{out_dir}/clients"):
os.mkdir(f"{out_dir}/clients")

# Load and convert to dict
train_data = torchvision.datasets.MNIST(root=f"{out_dir}/train", transform=torchvision.transforms.ToTensor, train=True)
test_data = torchvision.datasets.MNIST(root=f"{out_dir}/test", transform=torchvision.transforms.ToTensor, train=False)
data = {
"x_train": splitset(train_data.data, n_splits),
"y_train": splitset(train_data.targets, n_splits),
"x_test": splitset(test_data.data, n_splits),
"y_test": splitset(test_data.targets, n_splits),
}

# Make splits
for i in range(n_splits):
subdir = f"{out_dir}/clients/{str(i+1)}"
if not os.path.exists(subdir):
os.mkdir(subdir)
torch.save(
{
"x_train": data["x_train"][i],
"y_train": data["y_train"][i],
"x_test": data["x_test"][i],
"y_test": data["y_test"][i],
},
f"{subdir}/mnist.pt",
)


if __name__ == "__main__":
# Prepare data if not already done
if not os.path.exists(abs_path + "/data/clients/1"):
get_data()
split()
12 changes: 12 additions & 0 deletions examples/mnist-pytorch-DPSGD/client/fedn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
python_env: python_env.yaml
entry_points:
build:
command: python model.py
startup:
command: python data.py
train:
command: python train.py
validate:
command: python validate.py
predict:
command: python predict.py
55 changes: 55 additions & 0 deletions examples/mnist-pytorch-DPSGD/client/validate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import os
import sys

import torch
from model import load_parameters

from data import load_data
from fedn.utils.helpers.helpers import save_metrics

dir_path = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.abspath(dir_path))


def validate(in_model_path, out_json_path, data_path=None):
"""Validate model.
:param in_model_path: The path to the input model.
:type in_model_path: str
:param out_json_path: The path to save the output JSON to.
:type out_json_path: str
:param data_path: The path to the data file.
:type data_path: str
"""
# Load data
x_train, y_train = load_data(data_path)
x_test, y_test = load_data(data_path, is_train=False)

# Load model
model = load_parameters(in_model_path)
model.eval()

# Evaluate
criterion = torch.nn.NLLLoss()
with torch.no_grad():
train_out = model(x_train)
training_loss = criterion(train_out, y_train)
training_accuracy = torch.sum(torch.argmax(train_out, dim=1) == y_train) / len(train_out)
test_out = model(x_test)
test_loss = criterion(test_out, y_test)
test_accuracy = torch.sum(torch.argmax(test_out, dim=1) == y_test) / len(test_out)

# JSON schema
report = {
"training_loss": training_loss.item(),
"training_accuracy": training_accuracy.item(),
"test_loss": test_loss.item(),
"test_accuracy": test_accuracy.item(),
}

# Save JSON
save_metrics(report, out_json_path)


if __name__ == "__main__":
validate(sys.argv[1], sys.argv[2])

0 comments on commit 03703bb

Please sign in to comment.