diff --git a/relay/include/relay/type_functor.h b/relay/include/relay/type_functor.h index 2ca82fbc12a6e..3dbca1513d94e 100644 --- a/relay/include/relay/type_functor.h +++ b/relay/include/relay/type_functor.h @@ -68,6 +68,14 @@ class TypeFunctor { virtual R VisitType_(const TensorTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + /*! + * \brief Because tensor shapes are not types per se, there is a separate call for those + * \param op The shape + * \param args Additional arguments + * \return The result of the call + */ + virtual R VisitShape_(const ShapeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitTypeDefault_(const Node* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->type_key(); return R(); diff --git a/relay/include/relay/type_visitor.h b/relay/include/relay/type_visitor.h index 1045b0cbebc71..ac6526f3d7b62 100644 --- a/relay/include/relay/type_visitor.h +++ b/relay/include/relay/type_visitor.h @@ -23,6 +23,8 @@ struct TypeVisitor : TypeFunctor { this->VisitType(op->boundType, args...); } void VisitType_(const TensorTypeNode* op, Args... args) override {} + + void VisitShape_(const ShapeNode *op, Args... args) override {} }; } // namespace relay