Skip to content

Commit

Permalink
feat: add crnn resnet native template
Browse files Browse the repository at this point in the history
  • Loading branch information
Bycob authored and mergify[bot] committed Jun 22, 2022
1 parent d1fd335 commit ec1f8ad
Show file tree
Hide file tree
Showing 13 changed files with 807 additions and 76 deletions.
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ if (USE_TORCH)
backends/torch/torchinputconns.cc
backends/torch/native/templates/nbeats.cc
backends/torch/native/templates/vit.cc
backends/torch/native/templates/crnn_head.cc
backends/torch/native/templates/crnn.cc
backends/torch/native/templates/visformer.cc
backends/torch/native/templates/ttransformer.cc
Expand Down
5 changes: 4 additions & 1 deletion src/backends/torch/native/native_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,16 @@ namespace dd

if (tdef.find("vit") != std::string::npos)
{

return new ViT(inputc, template_params);
}
else if (tdef.find("visformer") != std::string::npos)
{
return new Visformer(inputc, template_params);
}
else if (tdef.find("crnn") != std::string::npos)
{
return new CRNN(inputc, template_params);
}
else if (VisionModelsFactory::is_vision_template(tdef))
{
return VisionModelsFactory::from_template(tdef, template_params,
Expand Down
11 changes: 10 additions & 1 deletion src/backends/torch/native/native_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "./templates/vit.h"
#include "./templates/visformer.h"
#include "./templates/ttransformer.h"
#include "./templates/crnn.hpp"
#include "../torchinputconns.h"
#include "apidata.h"
#include "templates/vision_models.h"
Expand All @@ -47,7 +48,8 @@ namespace dd
if (tdef.find("nbeats") != std::string::npos
|| tdef.find("vit") != std::string::npos
|| tdef.find("visformer") != std::string::npos
|| tdef.find("ttransformer") != std::string::npos)
|| tdef.find("ttransformer") != std::string::npos
|| tdef.find("crnn") != std::string::npos)
return true;
else if (VisionModelsFactory::is_vision_template(tdef))
return true;
Expand All @@ -62,6 +64,13 @@ namespace dd
return false;
}

static bool is_ctc(const std::string &tdef)
{
if (tdef.find("crnn") != std::string::npos)
return true;
return false;
}

/**
* if tdef is a vision template, returns true if it supports greyscale
* input.
Expand Down
Loading

0 comments on commit ec1f8ad

Please sign in to comment.