Skip to content

Commit

Permalink
Add visitor function for tensor shape (but won't dispatch through sta…
Browse files Browse the repository at this point in the history
…ndard entry point
  • Loading branch information
slyubomirsky authored and jroesch committed Aug 16, 2018
1 parent 731361f commit 9db7b3e
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 0 deletions.
8 changes: 8 additions & 0 deletions relay/include/relay/type_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ class TypeFunctor<R(const Type& n, Args...)> {
virtual R VisitType_(const TensorTypeNode* op,
Args... args) TYPE_FUNCTOR_DEFAULT;

/*!
* \brief Because tensor shapes are not types per se, there is a separate call for those
* \param op The shape
* \param args Additional arguments
* \return The result of the call
*/
virtual R VisitShape_(const ShapeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;

virtual R VisitTypeDefault_(const Node* op, Args...) {
LOG(FATAL) << "Do not have a default for " << op->type_key();
return R();
Expand Down
2 changes: 2 additions & 0 deletions relay/include/relay/type_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ struct TypeVisitor : TypeFunctor<void(const Type& n, Args ...)> {
this->VisitType(op->boundType, args...);
}
void VisitType_(const TensorTypeNode* op, Args... args) override {}

void VisitShape_(const ShapeNode *op, Args... args) override {}
};

} // namespace relay
Expand Down

0 comments on commit 9db7b3e

Please sign in to comment.