diff --git a/python/tvm/relay/testing/py_converter.py b/python/tvm/relay/testing/py_converter.py index cd06a8ba1df5f..d3cc765516bb3 100644 --- a/python/tvm/relay/testing/py_converter.py +++ b/python/tvm/relay/testing/py_converter.py @@ -46,7 +46,7 @@ ] class PythonConverter(ExprFunctor): - '''Functor for translating Relay programs into Python ASTs.''' + """Functor for translating Relay programs into Python ASTs.""" def __init__(self, mod, target) -> None: super().__init__() @@ -59,12 +59,12 @@ def __init__(self, mod, target) -> None: def convert(self, prog: Expr): - '''This method converts the passed Relay expression into a Python + """This method converts the passed Relay expression into a Python AST object with equivalent semantics. The Python AST can be executed using exec(); it can be turned into text and inspected using astor. - ''' + """ optimized = self.optimize(prog) # start with conversion prelude (imports) and convert global defs @@ -83,7 +83,7 @@ def convert(self, prog: Expr): def optimize(self, prog: Expr): - '''Performs optimizations necessary to be able to generate code for prog.''' + """Performs optimizations necessary to be able to generate code for prog.""" # unwrap tuple wrappers (some op calls produce them) unwrapped = prog.astuple() if isinstance(prog, relay.TupleWrapper) else prog assert relay.analysis.well_formed(unwrapped) @@ -99,28 +99,28 @@ def optimize(self, prog: Expr): def sanitize(self, name: str) -> str: - '''Removes any invalid characters (only underscores, numbers, and letters permitted) + """Removes any invalid characters (only underscores, numbers, and letters permitted) from the given name. Since we append a number and underscore to var names anyway, - it doesn't matter if the name is the empty string.''' + it doesn't matter if the name is the empty string.""" return re.sub(r'\W', '', name) def generate_var_name(self, name_hint: str) -> str: - '''Generates a unique variable name starting from the hint.''' + """Generates a unique variable name starting from the hint.""" name = '{}_var_{}'.format(self.sanitize(name_hint), self.var_no) self.var_no += 1 return name def generate_function_name(self, name_hint: str) -> str: - '''Generates a unique function name starting from the hint.''' + """Generates a unique function name starting from the hint.""" name = '{}_fun_{}'.format(self.sanitize(name_hint), self.fun_no) self.fun_no += 1 return name def get_var_name(self, var: Expr) -> str: - '''Returns the var name for the given Realy variable.''' + """Returns the var name for the given Realy variable.""" if var in self.var_map: return self.var_map[var] name = self.generate_var_name(var.name_hint) @@ -129,15 +129,15 @@ def get_var_name(self, var: Expr) -> str: def include_var(self, var: Expr, assign=False): - '''Returns a variable AST node for the given Relay var depending on - whether it must appear in an assignment or not.''' + """Returns a variable AST node for the given Relay var depending on + whether it must appear in an assignment or not.""" name = self.get_var_name(var) return Name(name, Store() if assign else Load()) def parse_name(self, name: str): - '''Given the name of a Python method with dots (e.g., 'relay.var'), - returns an appropriate AST object corresponding to that name.''' + """Given the name of a Python method with dots (e.g., 'relay.var'), + returns an appropriate AST object corresponding to that name.""" attributes = name.split('.') ret = Name(attributes[0], Load()) for i in range(len(attributes) - 1): @@ -146,8 +146,8 @@ def parse_name(self, name: str): def parse_numpy_array(self, arr): - '''Given a Numpy array, produces an appropriate Python array - or numerical literal representing its contents.''' + """Given a Numpy array, produces an appropriate Python array + or numerical literal representing its contents.""" parse_single = lambda i: NameConstant(i) if isinstance(i, bool) else Num(i) if arr.ndim == 0: return parse_single(arr.item()) @@ -161,8 +161,8 @@ def parse_numpy_array(self, arr): def convert_fields(self, fields: [Expr]): - '''Given a list of call args or tuple fields, converts - each and returns their ASTs and their defs lists (in order).''' + """Given a list of call args or tuple fields, converts + each and returns their ASTs and their defs lists (in order).""" bodies = [] defs = [] for field in fields: @@ -173,7 +173,7 @@ def convert_fields(self, fields: [Expr]): def convert_to_thunk(self, name_hint: str, expr: Expr): - '''Wraps the passed expression in a thunk.''' + """Wraps the passed expression in a thunk.""" body, defs = self.visit(expr) thunk_name = self.generate_function_name(name_hint) thunk = self.create_def(thunk_name, [], defs + [Return(body)]) @@ -181,8 +181,8 @@ def convert_to_thunk(self, name_hint: str, expr: Expr): def convert_func_node(self, func: Function, name_var=None): - '''Converts the given Relay function into a Python function, with - special for named functions (locally or globally)''' + """Converts the given Relay function into a Python function, with + special for named functions (locally or globally)""" if name_var is None: func_name = self.generate_function_name('_anon_func') if isinstance(name_var, GlobalVar): @@ -197,8 +197,8 @@ def convert_func_node(self, func: Function, name_var=None): def convert_module(self): - '''Converts all the global functions defined in the module and returns - them as a list of definitions''' + """Converts all the global functions defined in the module and returns + them as a list of definitions""" defs = [] for var, func in self.mod.functions.items(): # optimize the definition so any operators used are lowered @@ -209,12 +209,12 @@ def convert_module(self): def create_call(self, func_name: str, arguments): - '''Creates a simple function call.''' + """Creates a simple function call.""" return ast.Call(self.parse_name(func_name), arguments, []) def create_def(self, func_name: str, arguments: [str], body): - '''Wrapper over function definition AST node, whose constructor is inconvenient.''' + """Wrapper over function definition AST node, whose constructor is inconvenient.""" return ast.FunctionDef( func_name, ast.arguments([ast.arg(argument, None) @@ -224,9 +224,9 @@ def create_def(self, func_name: str, arguments: [str], body): def create_op_call(self, op: Function, relay_args, py_args): - '''Lowers the passed primitive function, registers it in TVM's + """Lowers the passed primitive function, registers it in TVM's global compiler, and produces a call to the lowered function in - the generated Python code.''' + the generated Python code.""" # compile the function and register globally cc_key = compile_engine.CCacheKey(op, self.tgt) @@ -237,8 +237,8 @@ def create_op_call(self, op: Function, relay_args, py_args): tvm.register_func(op_name, jitted) def convert_input(py_input, arg_type): - '''Use the types of the function arguments to determine whether we expect - a tensor or tuple (returns list of inputs to the lowered op call)''' + """Use the types of the function arguments to determine whether we expect + a tensor or tuple (returns list of inputs to the lowered op call)""" # equivalent: input.data if isinstance(arg_type, relay.TensorType): return [ast.Attribute(py_input, 'data', Load())] @@ -254,9 +254,9 @@ def convert_input(py_input, arg_type): return ret def convert_output(ret_type): - '''Use the function return type to produce auxiliary variables to store outputs. + """Use the function return type to produce auxiliary variables to store outputs. Returns ([assignments of output vars], [extra arguments to pass to op call], - expression collecting output)''' + expression collecting output)""" if isinstance(ret_type, relay.TensorType): output_var_name = self.generate_var_name('_out') output_var = Name(output_var_name, Load()) @@ -303,9 +303,9 @@ def convert_output(ret_type): def create_match_check(self, pattern: Pattern, data): - '''Given an ADT match pattern and a (Python) expression pointing to + """Given an ADT match pattern and a (Python) expression pointing to an ADT value, this generates a Python expression that checks if the - ADT value matches the given pattern (returning True or False).''' + ADT value matches the given pattern (returning True or False).""" # wildcard or var match everything if isinstance(pattern, (relay.PatternWildcard, relay.PatternVar)): @@ -340,14 +340,14 @@ def create_match_check(self, pattern: Pattern, data): def create_match_clause_body(self, pattern: Pattern, body: Expr): - '''Given a match clause pattern and a clause body, + """Given a match clause pattern and a clause body, generates a Python function that when called with an ADT that matches the pattern, returns the result of evaluating the clause body. This function returns a function definition - and the name of the generated function.''' + and the name of the generated function.""" def collect_var_assignments(pat, val): - '''This helper function ensures that the pattern is used to + """This helper function ensures that the pattern is used to properly assign all subfields of the given AST for use in the clause body @@ -356,7 +356,7 @@ def collect_var_assignments(pat, val): we would want to have v = a.fields[0] w = a.fields[2].fields[0] - ''' + """ if isinstance(pat, relay.PatternWildcard): return [] if isinstance(pat, relay.PatternVar): @@ -407,13 +407,13 @@ def visit_let(self, letexp: Expr): # To properly account for scoping and ensure that the entire node produces an expression, # we translate the let binding as a function that we call with the value we intend to bind. # Yes, this is somewhat ugly. - ''' + """ let var = value in body ======================= def let_thunk(var): return body let_thunk(value) - ''' + """ bind_body, bind_defs = self.visit(letexp.body) func_name = self.generate_function_name('_let_func') @@ -459,9 +459,9 @@ def visit_if(self, if_block: Expr): def visit_constant(self, constant: Expr): - '''Proceeds by converting constant value to a numpy array + """Proceeds by converting constant value to a numpy array and converting it to the appropriate value in the generated - code (whether it be a Python scalar or a Numpy array)''' + code (whether it be a Python scalar or a Numpy array)""" value = constant.data.asnumpy() const_expr = ast.Call(ast.Attribute(Name('numpy', Load()), 'array', Load()), [self.parse_numpy_array(value)], @@ -476,8 +476,8 @@ def visit_function(self, func: Expr): def visit_call(self, call: Expr): - '''For calls, we must distinguish between ordinary functions, - operators, and constructor calls.''' + """For calls, we must distinguish between ordinary functions, + operators, and constructor calls.""" func = call.op fields, field_defs = self.convert_fields(call.args) @@ -515,11 +515,11 @@ def visit_ref_read(self, read: Expr): def visit_ref_write(self, write: Expr): - '''For writing refs, we wrap the update in a thunk + """For writing refs, we wrap the update in a thunk (returning an empty tuple to match Relay's semantics) that we execute at the right time. This ensures such assignments can be properly nested, since assignments are statements - in Python but expressions in Relay''' + in Python but expressions in Relay""" ref, ref_defs = self.visit(write.ref) val, val_defs = self.visit(write.value) thunk_name = self.generate_function_name('_ref_write_thunk') @@ -533,11 +533,11 @@ def visit_ref_write(self, write: Expr): def visit_match(self, match: Expr): - '''For matches, we wrap the entire expression in a thunk + """For matches, we wrap the entire expression in a thunk because it is easiest to implement them using if statements. For each clause, we generate a function that checks if the pattern matches. If yes, we call a function that assigns - the variables appropriately and invokes the clause body.''' + the variables appropriately and invokes the clause body.""" data, defs = self.visit(match.data) data_var = self.generate_var_name('_match_data') @@ -571,16 +571,16 @@ def visit_op(self, _): def to_python(expr: Expr, mod=None, target=tvm.target.create('llvm')): - '''Converts the given Relay expression into a Python script (as a Python AST object). - For easiest debugging, import the astor package and use to_source().''' + """Converts the given Relay expression into a Python script (as a Python AST object). + For easiest debugging, import the astor package and use to_source().""" mod = mod if mod is not None else relay.Module() converter = PythonConverter(mod, target) return converter.convert(expr) def run_as_python(expr: Expr, mod=None, target=tvm.target.create('llvm')): - '''Converts the given Relay expression into a Python script and - executes it.''' + """Converts the given Relay expression into a Python script and + executes it.""" mod = mod if mod is not None else relay.Module() py_ast = to_python(expr, mod, target) code = compile(py_ast, '', 'exec')