Skip to content

Commit

Permalink
Updated some stuff in models (#1115)
Browse files Browse the repository at this point in the history
  • Loading branch information
ShahriarSS authored and fmassa committed Jul 15, 2019
1 parent d84fee6 commit 8d580a1
Show file tree
Hide file tree
Showing 9 changed files with 108 additions and 27 deletions.
6 changes: 6 additions & 0 deletions test/test_cpp_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ def test_resnext50_32x4d(self):
def test_resnext101_32x8d(self):
process_model(models.resnext101_32x8d(), self.image, _C_tests.forward_resnext101_32x8d, 'ResNext101_32x8d')

def test_wide_resnet50_2(self):
process_model(models.wide_resnet50_2(), self.image, _C_tests.forward_wide_resnet50_2, 'WideResNet50_2')

def test_wide_resnet101_2(self):
process_model(models.wide_resnet101_2(), self.image, _C_tests.forward_wide_resnet101_2, 'WideResNet101_2')

def test_squeezenet1_0(self):
process_model(models.squeezenet1_0(self.pretrained), self.image,
_C_tests.forward_squeezenet1_0, 'Squeezenet1.0')
Expand Down
18 changes: 18 additions & 0 deletions test/test_models.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,16 @@ torch::Tensor forward_resnext101_32x8d(
torch::Tensor x) {
return forward_model<ResNext101_32x8d>(input_path, x);
}
torch::Tensor forward_wide_resnet50_2(
const std::string& input_path,
torch::Tensor x) {
return forward_model<WideResNet50_2>(input_path, x);
}
torch::Tensor forward_wide_resnet101_2(
const std::string& input_path,
torch::Tensor x) {
return forward_model<WideResNet101_2>(input_path, x);
}

torch::Tensor forward_squeezenet1_0(
const std::string& input_path,
Expand Down Expand Up @@ -168,6 +178,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"forward_resnext101_32x8d",
&forward_resnext101_32x8d,
"forward_resnext101_32x8d");
m.def(
"forward_wide_resnet50_2",
&forward_wide_resnet50_2,
"forward_wide_resnet50_2");
m.def(
"forward_wide_resnet101_2",
&forward_wide_resnet101_2,
"forward_wide_resnet101_2");

m.def(
"forward_squeezenet1_0", &forward_squeezenet1_0, "forward_squeezenet1_0");
Expand Down
4 changes: 4 additions & 0 deletions torchvision/csrc/convert_models/convert_models.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ int main(int argc, const char* argv[]) {
"resnext50_32x4d_python.pt", "resnext50_32x4d_cpp.pt");
convert_and_save_model<ResNext101_32x8d>(
"resnext101_32x8d_python.pt", "resnext101_32x8d_cpp.pt");
convert_and_save_model<WideResNet50_2>(
"wide_resnet50_2_python.pt", "wide_resnet50_2_cpp.pt");
convert_and_save_model<WideResNet101_2>(
"wide_resnet101_2_python.pt", "wide_resnet101_2_cpp.pt");

convert_and_save_model<SqueezeNet1_0>(
"squeezenet1_0_python.pt", "squeezenet1_0_cpp.pt");
Expand Down
55 changes: 40 additions & 15 deletions torchvision/csrc/models/mobilenet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,19 @@ namespace vision {
namespace models {
using Options = torch::nn::Conv2dOptions;

int64_t make_divisible(
double value,
int64_t divisor,
c10::optional<int64_t> min_value = {}) {
if (!min_value.has_value())
min_value = divisor;
auto new_value = std::max(
min_value.value(), (int64_t(value + divisor / 2) / divisor) * divisor);
if (new_value < .9 * value)
new_value += divisor;
return new_value;
}

struct ConvBNReLUImpl : torch::nn::SequentialImpl {
ConvBNReLUImpl(
int64_t in_planes,
Expand Down Expand Up @@ -69,28 +82,40 @@ struct MobileNetInvertedResidualImpl : torch::nn::Module {

TORCH_MODULE(MobileNetInvertedResidual);

MobileNetV2Impl::MobileNetV2Impl(int64_t num_classes, double width_mult) {
MobileNetV2Impl::MobileNetV2Impl(
int64_t num_classes,
double width_mult,
std::vector<std::vector<int64_t>> inverted_residual_settings,
int64_t round_nearest) {
using Block = MobileNetInvertedResidual;
int64_t input_channel = 32;
int64_t last_channel = 1280;

std::vector<std::vector<int64_t>> inverted_residual_settings = {
// t, c, n, s
{1, 16, 1, 1},
{6, 24, 2, 2},
{6, 32, 3, 2},
{6, 64, 4, 2},
{6, 96, 3, 1},
{6, 160, 3, 2},
{6, 320, 1, 1},
};

input_channel = int64_t(input_channel * width_mult);
this->last_channel = int64_t(last_channel * std::max(1.0, width_mult));
if (inverted_residual_settings.empty())
inverted_residual_settings = {
// t, c, n, s
{1, 16, 1, 1},
{6, 24, 2, 2},
{6, 32, 3, 2},
{6, 64, 4, 2},
{6, 96, 3, 1},
{6, 160, 3, 2},
{6, 320, 1, 1},
};

if (inverted_residual_settings[0].size() != 4) {
std::cerr << "inverted_residual_settings should contain 4-element vectors";
assert(false);
}

input_channel = make_divisible(input_channel * width_mult, round_nearest);
this->last_channel =
make_divisible(last_channel * std::max(1.0, width_mult), round_nearest);
features->push_back(ConvBNReLU(3, input_channel, 3, 2));

for (auto setting : inverted_residual_settings) {
auto output_channel = int64_t(setting[1] * width_mult);
auto output_channel =
make_divisible(setting[1] * width_mult, round_nearest);

for (int64_t i = 0; i < setting[2]; ++i) {
auto stride = i == 0 ? setting[3] : 1;
Expand Down
6 changes: 5 additions & 1 deletion torchvision/csrc/models/mobilenet.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@ struct VISION_API MobileNetV2Impl : torch::nn::Module {
int64_t last_channel;
torch::nn::Sequential features, classifier;

MobileNetV2Impl(int64_t num_classes = 1000, double width_mult = 1.0);
MobileNetV2Impl(
int64_t num_classes = 1000,
double width_mult = 1.0,
std::vector<std::vector<int64_t>> inverted_residual_settings = {},
int64_t round_nearest = 8);

torch::Tensor forward(torch::Tensor x);
};
Expand Down
10 changes: 10 additions & 0 deletions torchvision/csrc/models/resnet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,5 +145,15 @@ ResNext101_32x8dImpl::ResNext101_32x8dImpl(
bool zero_init_residual)
: ResNetImpl({3, 4, 23, 3}, num_classes, zero_init_residual, 32, 8) {}

WideResNet50_2Impl::WideResNet50_2Impl(
int64_t num_classes,
bool zero_init_residual)
: ResNetImpl({3, 4, 6, 3}, num_classes, zero_init_residual, 1, 64 * 2) {}

WideResNet101_2Impl::WideResNet101_2Impl(
int64_t num_classes,
bool zero_init_residual)
: ResNetImpl({3, 4, 23, 3}, num_classes, zero_init_residual, 1, 64 * 2) {}

} // namespace models
} // namespace vision
14 changes: 14 additions & 0 deletions torchvision/csrc/models/resnet.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,18 @@ struct VISION_API ResNext101_32x8dImpl : ResNetImpl<_resnetimpl::Bottleneck> {
bool zero_init_residual = false);
};

struct VISION_API WideResNet50_2Impl : ResNetImpl<_resnetimpl::Bottleneck> {
WideResNet50_2Impl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};

struct VISION_API WideResNet101_2Impl : ResNetImpl<_resnetimpl::Bottleneck> {
WideResNet101_2Impl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};

template <typename Block>
struct VISION_API ResNet : torch::nn::ModuleHolder<ResNetImpl<Block>> {
using torch::nn::ModuleHolder<ResNetImpl<Block>>::ModuleHolder;
Expand All @@ -229,6 +241,8 @@ TORCH_MODULE(ResNet101);
TORCH_MODULE(ResNet152);
TORCH_MODULE(ResNext50_32x4d);
TORCH_MODULE(ResNext101_32x8d);
TORCH_MODULE(WideResNet50_2);
TORCH_MODULE(WideResNet101_2);

} // namespace models
} // namespace vision
Expand Down
4 changes: 2 additions & 2 deletions torchvision/csrc/models/squeezenet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ SqueezeNetImpl::SqueezeNetImpl(double version, int64_t num_classes)
Fire(384, 64, 256, 256),
Fire(512, 64, 256, 256));
} else {
std::cerr << "Wrong version number is passed th SqueeseNet constructor!"
<< std::endl;
std::cerr << "Unsupported SqueezeNet version " << version
<< ". 1_0 or 1_1 expected" << std::endl;
assert(false);
}

Expand Down
18 changes: 9 additions & 9 deletions torchvision/csrc/models/vgg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,36 +79,36 @@ torch::Tensor VGGImpl::forward(torch::Tensor x) {
}

// clang-format off
static std::unordered_map<char, std::vector<int>> cfg = {
static std::unordered_map<char, std::vector<int>> cfgs = {
{'A', {64, -1, 128, -1, 256, 256, -1, 512, 512, -1, 512, 512, -1}},
{'B', {64, 64, -1, 128, 128, -1, 256, 256, -1, 512, 512, -1, 512, 512, -1}},
{'D', {64, 64, -1, 128, 128, -1, 256, 256, 256, -1, 512, 512, 512, -1, 512, 512, 512, -1}},
{'E', {64, 64, -1, 128, 128, -1, 256, 256, 256, 256, -1, 512, 512, 512, 512, -1, 512, 512, 512, 512, -1}}};
// clang-format on

VGG11Impl::VGG11Impl(int64_t num_classes, bool initialize_weights)
: VGGImpl(makeLayers(cfg['A']), num_classes, initialize_weights) {}
: VGGImpl(makeLayers(cfgs['A']), num_classes, initialize_weights) {}

VGG13Impl::VGG13Impl(int64_t num_classes, bool initialize_weights)
: VGGImpl(makeLayers(cfg['B']), num_classes, initialize_weights) {}
: VGGImpl(makeLayers(cfgs['B']), num_classes, initialize_weights) {}

VGG16Impl::VGG16Impl(int64_t num_classes, bool initialize_weights)
: VGGImpl(makeLayers(cfg['D']), num_classes, initialize_weights) {}
: VGGImpl(makeLayers(cfgs['D']), num_classes, initialize_weights) {}

VGG19Impl::VGG19Impl(int64_t num_classes, bool initialize_weights)
: VGGImpl(makeLayers(cfg['E']), num_classes, initialize_weights) {}
: VGGImpl(makeLayers(cfgs['E']), num_classes, initialize_weights) {}

VGG11BNImpl::VGG11BNImpl(int64_t num_classes, bool initialize_weights)
: VGGImpl(makeLayers(cfg['A'], true), num_classes, initialize_weights) {}
: VGGImpl(makeLayers(cfgs['A'], true), num_classes, initialize_weights) {}

VGG13BNImpl::VGG13BNImpl(int64_t num_classes, bool initialize_weights)
: VGGImpl(makeLayers(cfg['B'], true), num_classes, initialize_weights) {}
: VGGImpl(makeLayers(cfgs['B'], true), num_classes, initialize_weights) {}

VGG16BNImpl::VGG16BNImpl(int64_t num_classes, bool initialize_weights)
: VGGImpl(makeLayers(cfg['D'], true), num_classes, initialize_weights) {}
: VGGImpl(makeLayers(cfgs['D'], true), num_classes, initialize_weights) {}

VGG19BNImpl::VGG19BNImpl(int64_t num_classes, bool initialize_weights)
: VGGImpl(makeLayers(cfg['E'], true), num_classes, initialize_weights) {}
: VGGImpl(makeLayers(cfgs['E'], true), num_classes, initialize_weights) {}

} // namespace models
} // namespace vision

0 comments on commit 8d580a1

Please sign in to comment.