Skip to content

Commit

Permalink
Converter fix to allow unimplemented convertToOperatorDef (pytorch#13069
Browse files Browse the repository at this point in the history
)

Summary:
Pull Request resolved: pytorch#13069

simply a new fallback

Reviewed By: ZolotukhinM

Differential Revision: D10591414

fbshipit-source-id: 1ad8f16135a6c68b2df889101f06b736a3e4f7da
  • Loading branch information
bwasti authored and facebook-github-bot committed Oct 25, 2018
1 parent ef019a2 commit e0a8665
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions caffe2/opt/converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,14 @@ OperatorDef Converter::convertToOperatorDef(
const nom::repr::NeuralNetOperator* nnOp) {
auto* annotation = nnOp->getAnnotation();
// Default to using the stored operator.
if (isa<Caffe2Annotation>(annotation)) {
if (annotation && isa<Caffe2Annotation>(annotation)) {
return dyn_cast<Caffe2Annotation>(annotation)->getOperatorDef();
}
CAFFE_THROW("TODO: Cannot yet instantiate OperatorDef from nomnigraph");
LOG(WARNING)
<< "Cannot instantiate this OperatorDef from nomnigraph, falling back";
caffe2::OperatorDef op;
op.set_type(nnOp->getName());
return op;
}

std::vector<int> getKernelShape(std::map<std::string, caffe2::Argument> argMap) {
Expand Down Expand Up @@ -156,11 +160,13 @@ class ClipConverter : public Converter {
float max = std::numeric_limits<float>::max();

if (argMap.count("min")) {
min = static_cast<float>(argMap["min"].i());
CAFFE_ENFORCE(argMap["min"].has_f(), "Invalid 'min' argument");
min = static_cast<float>(argMap["min"].f());
}

if (argMap.count("max")) {
max = static_cast<float>(argMap["max"].i());
CAFFE_ENFORCE(argMap["max"].has_f(), "Invalid 'max' argument");
max = static_cast<float>(argMap["max"].f());
}

return util::make_unique<repr::Clip>(min, max);
Expand Down Expand Up @@ -363,13 +369,14 @@ repr::NNModule convertToNNModule(
caffe2::OperatorDef convertToOperatorDef(
const repr::NNGraph::NodeRef& instrNode) {
auto *nnOp = repr::nn::get<repr::NeuralNetOperator>(instrNode);
auto op_type = nnOp->getName();
auto *annotation = nnOp->getAnnotation();
caffe2::OperatorDef op;

if (ConverterRegistry()->Has(op.type())) {
op = ConverterRegistry()->Create(op.type())->convertToOperatorDef(nnOp);
if (ConverterRegistry()->Has(op_type)) {
op = ConverterRegistry()->Create(op_type)->convertToOperatorDef(nnOp);
} else if (!annotation) {
op.set_type(nnOp->getName());
op.set_type(op_type);
} else {
if (isa<Caffe2Annotation>(annotation)) {
auto c2_annotation = dyn_cast<Caffe2Annotation>(annotation);
Expand Down

0 comments on commit e0a8665

Please sign in to comment.