Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix import error with python3
Browse files Browse the repository at this point in the history
fix alexnet and googlenet pooling pad error
remove first pool pad.
remove pad
  • Loading branch information
yajiedesign committed Jan 16, 2017
1 parent a0ec859 commit 0eb840a
Show file tree
Hide file tree
Showing 17 changed files with 5 additions and 7 deletions.
2 changes: 1 addition & 1 deletion example/image-classification/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def validate(self, attrs):
sys.exit(1)
try:
#check if the network exists
importlib.import_module('symbol.'+ args[0])
importlib.import_module('symbols.'+ args[0])
batch_size = int(args[1])
img_size = int(args[2])
return Network(name=args[0], batch_size=batch_size, img_size=img_size)
Expand Down
2 changes: 1 addition & 1 deletion example/image-classification/benchmark_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def get_symbol(network, batch_size):
if 'resnet' in network:
num_layers = int(network.split('-')[1])
network = 'resnet'
net = import_module('symbol.'+network)
net = import_module('symbols.'+network)
sym = net.get_symbol(num_classes = 1000,
image_shape = ','.join([str(i) for i in image_shape]),
num_layers = num_layers)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
with convolutions." arXiv preprint arXiv:1409.4842 (2014).
"""

import find_mxnet
import mxnet as mx

def ConvFactory(data, num_filter, kernel, stride=(1,1), pad=(0, 0), name=None, suffix=''):
Expand Down
2 changes: 1 addition & 1 deletion example/image-classification/train_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def download_cifar10():

# load network
from importlib import import_module
net = import_module('symbol.'+args.network)
net = import_module('symbols.'+args.network)
sym = net.get_symbol(**vars(args))

# train
Expand Down
2 changes: 1 addition & 1 deletion example/image-classification/train_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

# load network
from importlib import import_module
net = import_module('symbol.'+args.network)
net = import_module('symbols.'+args.network)
sym = net.get_symbol(**vars(args))

# train
Expand Down
2 changes: 1 addition & 1 deletion example/image-classification/train_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def get_mnist_iter(args, kv):

# load network
from importlib import import_module
net = import_module('symbol.'+args.network)
net = import_module('symbols.'+args.network)
sym = net.get_symbol(**vars(args))

# train
Expand Down

0 comments on commit 0eb840a

Please sign in to comment.