From 94d4e9466807dafcc1b65fb81155feb0a60a8f3b Mon Sep 17 00:00:00 2001 From: James Reed Date: Wed, 3 Feb 2021 16:47:41 -0800 Subject: [PATCH] [FX] Move examples from pytorch/pytorch --- fx/inline_function.py | 68 +++++++++++++++++++++++++++ fx/invert.py | 62 +++++++++++++++++++++++++ fx/proxy_based_graph_creation.py | 56 ++++++++++++++++++++++ fx/replace_op.py | 63 +++++++++++++++++++++++++ fx/subgraph_rewriter_basic_use.py | 74 +++++++++++++++++++++++++++++ fx/wrap_output_dynamically.py | 77 +++++++++++++++++++++++++++++++ 6 files changed, 400 insertions(+) create mode 100644 fx/inline_function.py create mode 100644 fx/invert.py create mode 100644 fx/proxy_based_graph_creation.py create mode 100644 fx/replace_op.py create mode 100644 fx/subgraph_rewriter_basic_use.py create mode 100644 fx/wrap_output_dynamically.py diff --git a/fx/inline_function.py b/fx/inline_function.py new file mode 100644 index 0000000000..c967bd0ee5 --- /dev/null +++ b/fx/inline_function.py @@ -0,0 +1,68 @@ +import torch +from torch.fx import Proxy, symbolic_trace +from torch.fx.node import map_arg + + +''' +How to Inline a Function Into an Existing Graph + +One reason you might want to inline a function is to get around FX's +default tracing behavior. For example, unless you've defined a custom +Tracer, the out-of-the-box implementation of ``symbolic_trace`` causes +references to ``torch.nn`` module instances to appear as +``call_module`` calls rather than being traced through. Let's say this +behavior is almost what you need; the only problem is that there's a +single module call that you want to replace with an inlined trace of the +function. Creating a custom Tracer would be too much. Instead, you can +accomplish this using Proxies. + +The following code demonstrates how to trace a module and inline it +into an existing Graph using Proxy. We'll trace our Graph, then iterate +through its Nodes until we find the right place to swap out the +``call_module`` Node with an inlined trace. At that point, we'll create +Proxies from the Node's args and kwargs. Finally, we'll call the +function we want to replace with those Proxies--which will, in essence, +"trace" that function. Finally, we'll insert the result of that call +into our Graph. (This last step will automatically inline the function.) +''' + + +# Sample module +class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(x) + 1.0 + +# Symbolically trace an instance of `M`. After tracing, `self.relu` is +# represented as a `call_module` Node. The full operation in the +# generated `forward` function's code will appear as `self.relu(x)` +m = symbolic_trace(M()) + +# Insert nodes from the ReLU graph in place of the original call to +# `self.relu` +for node in m.graph.nodes: + # Find `call_module` Node in `m` that corresponds to `self.relu`. + # This is the Node we want to swap out for an inlined version of the + # same call + if (node.op, node.target) == ("call_module", "relu"): + with m.graph.inserting_before(node): + # Create a Proxy from each Node in the current Node's + # args/kwargs + proxy_args = map_arg(node.args, Proxy) + proxy_kwargs = map_arg(node.kwargs, Proxy) + # Call `m.relu` with the newly-created Proxy arguments. + # `m.relu` is the generic version of the function; by + # calling it with Proxies created from Nodes in `m`, we're + # emitting Nodes that reference exiting values in the IR. + # The result of this call is another Proxy, which we can + # hook into our existing Graph to complete the function + # inlining. + proxy_output = m.relu(*proxy_args, **proxy_kwargs) + # Replace the relu `call_module` node with the inlined + # version of the function + node.replace_all_uses_with(proxy_output.node) + # Make sure that the old relu Node is erased + m.graph.erase_node(node) diff --git a/fx/invert.py b/fx/invert.py new file mode 100644 index 0000000000..dd9b1d3d6f --- /dev/null +++ b/fx/invert.py @@ -0,0 +1,62 @@ +import torch +import torch.fx as fx + +# An inverse mapping is one that takes a function f(x) and returns a function g +# such that f(g(x)) == x. For example,since log(exp(x)) == x, exp and log are +# inverses. + +invert_mapping = {} +def add_inverse(a, b): + invert_mapping[a] = b + invert_mapping[b] = a +inverses = [ + (torch.sin, torch.arcsin), + (torch.cos, torch.arccos), + (torch.tan, torch.arctan), + (torch.exp, torch.log), +] +for a, b in inverses: + add_inverse(a, b) + +# The general strategy is that we walk the graph backwards, transforming each +# node into its inverse. To do so, we swap the outputs and inputs of the +# functions, and then we look up its inverse in `invert_mapping`. Note that +# this transform assumes that all operations take in only one input and return +# one output. +def invert(model: torch.nn.Module) -> torch.nn.Module: + fx_model = fx.symbolic_trace(model) + new_graph = fx.Graph() # As we're building up a new graph + env = {} + for node in reversed(fx_model.graph.nodes): + if node.op == 'call_function': + # This creates a node in the new graph with the inverse function, + # and passes `env[node.name]` (i.e. the previous output node) as + # input. + new_node = new_graph.call_function(invert_mapping[node.target], (env[node.name],)) + env[node.args[0].name] = new_node + elif node.op == 'output': + # We turn the output into an input placeholder + new_node = new_graph.placeholder(node.name) + env[node.args[0].name] = new_node + elif node.op == 'placeholder': + # We turn the input placeholder into an output + new_graph.output(env[node.name]) + else: + raise RuntimeError("Not implemented") + + new_graph.lint() + return fx.GraphModule(fx_model, new_graph) + + +def f(x): + return torch.exp(torch.tan(x)) + +res = invert(f) +print(res.code) +""" +def forward(self, output): + log_1 = torch.log(output); output = None + arctan_1 = torch.arctan(log_1); log_1 = None + return arctan_1 +""" +print(f(res((torch.arange(5) + 1)))) # [1., 2., 3., 4, 5.] diff --git a/fx/proxy_based_graph_creation.py b/fx/proxy_based_graph_creation.py new file mode 100644 index 0000000000..b08f61d4f6 --- /dev/null +++ b/fx/proxy_based_graph_creation.py @@ -0,0 +1,56 @@ +import torch +from torch.fx import Proxy, Graph, GraphModule + + +''' +How to Create a Graph Using Proxy Objects Instead of Tracing + +It's possible to directly create a Proxy object around a raw Node. This +can be used to create a Graph independently of symbolic tracing. + +The following code demonstrates how to use Proxy with a raw Node to +append operations to a fresh Graph. We'll create two parameters (``x`` +and ``y``), perform some operations on those parameters, then add +everything we created to the new Graph. We'll then wrap that Graph in +a GraphModule. Doing so creates a runnable instance of ``nn.Module`` +where previously-created operations are represented in the Module's +``forward`` function. + +By the end of the tutorial, we'll have added the following method to an +empty ``nn.Module`` class. + +.. code-block:: python + + def forward(self, x, y): + cat_1 = torch.cat([x, y]); x = y = None + tanh_1 = torch.tanh(cat_1); cat_1 = None + neg_1 = torch.neg(tanh_1); tanh_1 = None + return neg_1 + +''' + + +# Create a graph independently of symbolic tracing +graph = Graph() + +# Create raw Nodes +raw1 = graph.placeholder('x') +raw2 = graph.placeholder('y') + +# Initialize Proxies using the raw Nodes +y = Proxy(raw1) +z = Proxy(raw2) + +# Create other operations using the Proxies `y` and `z` +a = torch.cat([y, z]) +b = torch.tanh(a) +c = torch.neg(b) + +# Create a new output Node and add it to the Graph. By doing this, the +# Graph will contain all the Nodes we just created (since they're all +# linked to the output Node) +graph.output(c.node) + +# Wrap our created Graph in a GraphModule to get a final, runnable +# `nn.Module` instance +mod = GraphModule(torch.nn.Module(), graph) diff --git a/fx/replace_op.py b/fx/replace_op.py new file mode 100644 index 0000000000..6ad5be3e82 --- /dev/null +++ b/fx/replace_op.py @@ -0,0 +1,63 @@ +import torch +from torch.fx import symbolic_trace +import operator + +""" +How to Replace One Op With Another + +1. Iterate through all Nodes in your GraphModule's Graph. +2. Determine if the current Node should be replaced. (Suggested: match +on the Node's ``target`` attribute). +3. Create a replacement Node and add it to the Graph. +4. Use the FX built-in ``replace_all_uses_with`` to replace all uses of +the current Node with the replacement. +5. Delete the old Node from the graph. +6. Call ``recompile`` on the GraphModule. This updates the generated +Python code to reflect the new Graph state. + +Currently, FX does not provide any way to guarantee that replaced +operators are syntactically valid. It's up to the user to confirm that +any new operators will work with the existing operands. + +The following code demonstrates an example of replacing any instance of +addition with a bitwise AND. + +To examine how the Graph evolves during op replacement, add the +statement `print(traced.graph)` after the line you want to inspect. +Alternatively, call `traced.graph.print_tabular()` to see the IR in a +tabular format. +""" + +# Sample module +class M(torch.nn.Module): + def forward(self, x, y): + return x + y, torch.add(x, y), x.add(y) + +# Symbolically trace an instance of the module +traced = symbolic_trace(M()) + +# As demonstrated in the above example, there are several different ways +# to denote addition. The possible cases are: +# 1. `x + y` - A `call_function` Node with target `operator.add`. +# We can match for equality on that `operator.add` directly. +# 2. `torch.add(x, y)` - A `call_function` Node with target +# `torch.add`. Similarly, we can match this function directly. +# 3. `x.add(y)` - The Tensor method call, whose target we can match +# as a string. + +patterns = set([operator.add, torch.add, "add"]) + +# Go through all the nodes in the Graph +for n in traced.graph.nodes: + # If the target matches one of the patterns + if any(n.target == pattern for pattern in patterns): + # Set the insert point, add the new node, and replace all uses + # of `n` with the new node + with traced.graph.inserting_after(n): + new_node = traced.graph.call_function(torch.bitwise_and, n.args, n.kwargs) + n.replace_all_uses_with(new_node) + # Remove the old node from the graph + traced.graph.erase_node(n) + +# Don't forget to recompile! +traced.recompile() diff --git a/fx/subgraph_rewriter_basic_use.py b/fx/subgraph_rewriter_basic_use.py new file mode 100644 index 0000000000..658815cac8 --- /dev/null +++ b/fx/subgraph_rewriter_basic_use.py @@ -0,0 +1,74 @@ +import torch +from torch.fx import symbolic_trace, replace_pattern + + +''' +How to Use the FX Subgraph Rewriter + +For easy subgraph rewriting, FX exposes the utility function: + + replace_pattern(gm : GraphModule, + pattern : Callable, + replacement : Callable) + -> None + +`replace_pattern` matches all possible non-overlapping sets of operators +and their data dependencies (`pattern`) in the Graph of a GraphModule +(`gm`), then replaces each of these matched subgraphs with another +subgraph (`replacement). + +The docstring for `replace_pattern` (located in `subgraph_rewriter.py`) +gives an in-depth explanation as to how `pattern` and `replacement` +should be specified, what happens during pattern matching, and other +important technical details. This tutorial, therefore, is only meant to +give an overview as to the FX Subgraph Rewriter's basic functionality. +Let's go rewrite a Graph! +''' + +# Sample module +class M(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, w1, w2): + val1 = torch.neg(w1) + m1 = torch.cat([val1, w2]).sum() + val2 = torch.neg(w1) + m2 = torch.cat([val2, w2]).sum() + return x + torch.max(m1) + torch.max(m2) + +# Symbolically trace an instance of `M` +traced = symbolic_trace(M()) + +# Define the pattern. The FX Subgraph Rewriter will match all +# non-overlapping instances of the pattern in the larger graph. +# Note that Pattern-matching is done based on data dependencies, +# not Node names. Even though we're operating on Nodes named `a1` and +# `a2` instead of `w1` and `w2`, the pattern is still a valid match +# for the two instances of `torch.cat([w1, w2]).sum()` above. Only +# operations that contribute to the single output value of the pattern +# are considered +def pattern(a1, a2): + val1 = torch.neg(a1) + return torch.cat([val1, a2]).sum() + +# Define the replacement (same rules as the pattern) +def replacement(w1, w2): + return torch.stack([w1, w2]) + +# Replace `pattern` with `replacement` in `traced` +replace_pattern(traced, pattern, replacement) + +# After calling `replace_pattern`, the generated code is: +''' +def forward(self, x, w1, w2): + stack_1 = torch.stack([w1, w2]) + sum_1 = stack_1.sum() + stack_2 = torch.stack([w1, w2]) + sum_2 = stack_2.sum() + max_1 = torch.max(sum_1) + add_1 = x + max_1 + max_2 = torch.max(sum_2) + add_2 = add_1 + max_2 + return add_2 +''' diff --git a/fx/wrap_output_dynamically.py b/fx/wrap_output_dynamically.py new file mode 100644 index 0000000000..c2d6b8402c --- /dev/null +++ b/fx/wrap_output_dynamically.py @@ -0,0 +1,77 @@ + +import torch +from torch.fx import Proxy, GraphModule, Node, symbolic_trace + +from enum import Enum, auto + +''' +Wrap Graph Output Dynamically + +The following code demonstrates how change an existing Graph based on +parameters specified at runtime. We'll let the user specify an +activation function from a predefined Enum list, then we'll symbolically +trace it. Next, we'll create a Proxy from the last operation in the +Graph. We'll call our traced activation function with this Proxy and +insert the ``output`` Node from that call into our Graph. (This final +step will automatically inline the entire traced function.) +''' + + +# Sample module +class M(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + y = torch.cat([x, y]) + return y + +# Symbolically trace an instance of `M` +traced = symbolic_trace(M()) + +# Selected activation functions +class ActivationFunction(Enum): + RELU = auto() + LEAKY_RELU = auto() + PRELU = auto() + +# Map activation function names to their implementation +activation_functions = { + ActivationFunction.RELU: torch.nn.ReLU(), + ActivationFunction.LEAKY_RELU: torch.nn.LeakyReLU(), + ActivationFunction.PRELU: torch.nn.PReLU(), +} + +def wrap_in_activation_function(m: GraphModule, fn: ActivationFunction) -> GraphModule: + # Get output node + output_node: Optional[Node] = None + for n in reversed(m.graph.nodes): + if n.op == "output": + output_node = n + break + assert output_node + + # Get the actual output (the "input" of the output node). This is + # the Node we want to wrap in a user-specified activation function + assert len(output_node.all_input_nodes) == 1 + wrap_node = output_node.all_input_nodes[0] + + # Wrap the actual output in a Proxy + wrap_proxy = Proxy(wrap_node) + + # Get the implementation of the specified activation function and + # symbolically trace it + fn_impl = activation_functions[fn] + fn_impl_traced = symbolic_trace(fn_impl) + + # Call the specified activation function using the Proxy wrapper for + # `output_op`. The result of this call is another Proxy, which we + # can hook into our existing Graph. + with traced.graph.inserting_before(wrap_node): + fn_impl_output_node = fn_impl_traced(wrap_proxy) + new_args = (fn_impl_output_node.node,) + output_node.args = new_args + + +# Example call +wrap_in_activation_function(traced, ActivationFunction.LEAKY_RELU)