Question on ResNet C++ example with Cifar10 dataset #14106
Replies: 5 comments
-
@leleamol Could you please help with this guy? |
Beta Was this translation helpful? Give feedback.
-
@xiaolin-cheng , |
Beta Was this translation helpful? Give feedback.
-
Hi @leleamol, I used the training code provided in the example resnet.cpp. Thank you! |
Beta Was this translation helpful? Give feedback.
-
@xiaolin-cheng, what learning rate are you using? It can be that you are using a learning rate that is too high or too low. |
Beta Was this translation helpful? Give feedback.
-
@mxnet-label-bot add [Pending Requester Info] |
Beta Was this translation helpful? Give feedback.
-
Hello all,
I modified the ResNet C++ example to be a ResNet18 model (4 levels and 2 blocks) working with Cifar10 dataset, but unfortunately it didn't work (validation accuracy always ~0.1). Only ResNet6 (1 level, 2 blocks) and ResNet10 (2 levels, 2 blocks) worked. My code is the following. Could you help me point out where I did wrong? Thank you very much!
Symbol getConv(const std::string & name, Symbol data, int num_filter, Shape kernel, Shape stride, Shape pad, bool with_relu, mx_float bn_momentum)
{
Symbol conv_w(name + "_w");
Symbol conv = ConvolutionNoBias(name, data, conv_w,
kernel, num_filter, stride, Shape(1, 1),
pad, 1, 512);
Symbol gamma(name + "_gamma");
Symbol beta(name + "_beta");
Symbol mmean(name + "_mmean");
Symbol mvar(name + "_mvar");
Symbol bn = BatchNorm(name + "_bn", conv, gamma, beta, mmean, mvar, 2e-5, bn_momentum, false);
if (with_relu) {
return Activation(name + "_relu", bn, "relu");
} else {
return bn;
}
}
Symbol makeBlock(const std::string & name, Symbol data, int num_filter, bool dim_match, mx_float bn_momentum)
{
Shape stride;
if (dim_match) {
stride = Shape(1, 1);
} else {
stride = Shape(2, 2);
}
Symbol conv1 = getConv(name + "_conv1", data, num_filter,
Shape(3, 3), stride, Shape(1, 1),
true, bn_momentum);
Symbol conv2 = getConv(name + "_conv2", conv1, num_filter,
Shape(3, 3), Shape(1, 1), Shape(1, 1),
false, bn_momentum);
Symbol shortcut;
if (dim_match) {
shortcut = data;
} else {
Symbol shortcut_w(name + "_proj_w");
shortcut = ConvolutionNoBias(name + "_proj", data, shortcut_w,
Shape(2, 2), num_filter,
Shape(2, 2), Shape(1, 1), Shape(0, 0),
1, 512);
}
Symbol fused = shortcut + conv2;
return Activation(name + "_relu", fused, "relu");
}
Symbol getBody(Symbol data, int num_level, int num_block, int num_filter, mx_float bn_momentum) { for (int level = 0; level < num_level; level++)
{
for (int block = 0; block < num_block; block++) {
data = makeBlock("level" + std::to_string(level + 1) + "_block" + std::to_string(block + 1),
data, num_filter * (std::pow(2, level)),
(level == 0 || block > 0), bn_momentum);
}
}
return data;
}
Symbol ResNetSymbol(int num_class, int num_level = 2, int num_block = 2, int num_filter = 64, mx_float bn_momentum = 0.9)
{
// data and label
Symbol data = Symbol::Variable("data");
Symbol data_label = Symbol::Variable("data_label");
//===== top =====//
Symbol conv = getConv("conv0", data, num_filter,
Shape(7, 7), Shape(2, 2), Shape(3, 3),
true, bn_momentum);
Symbol max_pool = Pooling("max_pool", conv, Shape(3, 3), PoolingPoolType::kMax,
false, false, PoolingPoolingConvention::kValid,
Shape(2, 2), Shape(1, 1));
//===== body =====//
Symbol body = getBody(conv, num_level, num_block, num_filter, bn_momentum);
//===== pool and fc =====//
Symbol avg_pool = Pooling("avg_pool", body, Shape(7, 7), PoolingPoolType::kAvg,
true, false, PoolingPoolingConvention::kValid);
Symbol flatten = Flatten("flatten", avg_pool);
Symbol fc_w("fc_w"), fc_b("fc_b");
Symbol fc = FullyConnected("fc", flatten, fc_w, fc_b, num_class);
return SoftmaxOutput("softmax", fc, data_label);
}
Beta Was this translation helpful? Give feedback.
All reactions