You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Part 2: (NEW PREDICTION TASK API) Refactoring all the prediction tasks to use the convention required by the Trainer class:i.e. compute loss inside the forward method + return dict with {"loss", "labels" and "predictions"} during training/evaluation mode.
#544
Update base PredictionTask class:
⁃ Add the targets argument to the forward method
⁃ Move compute_loss() inside the forward method
⁃ Return output (dict with the three keys or torch.Tensor) based on training and testing flags
⁃ Update calculate metrics to pass the targets to the forward call
Update Head and Model classes to support the new convention in their forward method call + calculate_metrics
Update the fit method [Done]: loss is computed inside the forward call + add flag compute_metrics=True to control whether to compute metrics during training or not. Replace the compute_loss call loss = self.compute_loss(x, y) by :
outputs = self(x, y, training=True)
loss = outputs['loss']
if compute_metrics=True:
self.calculate_metrics(outputs['predictions'], outputs['labels'])
Update the failing unit tests
The text was updated successfully, but these errors were encountered:
This part of refactoring includes 4 parts:
Update base PredictionTask class:
⁃ Add the
targets
argument to the forward method⁃ Move compute_loss() inside the forward method
⁃ Return output (dict with the three keys or torch.Tensor) based on training and testing flags
⁃ Update calculate metrics to pass the
targets
to the forward callUpdate Head and Model classes to support the new convention in their forward method call + calculate_metrics
Update the fit method [Done]: loss is computed inside the forward call + add flag
compute_metrics=True
to control whether to compute metrics during training or not. Replace thecompute_loss
callloss = self.compute_loss(x, y)
by :The text was updated successfully, but these errors were encountered: