Skip to content

Commit

Permalink
Code style
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Sep 19, 2023
1 parent 9dd78cf commit c67e51c
Showing 1 changed file with 13 additions and 14 deletions.
27 changes: 13 additions & 14 deletions examples/torch/classification/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,20 +728,19 @@ def validate(val_loader, model, criterion, config, epoch=0, log_validation_info=
)
)

if is_main_process():
if log_validation_info:
config.tb.add_scalar("val/loss", losses.avg, len(val_loader) * epoch)
config.tb.add_scalar("val/top1", top1.avg, len(val_loader) * epoch)
config.tb.add_scalar("val/top5", top5.avg, len(val_loader) * epoch)
config.mlflow.safe_call("log_metric", "val/loss", float(losses.avg), epoch)
config.mlflow.safe_call("log_metric", "val/top1", float(top1.avg), epoch)
config.mlflow.safe_call("log_metric", "val/top5", float(top5.avg), epoch)

logger.info(" * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}\n".format(top1=top1, top5=top5))

if config.metrics_dump is not None:
acc = top1.avg / 100
write_metrics(acc, config.metrics_dump)
if is_main_process() and log_validation_info:
config.tb.add_scalar("val/loss", losses.avg, len(val_loader) * epoch)
config.tb.add_scalar("val/top1", top1.avg, len(val_loader) * epoch)
config.tb.add_scalar("val/top5", top5.avg, len(val_loader) * epoch)
config.mlflow.safe_call("log_metric", "val/loss", float(losses.avg), epoch)
config.mlflow.safe_call("log_metric", "val/top1", float(top1.avg), epoch)
config.mlflow.safe_call("log_metric", "val/top5", float(top5.avg), epoch)

logger.info(" * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}\n".format(top1=top1, top5=top5))

if is_main_process() and config.metrics_dump is not None:
acc = top1.avg / 100
write_metrics(acc, config.metrics_dump)

return top1.avg, top5.avg, losses.avg

Expand Down

0 comments on commit c67e51c

Please sign in to comment.