Skip to content

Commit

Permalink
Respect configurable rounds_per_epoch configuration parameter during …
Browse files Browse the repository at this point in the history
…training
  • Loading branch information
JMGaljaard committed May 30, 2022
1 parent 8b8d3a0 commit 4852ed1
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 37 deletions.
78 changes: 43 additions & 35 deletions fltk/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,45 +57,53 @@ def _event_loop(self):
time.sleep(0.1)
self.logger.info('Exiting node')

def train(self, num_epochs: int):
def train(self, num_epochs: int, round_id: int):
"""
Function implementing federated learning training loop.
@param num_epochs: Number of epochs to run.
Function implementing federated learning training loop, allowing to run for a configurable number of epochs
on a local dataset. Note that only the last statistics of a run are sent to the caller (i.e. Federator).
@param num_epochs: Number of epochs to run during a communication round's training loop.
@type num_epochs: int
@return: Final running loss statistic and acquired parameters of the locally trained network.
@param round_id: Global communication round ID to be used during training.
@type round_id: int
@return: Final running loss statistic and acquired parameters of the locally trained network. NOTE that
intermediate information is only logged to the STD-out.
@rtype: Tuple[float, Dict[str, torch.Tensor]]
"""
start_time = time.time()

running_loss = 0.0
final_running_loss = 0.0
if self.distributed:
self.dataset.train_sampler.set_epoch(num_epochs)

number_of_training_samples = len(self.dataset.get_train_loader())
self.logger.info(f'{self.id}: Number of training samples: {number_of_training_samples}')

for i, (inputs, labels) in enumerate(self.dataset.get_train_loader(), 0):
inputs, labels = inputs.to(self.device), labels.to(self.device)

# zero the parameter gradients
self.optimizer.zero_grad()

outputs = self.net(inputs)
loss = self.loss_function(outputs, labels)

loss.backward()
self.optimizer.step()
running_loss += loss.item()
# Mark logging update step
if i % self.config.log_interval == 0:
self.logger.info(
f'[{self.id}] [{num_epochs:d}, {i:5d}] loss: {running_loss / self.config.log_interval:.3f}')
final_running_loss = running_loss / self.config.log_interval
running_loss = 0.0
end_time = time.time()
duration = end_time - start_time
self.logger.info(f'Train duration is {duration} seconds')
for local_epoch in range(num_epochs):
effective_epoch = round_id * num_epochs + local_epoch
progress = f'[RD-{round_id}][LE-{local_epoch}][EE-{effective_epoch}]'
if self.distributed:
# In case a client occurs within (num_epochs) communication rounds as this would cause
# an order or data to re-occur during training.
self.dataset.train_sampler.set_epoch(effective_epoch)

training_cardinality = len(self.dataset.get_train_loader())
self.logger.info(f'{progress}{self.id}: Number of training samples: {training_cardinality}')

for i, (inputs, labels) in enumerate(self.dataset.get_train_loader(), 0):
inputs, labels = inputs.to(self.device), labels.to(self.device)

# zero the parameter gradients
self.optimizer.zero_grad()

outputs = self.net(inputs)
loss = self.loss_function(outputs, labels)

loss.backward()
self.optimizer.step()
running_loss += loss.item()
# Mark logging update step
if i % self.config.log_interval == 0:
self.logger.info(
f'[{self.id}] [{local_epoch}/{num_epochs:d}, {i:5d}] loss: {running_loss / self.config.log_interval:.3f}')
final_running_loss = running_loss / self.config.log_interval
running_loss = 0.0
end_time = time.time()
duration = end_time - start_time
self.logger.info(f'{progress} Train duration is {duration} seconds')

return final_running_loss, self.get_nn_parameters(),

Expand Down Expand Up @@ -148,7 +156,7 @@ def test(self) -> Tuple[float, float, np.array]:
def get_client_datasize(self): # pylint: disable=missing-function-docstring
return len(self.dataset.get_train_sampler())

def exec_round(self, num_epochs: int) -> Tuple[Any, Any, Any, Any, float, float, float, np.array]:
def exec_round(self, num_epochs: int, round_id: int) -> Tuple[Any, Any, Any, Any, float, float, float, np.array]:
"""
Function as access point for the Federator Node to kick off a remote learning round on a client.
@param num_epochs: Number of epochs to run
Expand All @@ -157,9 +165,9 @@ def exec_round(self, num_epochs: int) -> Tuple[Any, Any, Any, Any, float, float,
training make-span, testing make-span, and confusion matrix.
@rtype: Tuple[Any, Any, Any, Any, float, float, float, np.array]
"""
self.logger.info(f"[EXEC] running {num_epochs} locally...")
start = time.time()

loss, weights = self.train(num_epochs)
loss, weights = self.train(num_epochs, round_id)
time_mark_between = time.time()
accuracy, test_loss, test_conf_matrix = self.test()

Expand Down
2 changes: 1 addition & 1 deletion fltk/core/federator.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def training_cb(fut: torch.Future, client_ref: LocalClient, client_weights, clie
client_ref.exp_data.append(c_record)

for client in selected_clients:
future = self.message_async(client.ref, Client.exec_round, num_epochs)
future = self.message_async(client.ref, Client.exec_round, num_epochs, com_round_id)
cb_factory(future, training_cb, client, client_weights, client_sizes, num_epochs)
self.logger.info(f'Request sent to client {client.name}')
training_futures.append(future)
Expand Down
8 changes: 7 additions & 1 deletion fltk/util/task/config/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# noinspection PyUnresolvedReferences
from typing import List, Optional, OrderedDict, Any, Union, Tuple, Type, Dict, MutableMapping, T

import deprecate
from dataclasses_json import dataclass_json, LetterCase, config
# noinspection PyProtectedMember
from torch.nn.modules.loss import _Loss
Expand Down Expand Up @@ -224,14 +225,19 @@ class LearningParameters:
Dataclass containing configuration parameters for the learning process itself. This includes the Federated learning
parameters as well as some system parameters like cuda.
"""
total_epochs: int
_total_epochs: int = field(metadata=config(field_name='total_epochs'))
cuda: bool
rounds: Optional[int] = None
epochs_per_round: Optional[int] = None
clients_per_round: Optional[int] = None
aggregation: Optional[Aggregations] = None
data_sampler: Optional[SamplerConfiguration] = None

@property
def total_epochs(self):
logging.warning('By default `total_epochs` is not used duruing Federated Learning. This attribute will be'
'changed in a comming release.')
return self.total_epochs
@dataclass_json(letter_case=LetterCase.CAMEL)
@dataclass(frozen=True)
class ExperimentConfiguration:
Expand Down

0 comments on commit 4852ed1

Please sign in to comment.