diff --git a/training/cifar/cifar10_deepspeed.py b/training/cifar/cifar10_deepspeed.py index a28bdcad0..da82e60db 100755 --- a/training/cifar/cifar10_deepspeed.py +++ b/training/cifar/cifar10_deepspeed.py @@ -343,7 +343,7 @@ def create_moe_param_groups(model): # We simply have to loop over our data iterator, and feed the inputs to the # network and optimize. -for epoch in range(2): # loop over the dataset multiple times +for epoch in range(args.epochs): # loop over the dataset multiple times running_loss = 0.0 for i, data in enumerate(trainloader):