Skip to content

Commit

Permalink
Merge pull request pytorch#882 from jamesr66a/fx
Browse files Browse the repository at this point in the history
[FX] Move examples from pytorch/pytorch
  • Loading branch information
James Reed authored Feb 4, 2021
2 parents 0bb43c8 + 94d4e94 commit 6e6e0d4
Show file tree
Hide file tree
Showing 6 changed files with 400 additions and 0 deletions.
68 changes: 68 additions & 0 deletions fx/inline_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import torch
from torch.fx import Proxy, symbolic_trace
from torch.fx.node import map_arg


'''
How to Inline a Function Into an Existing Graph
One reason you might want to inline a function is to get around FX's
default tracing behavior. For example, unless you've defined a custom
Tracer, the out-of-the-box implementation of ``symbolic_trace`` causes
references to ``torch.nn`` module instances to appear as
``call_module`` calls rather than being traced through. Let's say this
behavior is almost what you need; the only problem is that there's a
single module call that you want to replace with an inlined trace of the
function. Creating a custom Tracer would be too much. Instead, you can
accomplish this using Proxies.
The following code demonstrates how to trace a module and inline it
into an existing Graph using Proxy. We'll trace our Graph, then iterate
through its Nodes until we find the right place to swap out the
``call_module`` Node with an inlined trace. At that point, we'll create
Proxies from the Node's args and kwargs. Finally, we'll call the
function we want to replace with those Proxies--which will, in essence,
"trace" that function. Finally, we'll insert the result of that call
into our Graph. (This last step will automatically inline the function.)
'''


# Sample module
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()

def forward(self, x):
return self.relu(x) + 1.0

# Symbolically trace an instance of `M`. After tracing, `self.relu` is
# represented as a `call_module` Node. The full operation in the
# generated `forward` function's code will appear as `self.relu(x)`
m = symbolic_trace(M())

# Insert nodes from the ReLU graph in place of the original call to
# `self.relu`
for node in m.graph.nodes:
# Find `call_module` Node in `m` that corresponds to `self.relu`.
# This is the Node we want to swap out for an inlined version of the
# same call
if (node.op, node.target) == ("call_module", "relu"):
with m.graph.inserting_before(node):
# Create a Proxy from each Node in the current Node's
# args/kwargs
proxy_args = map_arg(node.args, Proxy)
proxy_kwargs = map_arg(node.kwargs, Proxy)
# Call `m.relu` with the newly-created Proxy arguments.
# `m.relu` is the generic version of the function; by
# calling it with Proxies created from Nodes in `m`, we're
# emitting Nodes that reference exiting values in the IR.
# The result of this call is another Proxy, which we can
# hook into our existing Graph to complete the function
# inlining.
proxy_output = m.relu(*proxy_args, **proxy_kwargs)
# Replace the relu `call_module` node with the inlined
# version of the function
node.replace_all_uses_with(proxy_output.node)
# Make sure that the old relu Node is erased
m.graph.erase_node(node)
62 changes: 62 additions & 0 deletions fx/invert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import torch
import torch.fx as fx

# An inverse mapping is one that takes a function f(x) and returns a function g
# such that f(g(x)) == x. For example,since log(exp(x)) == x, exp and log are
# inverses.

invert_mapping = {}
def add_inverse(a, b):
invert_mapping[a] = b
invert_mapping[b] = a
inverses = [
(torch.sin, torch.arcsin),
(torch.cos, torch.arccos),
(torch.tan, torch.arctan),
(torch.exp, torch.log),
]
for a, b in inverses:
add_inverse(a, b)

# The general strategy is that we walk the graph backwards, transforming each
# node into its inverse. To do so, we swap the outputs and inputs of the
# functions, and then we look up its inverse in `invert_mapping`. Note that
# this transform assumes that all operations take in only one input and return
# one output.
def invert(model: torch.nn.Module) -> torch.nn.Module:
fx_model = fx.symbolic_trace(model)
new_graph = fx.Graph() # As we're building up a new graph
env = {}
for node in reversed(fx_model.graph.nodes):
if node.op == 'call_function':
# This creates a node in the new graph with the inverse function,
# and passes `env[node.name]` (i.e. the previous output node) as
# input.
new_node = new_graph.call_function(invert_mapping[node.target], (env[node.name],))
env[node.args[0].name] = new_node
elif node.op == 'output':
# We turn the output into an input placeholder
new_node = new_graph.placeholder(node.name)
env[node.args[0].name] = new_node
elif node.op == 'placeholder':
# We turn the input placeholder into an output
new_graph.output(env[node.name])
else:
raise RuntimeError("Not implemented")

new_graph.lint()
return fx.GraphModule(fx_model, new_graph)


def f(x):
return torch.exp(torch.tan(x))

res = invert(f)
print(res.code)
"""
def forward(self, output):
log_1 = torch.log(output); output = None
arctan_1 = torch.arctan(log_1); log_1 = None
return arctan_1
"""
print(f(res((torch.arange(5) + 1)))) # [1., 2., 3., 4, 5.]
56 changes: 56 additions & 0 deletions fx/proxy_based_graph_creation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import torch
from torch.fx import Proxy, Graph, GraphModule


'''
How to Create a Graph Using Proxy Objects Instead of Tracing
It's possible to directly create a Proxy object around a raw Node. This
can be used to create a Graph independently of symbolic tracing.
The following code demonstrates how to use Proxy with a raw Node to
append operations to a fresh Graph. We'll create two parameters (``x``
and ``y``), perform some operations on those parameters, then add
everything we created to the new Graph. We'll then wrap that Graph in
a GraphModule. Doing so creates a runnable instance of ``nn.Module``
where previously-created operations are represented in the Module's
``forward`` function.
By the end of the tutorial, we'll have added the following method to an
empty ``nn.Module`` class.
.. code-block:: python
def forward(self, x, y):
cat_1 = torch.cat([x, y]); x = y = None
tanh_1 = torch.tanh(cat_1); cat_1 = None
neg_1 = torch.neg(tanh_1); tanh_1 = None
return neg_1
'''


# Create a graph independently of symbolic tracing
graph = Graph()

# Create raw Nodes
raw1 = graph.placeholder('x')
raw2 = graph.placeholder('y')

# Initialize Proxies using the raw Nodes
y = Proxy(raw1)
z = Proxy(raw2)

# Create other operations using the Proxies `y` and `z`
a = torch.cat([y, z])
b = torch.tanh(a)
c = torch.neg(b)

# Create a new output Node and add it to the Graph. By doing this, the
# Graph will contain all the Nodes we just created (since they're all
# linked to the output Node)
graph.output(c.node)

# Wrap our created Graph in a GraphModule to get a final, runnable
# `nn.Module` instance
mod = GraphModule(torch.nn.Module(), graph)
63 changes: 63 additions & 0 deletions fx/replace_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch
from torch.fx import symbolic_trace
import operator

"""
How to Replace One Op With Another
1. Iterate through all Nodes in your GraphModule's Graph.
2. Determine if the current Node should be replaced. (Suggested: match
on the Node's ``target`` attribute).
3. Create a replacement Node and add it to the Graph.
4. Use the FX built-in ``replace_all_uses_with`` to replace all uses of
the current Node with the replacement.
5. Delete the old Node from the graph.
6. Call ``recompile`` on the GraphModule. This updates the generated
Python code to reflect the new Graph state.
Currently, FX does not provide any way to guarantee that replaced
operators are syntactically valid. It's up to the user to confirm that
any new operators will work with the existing operands.
The following code demonstrates an example of replacing any instance of
addition with a bitwise AND.
To examine how the Graph evolves during op replacement, add the
statement `print(traced.graph)` after the line you want to inspect.
Alternatively, call `traced.graph.print_tabular()` to see the IR in a
tabular format.
"""

# Sample module
class M(torch.nn.Module):
def forward(self, x, y):
return x + y, torch.add(x, y), x.add(y)

# Symbolically trace an instance of the module
traced = symbolic_trace(M())

# As demonstrated in the above example, there are several different ways
# to denote addition. The possible cases are:
# 1. `x + y` - A `call_function` Node with target `operator.add`.
# We can match for equality on that `operator.add` directly.
# 2. `torch.add(x, y)` - A `call_function` Node with target
# `torch.add`. Similarly, we can match this function directly.
# 3. `x.add(y)` - The Tensor method call, whose target we can match
# as a string.

patterns = set([operator.add, torch.add, "add"])

# Go through all the nodes in the Graph
for n in traced.graph.nodes:
# If the target matches one of the patterns
if any(n.target == pattern for pattern in patterns):
# Set the insert point, add the new node, and replace all uses
# of `n` with the new node
with traced.graph.inserting_after(n):
new_node = traced.graph.call_function(torch.bitwise_and, n.args, n.kwargs)
n.replace_all_uses_with(new_node)
# Remove the old node from the graph
traced.graph.erase_node(n)

# Don't forget to recompile!
traced.recompile()
74 changes: 74 additions & 0 deletions fx/subgraph_rewriter_basic_use.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import torch
from torch.fx import symbolic_trace, replace_pattern


'''
How to Use the FX Subgraph Rewriter
For easy subgraph rewriting, FX exposes the utility function:
replace_pattern(gm : GraphModule,
pattern : Callable,
replacement : Callable)
-> None
`replace_pattern` matches all possible non-overlapping sets of operators
and their data dependencies (`pattern`) in the Graph of a GraphModule
(`gm`), then replaces each of these matched subgraphs with another
subgraph (`replacement).
The docstring for `replace_pattern` (located in `subgraph_rewriter.py`)
gives an in-depth explanation as to how `pattern` and `replacement`
should be specified, what happens during pattern matching, and other
important technical details. This tutorial, therefore, is only meant to
give an overview as to the FX Subgraph Rewriter's basic functionality.
Let's go rewrite a Graph!
'''

# Sample module
class M(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, w1, w2):
val1 = torch.neg(w1)
m1 = torch.cat([val1, w2]).sum()
val2 = torch.neg(w1)
m2 = torch.cat([val2, w2]).sum()
return x + torch.max(m1) + torch.max(m2)

# Symbolically trace an instance of `M`
traced = symbolic_trace(M())

# Define the pattern. The FX Subgraph Rewriter will match all
# non-overlapping instances of the pattern in the larger graph.
# Note that Pattern-matching is done based on data dependencies,
# not Node names. Even though we're operating on Nodes named `a1` and
# `a2` instead of `w1` and `w2`, the pattern is still a valid match
# for the two instances of `torch.cat([w1, w2]).sum()` above. Only
# operations that contribute to the single output value of the pattern
# are considered
def pattern(a1, a2):
val1 = torch.neg(a1)
return torch.cat([val1, a2]).sum()

# Define the replacement (same rules as the pattern)
def replacement(w1, w2):
return torch.stack([w1, w2])

# Replace `pattern` with `replacement` in `traced`
replace_pattern(traced, pattern, replacement)

# After calling `replace_pattern`, the generated code is:
'''
def forward(self, x, w1, w2):
stack_1 = torch.stack([w1, w2])
sum_1 = stack_1.sum()
stack_2 = torch.stack([w1, w2])
sum_2 = stack_2.sum()
max_1 = torch.max(sum_1)
add_1 = x + max_1
max_2 = torch.max(sum_2)
add_2 = add_1 + max_2
return add_2
'''
Loading

0 comments on commit 6e6e0d4

Please sign in to comment.