diff --git a/example/image-classification/common/fit.py b/example/image-classification/common/fit.py index 9412b6f9371b..8189a0a29bb9 100755 --- a/example/image-classification/common/fit.py +++ b/example/image-classification/common/fit.py @@ -238,7 +238,7 @@ def fit(args, network, data_loader, **kwargs): # AlexNet will not converge using Xavier initializer = mx.init.Normal() # VGG will not trend to converge using Xavier-Gaussian - elif 'vgg' in args.network: + elif args.network and 'vgg' in args.network: initializer = mx.init.Xavier() else: initializer = mx.init.Xavier(