Skip to content

Commit

Permalink
[Refactor][Relay] Refactor Relay Python to use new FFI (apache#5077)
Browse files Browse the repository at this point in the history
* refactor relay python

* revert relay/ir/*.py to relay

* Address comments

* remove direct access to analysis and transform namespace
  • Loading branch information
zhiics authored and Trevor Morris committed Apr 16, 2020
1 parent 54155b9 commit 5eb34cb
Show file tree
Hide file tree
Showing 63 changed files with 349 additions and 466 deletions.
4 changes: 0 additions & 4 deletions docs/api/python/relay/base.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,6 @@ tvm.relay.base
--------------
.. automodule:: tvm.relay.base

.. autofunction:: tvm.relay.base.register_relay_node

.. autofunction:: tvm.relay.base.register_relay_attr_node

.. autoclass:: tvm.relay.base.RelayNode
:members:

Expand Down
46 changes: 23 additions & 23 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,51 +19,50 @@
import os
from sys import setrecursionlimit

from . import call_graph
from . import base
from . import ty
from . import expr
from . import type_functor
from . import expr_functor
from . import adt
from . import analysis
from . import prelude
from . import loops
from . import scope_builder
from . import parser

from . import transform
from . import analysis
from .analysis import alpha_equal
from .build_module import build, create_executor, optimize
from .transform import build_config
from . import prelude
from . import parser
from . import debug
from . import param_dict
from . import feature
from .backend import vm

# Root operators
from .op import Op
from .op import nn
from .op import image
from .op import annotation
from .op import vision
from .op import contrib
from .op.reduce import *
from .op.tensor import *
from .op.transform import *
from .op.algorithm import *
from . import nn
from . import annotation
from . import vision
from . import contrib
from . import image
from . import frontend
from . import backend
from . import quantize

# Dialects
from . import qnn

from .scope_builder import ScopeBuilder
# Load Memory pass
from . import memory_alloc

# Required to traverse large programs
setrecursionlimit(10000)

# Span
Span = base.Span
SourceName = base.SourceName

# Type
Type = ty.Type
Expand Down Expand Up @@ -98,6 +97,7 @@
RefWrite = expr.RefWrite

# ADT
Pattern = adt.Pattern
PatternWildcard = adt.PatternWildcard
PatternVar = adt.PatternVar
PatternConstructor = adt.PatternConstructor
Expand All @@ -111,9 +111,6 @@
var = expr.var
const = expr.const
bind = expr.bind
module_pass = transform.module_pass
function_pass = transform.function_pass
alpha_equal = analysis.alpha_equal

# TypeFunctor
TypeFunctor = type_functor.TypeFunctor
Expand All @@ -125,6 +122,15 @@
ExprVisitor = expr_functor.ExprVisitor
ExprMutator = expr_functor.ExprMutator

# Prelude
Prelude = prelude.Prelude

# Scope builder
ScopeBuilder = scope_builder.ScopeBuilder

module_pass = transform.module_pass
function_pass = transform.function_pass

# Parser
fromtext = parser.fromtext

Expand All @@ -139,9 +145,3 @@
ModulePass = transform.ModulePass
FunctionPass = transform.FunctionPass
Sequential = transform.Sequential

# Feature
Feature = feature.Feature

# CallGraph
CallGraph = call_graph.CallGraph
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI exposing the passes for Relay program analysis."""
"""FFI APIs for Relay program IR."""
import tvm._ffi

tvm._ffi._init_api("relay._analysis", __name__)
tvm._ffi._init_api("relay.ir", __name__)
30 changes: 16 additions & 14 deletions python/tvm/relay/adt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-import
"""Algebraic data types in Relay."""
from tvm.ir import Constructor, TypeData
from tvm.runtime import Object
import tvm._ffi

from .base import RelayNode, register_relay_node, Object
from . import _make
from .base import RelayNode
from . import _ffi_api
from .ty import Type
from .expr import ExprWithOp, RelayExpr, Call

Expand All @@ -28,7 +30,7 @@ class Pattern(RelayNode):
"""Base type for pattern matching constructs."""


@register_relay_node
@tvm._ffi.register_object("relay.PatternWildcard")
class PatternWildcard(Pattern):
"""Wildcard pattern in Relay: Matches any ADT and binds nothing."""

Expand All @@ -44,10 +46,10 @@ def __init__(self):
wildcard: PatternWildcard
a wildcard pattern.
"""
self.__init_handle_by_constructor__(_make.PatternWildcard)
self.__init_handle_by_constructor__(_ffi_api.PatternWildcard)


@register_relay_node
@tvm._ffi.register_object("relay.PatternVar")
class PatternVar(Pattern):
"""Variable pattern in Relay: Matches anything and binds it to the variable."""

Expand All @@ -63,10 +65,10 @@ def __init__(self, var):
pv: PatternVar
A variable pattern.
"""
self.__init_handle_by_constructor__(_make.PatternVar, var)
self.__init_handle_by_constructor__(_ffi_api.PatternVar, var)


@register_relay_node
@tvm._ffi.register_object("relay.PatternConstructor")
class PatternConstructor(Pattern):
"""Constructor pattern in Relay: Matches an ADT of the given constructor, binds recursively."""

Expand All @@ -88,10 +90,10 @@ def __init__(self, constructor, patterns=None):
"""
if patterns is None:
patterns = []
self.__init_handle_by_constructor__(_make.PatternConstructor, constructor, patterns)
self.__init_handle_by_constructor__(_ffi_api.PatternConstructor, constructor, patterns)


@register_relay_node
@tvm._ffi.register_object("relay.PatternTuple")
class PatternTuple(Pattern):
"""Constructor pattern in Relay: Matches a tuple, binds recursively."""

Expand All @@ -111,10 +113,10 @@ def __init__(self, patterns=None):
"""
if patterns is None:
patterns = []
self.__init_handle_by_constructor__(_make.PatternTuple, patterns)
self.__init_handle_by_constructor__(_ffi_api.PatternTuple, patterns)


@register_relay_node
@tvm._ffi.register_object("relay.Clause")
class Clause(Object):
"""Clause for pattern matching in Relay."""

Expand All @@ -133,10 +135,10 @@ def __init__(self, lhs, rhs):
clause: Clause
The Clause.
"""
self.__init_handle_by_constructor__(_make.Clause, lhs, rhs)
self.__init_handle_by_constructor__(_ffi_api.Clause, lhs, rhs)


@register_relay_node
@tvm._ffi.register_object("relay.Match")
class Match(ExprWithOp):
"""Pattern matching expression in Relay."""

Expand All @@ -160,4 +162,4 @@ def __init__(self, data, clauses, complete=True):
match: tvm.relay.Expr
The match expression.
"""
self.__init_handle_by_constructor__(_make.Match, data, clauses, complete)
self.__init_handle_by_constructor__(_ffi_api.Match, data, clauses, complete)
28 changes: 28 additions & 0 deletions python/tvm/relay/analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import, redefined-builtin, invalid-name
"""The Relay IR namespace containing the analysis passes."""
# Analysis passes
from .analysis import *

# Call graph
from . import call_graph

# Feature
from . import feature

CallGraph = call_graph.CallGraph
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
"""The interface of expr function exposed from C++."""
"""FFI APIs for Relay program analysis."""
import tvm._ffi

tvm._ffi._init_api("relay._base", __name__)
tvm._ffi._init_api("relay.analysis", __name__)
Loading

0 comments on commit 5eb34cb

Please sign in to comment.