Skip to content

Commit

Permalink
Replace triple single quotes with triple double quotes
Browse files Browse the repository at this point in the history
  • Loading branch information
slyubomirsky committed Jul 5, 2019
1 parent cb03db8 commit b9d6d77
Showing 1 changed file with 50 additions and 50 deletions.
100 changes: 50 additions & 50 deletions python/tvm/relay/testing/py_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
]

class PythonConverter(ExprFunctor):
'''Functor for translating Relay programs into Python ASTs.'''
"""Functor for translating Relay programs into Python ASTs."""

def __init__(self, mod, target) -> None:
super().__init__()
Expand All @@ -59,12 +59,12 @@ def __init__(self, mod, target) -> None:


def convert(self, prog: Expr):
'''This method converts the passed Relay expression into a Python
"""This method converts the passed Relay expression into a Python
AST object with equivalent semantics.
The Python AST can be executed using exec(); it can be turned
into text and inspected using astor.
'''
"""
optimized = self.optimize(prog)

# start with conversion prelude (imports) and convert global defs
Expand All @@ -83,7 +83,7 @@ def convert(self, prog: Expr):


def optimize(self, prog: Expr):
'''Performs optimizations necessary to be able to generate code for prog.'''
"""Performs optimizations necessary to be able to generate code for prog."""
# unwrap tuple wrappers (some op calls produce them)
unwrapped = prog.astuple() if isinstance(prog, relay.TupleWrapper) else prog
assert relay.analysis.well_formed(unwrapped)
Expand All @@ -99,28 +99,28 @@ def optimize(self, prog: Expr):


def sanitize(self, name: str) -> str:
'''Removes any invalid characters (only underscores, numbers, and letters permitted)
"""Removes any invalid characters (only underscores, numbers, and letters permitted)
from the given name. Since we append a number and underscore to var names anyway,
it doesn't matter if the name is the empty string.'''
it doesn't matter if the name is the empty string."""
return re.sub(r'\W', '', name)


def generate_var_name(self, name_hint: str) -> str:
'''Generates a unique variable name starting from the hint.'''
"""Generates a unique variable name starting from the hint."""
name = '{}_var_{}'.format(self.sanitize(name_hint), self.var_no)
self.var_no += 1
return name


def generate_function_name(self, name_hint: str) -> str:
'''Generates a unique function name starting from the hint.'''
"""Generates a unique function name starting from the hint."""
name = '{}_fun_{}'.format(self.sanitize(name_hint), self.fun_no)
self.fun_no += 1
return name


def get_var_name(self, var: Expr) -> str:
'''Returns the var name for the given Realy variable.'''
"""Returns the var name for the given Realy variable."""
if var in self.var_map:
return self.var_map[var]
name = self.generate_var_name(var.name_hint)
Expand All @@ -129,15 +129,15 @@ def get_var_name(self, var: Expr) -> str:


def include_var(self, var: Expr, assign=False):
'''Returns a variable AST node for the given Relay var depending on
whether it must appear in an assignment or not.'''
"""Returns a variable AST node for the given Relay var depending on
whether it must appear in an assignment or not."""
name = self.get_var_name(var)
return Name(name, Store() if assign else Load())


def parse_name(self, name: str):
'''Given the name of a Python method with dots (e.g., 'relay.var'),
returns an appropriate AST object corresponding to that name.'''
"""Given the name of a Python method with dots (e.g., 'relay.var'),
returns an appropriate AST object corresponding to that name."""
attributes = name.split('.')
ret = Name(attributes[0], Load())
for i in range(len(attributes) - 1):
Expand All @@ -146,8 +146,8 @@ def parse_name(self, name: str):


def parse_numpy_array(self, arr):
'''Given a Numpy array, produces an appropriate Python array
or numerical literal representing its contents.'''
"""Given a Numpy array, produces an appropriate Python array
or numerical literal representing its contents."""
parse_single = lambda i: NameConstant(i) if isinstance(i, bool) else Num(i)
if arr.ndim == 0:
return parse_single(arr.item())
Expand All @@ -161,8 +161,8 @@ def parse_numpy_array(self, arr):


def convert_fields(self, fields: [Expr]):
'''Given a list of call args or tuple fields, converts
each and returns their ASTs and their defs lists (in order).'''
"""Given a list of call args or tuple fields, converts
each and returns their ASTs and their defs lists (in order)."""
bodies = []
defs = []
for field in fields:
Expand All @@ -173,16 +173,16 @@ def convert_fields(self, fields: [Expr]):


def convert_to_thunk(self, name_hint: str, expr: Expr):
'''Wraps the passed expression in a thunk.'''
"""Wraps the passed expression in a thunk."""
body, defs = self.visit(expr)
thunk_name = self.generate_function_name(name_hint)
thunk = self.create_def(thunk_name, [], defs + [Return(body)])
return (thunk, thunk_name)


def convert_func_node(self, func: Function, name_var=None):
'''Converts the given Relay function into a Python function, with
special for named functions (locally or globally)'''
"""Converts the given Relay function into a Python function, with
special for named functions (locally or globally)"""
if name_var is None:
func_name = self.generate_function_name('_anon_func')
if isinstance(name_var, GlobalVar):
Expand All @@ -197,8 +197,8 @@ def convert_func_node(self, func: Function, name_var=None):


def convert_module(self):
'''Converts all the global functions defined in the module and returns
them as a list of definitions'''
"""Converts all the global functions defined in the module and returns
them as a list of definitions"""
defs = []
for var, func in self.mod.functions.items():
# optimize the definition so any operators used are lowered
Expand All @@ -209,12 +209,12 @@ def convert_module(self):


def create_call(self, func_name: str, arguments):
'''Creates a simple function call.'''
"""Creates a simple function call."""
return ast.Call(self.parse_name(func_name), arguments, [])


def create_def(self, func_name: str, arguments: [str], body):
'''Wrapper over function definition AST node, whose constructor is inconvenient.'''
"""Wrapper over function definition AST node, whose constructor is inconvenient."""
return ast.FunctionDef(
func_name,
ast.arguments([ast.arg(argument, None)
Expand All @@ -224,9 +224,9 @@ def create_def(self, func_name: str, arguments: [str], body):


def create_op_call(self, op: Function, relay_args, py_args):
'''Lowers the passed primitive function, registers it in TVM's
"""Lowers the passed primitive function, registers it in TVM's
global compiler, and produces a call to the lowered function in
the generated Python code.'''
the generated Python code."""

# compile the function and register globally
cc_key = compile_engine.CCacheKey(op, self.tgt)
Expand All @@ -237,8 +237,8 @@ def create_op_call(self, op: Function, relay_args, py_args):
tvm.register_func(op_name, jitted)

def convert_input(py_input, arg_type):
'''Use the types of the function arguments to determine whether we expect
a tensor or tuple (returns list of inputs to the lowered op call)'''
"""Use the types of the function arguments to determine whether we expect
a tensor or tuple (returns list of inputs to the lowered op call)"""
# equivalent: input.data
if isinstance(arg_type, relay.TensorType):
return [ast.Attribute(py_input, 'data', Load())]
Expand All @@ -254,9 +254,9 @@ def convert_input(py_input, arg_type):
return ret

def convert_output(ret_type):
'''Use the function return type to produce auxiliary variables to store outputs.
"""Use the function return type to produce auxiliary variables to store outputs.
Returns ([assignments of output vars], [extra arguments to pass to op call],
expression collecting output)'''
expression collecting output)"""
if isinstance(ret_type, relay.TensorType):
output_var_name = self.generate_var_name('_out')
output_var = Name(output_var_name, Load())
Expand Down Expand Up @@ -303,9 +303,9 @@ def convert_output(ret_type):


def create_match_check(self, pattern: Pattern, data):
'''Given an ADT match pattern and a (Python) expression pointing to
"""Given an ADT match pattern and a (Python) expression pointing to
an ADT value, this generates a Python expression that checks if the
ADT value matches the given pattern (returning True or False).'''
ADT value matches the given pattern (returning True or False)."""

# wildcard or var match everything
if isinstance(pattern, (relay.PatternWildcard, relay.PatternVar)):
Expand Down Expand Up @@ -340,14 +340,14 @@ def create_match_check(self, pattern: Pattern, data):


def create_match_clause_body(self, pattern: Pattern, body: Expr):
'''Given a match clause pattern and a clause body,
"""Given a match clause pattern and a clause body,
generates a Python function that when called with an ADT
that matches the pattern, returns the result of evaluating
the clause body. This function returns a function definition
and the name of the generated function.'''
and the name of the generated function."""

def collect_var_assignments(pat, val):
'''This helper function ensures that the pattern is used to
"""This helper function ensures that the pattern is used to
properly assign all subfields of the given AST for use
in the clause body
Expand All @@ -356,7 +356,7 @@ def collect_var_assignments(pat, val):
we would want to have
v = a.fields[0]
w = a.fields[2].fields[0]
'''
"""
if isinstance(pat, relay.PatternWildcard):
return []
if isinstance(pat, relay.PatternVar):
Expand Down Expand Up @@ -407,13 +407,13 @@ def visit_let(self, letexp: Expr):
# To properly account for scoping and ensure that the entire node produces an expression,
# we translate the let binding as a function that we call with the value we intend to bind.
# Yes, this is somewhat ugly.
'''
"""
let var = value in body
=======================
def let_thunk(var):
return body
let_thunk(value)
'''
"""
bind_body, bind_defs = self.visit(letexp.body)

func_name = self.generate_function_name('_let_func')
Expand Down Expand Up @@ -459,9 +459,9 @@ def visit_if(self, if_block: Expr):


def visit_constant(self, constant: Expr):
'''Proceeds by converting constant value to a numpy array
"""Proceeds by converting constant value to a numpy array
and converting it to the appropriate value in the generated
code (whether it be a Python scalar or a Numpy array)'''
code (whether it be a Python scalar or a Numpy array)"""
value = constant.data.asnumpy()
const_expr = ast.Call(ast.Attribute(Name('numpy', Load()), 'array', Load()),
[self.parse_numpy_array(value)],
Expand All @@ -476,8 +476,8 @@ def visit_function(self, func: Expr):


def visit_call(self, call: Expr):
'''For calls, we must distinguish between ordinary functions,
operators, and constructor calls.'''
"""For calls, we must distinguish between ordinary functions,
operators, and constructor calls."""
func = call.op
fields, field_defs = self.convert_fields(call.args)

Expand Down Expand Up @@ -515,11 +515,11 @@ def visit_ref_read(self, read: Expr):


def visit_ref_write(self, write: Expr):
'''For writing refs, we wrap the update in a thunk
"""For writing refs, we wrap the update in a thunk
(returning an empty tuple to match Relay's semantics)
that we execute at the right time. This ensures such assignments
can be properly nested, since assignments are statements
in Python but expressions in Relay'''
in Python but expressions in Relay"""
ref, ref_defs = self.visit(write.ref)
val, val_defs = self.visit(write.value)
thunk_name = self.generate_function_name('_ref_write_thunk')
Expand All @@ -533,11 +533,11 @@ def visit_ref_write(self, write: Expr):


def visit_match(self, match: Expr):
'''For matches, we wrap the entire expression in a thunk
"""For matches, we wrap the entire expression in a thunk
because it is easiest to implement them using if statements.
For each clause, we generate a function that checks if the
pattern matches. If yes, we call a function that assigns
the variables appropriately and invokes the clause body.'''
the variables appropriately and invokes the clause body."""
data, defs = self.visit(match.data)
data_var = self.generate_var_name('_match_data')

Expand Down Expand Up @@ -571,16 +571,16 @@ def visit_op(self, _):


def to_python(expr: Expr, mod=None, target=tvm.target.create('llvm')):
'''Converts the given Relay expression into a Python script (as a Python AST object).
For easiest debugging, import the astor package and use to_source().'''
"""Converts the given Relay expression into a Python script (as a Python AST object).
For easiest debugging, import the astor package and use to_source()."""
mod = mod if mod is not None else relay.Module()
converter = PythonConverter(mod, target)
return converter.convert(expr)


def run_as_python(expr: Expr, mod=None, target=tvm.target.create('llvm')):
'''Converts the given Relay expression into a Python script and
executes it.'''
"""Converts the given Relay expression into a Python script and
executes it."""
mod = mod if mod is not None else relay.Module()
py_ast = to_python(expr, mod, target)
code = compile(py_ast, '<string>', 'exec')
Expand Down

0 comments on commit b9d6d77

Please sign in to comment.