forked from pytorch/audio
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request pytorch#882 from jamesr66a/fx
[FX] Move examples from pytorch/pytorch
- Loading branch information
Showing
6 changed files
with
400 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import torch | ||
from torch.fx import Proxy, symbolic_trace | ||
from torch.fx.node import map_arg | ||
|
||
|
||
''' | ||
How to Inline a Function Into an Existing Graph | ||
One reason you might want to inline a function is to get around FX's | ||
default tracing behavior. For example, unless you've defined a custom | ||
Tracer, the out-of-the-box implementation of ``symbolic_trace`` causes | ||
references to ``torch.nn`` module instances to appear as | ||
``call_module`` calls rather than being traced through. Let's say this | ||
behavior is almost what you need; the only problem is that there's a | ||
single module call that you want to replace with an inlined trace of the | ||
function. Creating a custom Tracer would be too much. Instead, you can | ||
accomplish this using Proxies. | ||
The following code demonstrates how to trace a module and inline it | ||
into an existing Graph using Proxy. We'll trace our Graph, then iterate | ||
through its Nodes until we find the right place to swap out the | ||
``call_module`` Node with an inlined trace. At that point, we'll create | ||
Proxies from the Node's args and kwargs. Finally, we'll call the | ||
function we want to replace with those Proxies--which will, in essence, | ||
"trace" that function. Finally, we'll insert the result of that call | ||
into our Graph. (This last step will automatically inline the function.) | ||
''' | ||
|
||
|
||
# Sample module | ||
class M(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.relu = torch.nn.ReLU() | ||
|
||
def forward(self, x): | ||
return self.relu(x) + 1.0 | ||
|
||
# Symbolically trace an instance of `M`. After tracing, `self.relu` is | ||
# represented as a `call_module` Node. The full operation in the | ||
# generated `forward` function's code will appear as `self.relu(x)` | ||
m = symbolic_trace(M()) | ||
|
||
# Insert nodes from the ReLU graph in place of the original call to | ||
# `self.relu` | ||
for node in m.graph.nodes: | ||
# Find `call_module` Node in `m` that corresponds to `self.relu`. | ||
# This is the Node we want to swap out for an inlined version of the | ||
# same call | ||
if (node.op, node.target) == ("call_module", "relu"): | ||
with m.graph.inserting_before(node): | ||
# Create a Proxy from each Node in the current Node's | ||
# args/kwargs | ||
proxy_args = map_arg(node.args, Proxy) | ||
proxy_kwargs = map_arg(node.kwargs, Proxy) | ||
# Call `m.relu` with the newly-created Proxy arguments. | ||
# `m.relu` is the generic version of the function; by | ||
# calling it with Proxies created from Nodes in `m`, we're | ||
# emitting Nodes that reference exiting values in the IR. | ||
# The result of this call is another Proxy, which we can | ||
# hook into our existing Graph to complete the function | ||
# inlining. | ||
proxy_output = m.relu(*proxy_args, **proxy_kwargs) | ||
# Replace the relu `call_module` node with the inlined | ||
# version of the function | ||
node.replace_all_uses_with(proxy_output.node) | ||
# Make sure that the old relu Node is erased | ||
m.graph.erase_node(node) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import torch | ||
import torch.fx as fx | ||
|
||
# An inverse mapping is one that takes a function f(x) and returns a function g | ||
# such that f(g(x)) == x. For example,since log(exp(x)) == x, exp and log are | ||
# inverses. | ||
|
||
invert_mapping = {} | ||
def add_inverse(a, b): | ||
invert_mapping[a] = b | ||
invert_mapping[b] = a | ||
inverses = [ | ||
(torch.sin, torch.arcsin), | ||
(torch.cos, torch.arccos), | ||
(torch.tan, torch.arctan), | ||
(torch.exp, torch.log), | ||
] | ||
for a, b in inverses: | ||
add_inverse(a, b) | ||
|
||
# The general strategy is that we walk the graph backwards, transforming each | ||
# node into its inverse. To do so, we swap the outputs and inputs of the | ||
# functions, and then we look up its inverse in `invert_mapping`. Note that | ||
# this transform assumes that all operations take in only one input and return | ||
# one output. | ||
def invert(model: torch.nn.Module) -> torch.nn.Module: | ||
fx_model = fx.symbolic_trace(model) | ||
new_graph = fx.Graph() # As we're building up a new graph | ||
env = {} | ||
for node in reversed(fx_model.graph.nodes): | ||
if node.op == 'call_function': | ||
# This creates a node in the new graph with the inverse function, | ||
# and passes `env[node.name]` (i.e. the previous output node) as | ||
# input. | ||
new_node = new_graph.call_function(invert_mapping[node.target], (env[node.name],)) | ||
env[node.args[0].name] = new_node | ||
elif node.op == 'output': | ||
# We turn the output into an input placeholder | ||
new_node = new_graph.placeholder(node.name) | ||
env[node.args[0].name] = new_node | ||
elif node.op == 'placeholder': | ||
# We turn the input placeholder into an output | ||
new_graph.output(env[node.name]) | ||
else: | ||
raise RuntimeError("Not implemented") | ||
|
||
new_graph.lint() | ||
return fx.GraphModule(fx_model, new_graph) | ||
|
||
|
||
def f(x): | ||
return torch.exp(torch.tan(x)) | ||
|
||
res = invert(f) | ||
print(res.code) | ||
""" | ||
def forward(self, output): | ||
log_1 = torch.log(output); output = None | ||
arctan_1 = torch.arctan(log_1); log_1 = None | ||
return arctan_1 | ||
""" | ||
print(f(res((torch.arange(5) + 1)))) # [1., 2., 3., 4, 5.] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import torch | ||
from torch.fx import Proxy, Graph, GraphModule | ||
|
||
|
||
''' | ||
How to Create a Graph Using Proxy Objects Instead of Tracing | ||
It's possible to directly create a Proxy object around a raw Node. This | ||
can be used to create a Graph independently of symbolic tracing. | ||
The following code demonstrates how to use Proxy with a raw Node to | ||
append operations to a fresh Graph. We'll create two parameters (``x`` | ||
and ``y``), perform some operations on those parameters, then add | ||
everything we created to the new Graph. We'll then wrap that Graph in | ||
a GraphModule. Doing so creates a runnable instance of ``nn.Module`` | ||
where previously-created operations are represented in the Module's | ||
``forward`` function. | ||
By the end of the tutorial, we'll have added the following method to an | ||
empty ``nn.Module`` class. | ||
.. code-block:: python | ||
def forward(self, x, y): | ||
cat_1 = torch.cat([x, y]); x = y = None | ||
tanh_1 = torch.tanh(cat_1); cat_1 = None | ||
neg_1 = torch.neg(tanh_1); tanh_1 = None | ||
return neg_1 | ||
''' | ||
|
||
|
||
# Create a graph independently of symbolic tracing | ||
graph = Graph() | ||
|
||
# Create raw Nodes | ||
raw1 = graph.placeholder('x') | ||
raw2 = graph.placeholder('y') | ||
|
||
# Initialize Proxies using the raw Nodes | ||
y = Proxy(raw1) | ||
z = Proxy(raw2) | ||
|
||
# Create other operations using the Proxies `y` and `z` | ||
a = torch.cat([y, z]) | ||
b = torch.tanh(a) | ||
c = torch.neg(b) | ||
|
||
# Create a new output Node and add it to the Graph. By doing this, the | ||
# Graph will contain all the Nodes we just created (since they're all | ||
# linked to the output Node) | ||
graph.output(c.node) | ||
|
||
# Wrap our created Graph in a GraphModule to get a final, runnable | ||
# `nn.Module` instance | ||
mod = GraphModule(torch.nn.Module(), graph) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import torch | ||
from torch.fx import symbolic_trace | ||
import operator | ||
|
||
""" | ||
How to Replace One Op With Another | ||
1. Iterate through all Nodes in your GraphModule's Graph. | ||
2. Determine if the current Node should be replaced. (Suggested: match | ||
on the Node's ``target`` attribute). | ||
3. Create a replacement Node and add it to the Graph. | ||
4. Use the FX built-in ``replace_all_uses_with`` to replace all uses of | ||
the current Node with the replacement. | ||
5. Delete the old Node from the graph. | ||
6. Call ``recompile`` on the GraphModule. This updates the generated | ||
Python code to reflect the new Graph state. | ||
Currently, FX does not provide any way to guarantee that replaced | ||
operators are syntactically valid. It's up to the user to confirm that | ||
any new operators will work with the existing operands. | ||
The following code demonstrates an example of replacing any instance of | ||
addition with a bitwise AND. | ||
To examine how the Graph evolves during op replacement, add the | ||
statement `print(traced.graph)` after the line you want to inspect. | ||
Alternatively, call `traced.graph.print_tabular()` to see the IR in a | ||
tabular format. | ||
""" | ||
|
||
# Sample module | ||
class M(torch.nn.Module): | ||
def forward(self, x, y): | ||
return x + y, torch.add(x, y), x.add(y) | ||
|
||
# Symbolically trace an instance of the module | ||
traced = symbolic_trace(M()) | ||
|
||
# As demonstrated in the above example, there are several different ways | ||
# to denote addition. The possible cases are: | ||
# 1. `x + y` - A `call_function` Node with target `operator.add`. | ||
# We can match for equality on that `operator.add` directly. | ||
# 2. `torch.add(x, y)` - A `call_function` Node with target | ||
# `torch.add`. Similarly, we can match this function directly. | ||
# 3. `x.add(y)` - The Tensor method call, whose target we can match | ||
# as a string. | ||
|
||
patterns = set([operator.add, torch.add, "add"]) | ||
|
||
# Go through all the nodes in the Graph | ||
for n in traced.graph.nodes: | ||
# If the target matches one of the patterns | ||
if any(n.target == pattern for pattern in patterns): | ||
# Set the insert point, add the new node, and replace all uses | ||
# of `n` with the new node | ||
with traced.graph.inserting_after(n): | ||
new_node = traced.graph.call_function(torch.bitwise_and, n.args, n.kwargs) | ||
n.replace_all_uses_with(new_node) | ||
# Remove the old node from the graph | ||
traced.graph.erase_node(n) | ||
|
||
# Don't forget to recompile! | ||
traced.recompile() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import torch | ||
from torch.fx import symbolic_trace, replace_pattern | ||
|
||
|
||
''' | ||
How to Use the FX Subgraph Rewriter | ||
For easy subgraph rewriting, FX exposes the utility function: | ||
replace_pattern(gm : GraphModule, | ||
pattern : Callable, | ||
replacement : Callable) | ||
-> None | ||
`replace_pattern` matches all possible non-overlapping sets of operators | ||
and their data dependencies (`pattern`) in the Graph of a GraphModule | ||
(`gm`), then replaces each of these matched subgraphs with another | ||
subgraph (`replacement). | ||
The docstring for `replace_pattern` (located in `subgraph_rewriter.py`) | ||
gives an in-depth explanation as to how `pattern` and `replacement` | ||
should be specified, what happens during pattern matching, and other | ||
important technical details. This tutorial, therefore, is only meant to | ||
give an overview as to the FX Subgraph Rewriter's basic functionality. | ||
Let's go rewrite a Graph! | ||
''' | ||
|
||
# Sample module | ||
class M(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, x, w1, w2): | ||
val1 = torch.neg(w1) | ||
m1 = torch.cat([val1, w2]).sum() | ||
val2 = torch.neg(w1) | ||
m2 = torch.cat([val2, w2]).sum() | ||
return x + torch.max(m1) + torch.max(m2) | ||
|
||
# Symbolically trace an instance of `M` | ||
traced = symbolic_trace(M()) | ||
|
||
# Define the pattern. The FX Subgraph Rewriter will match all | ||
# non-overlapping instances of the pattern in the larger graph. | ||
# Note that Pattern-matching is done based on data dependencies, | ||
# not Node names. Even though we're operating on Nodes named `a1` and | ||
# `a2` instead of `w1` and `w2`, the pattern is still a valid match | ||
# for the two instances of `torch.cat([w1, w2]).sum()` above. Only | ||
# operations that contribute to the single output value of the pattern | ||
# are considered | ||
def pattern(a1, a2): | ||
val1 = torch.neg(a1) | ||
return torch.cat([val1, a2]).sum() | ||
|
||
# Define the replacement (same rules as the pattern) | ||
def replacement(w1, w2): | ||
return torch.stack([w1, w2]) | ||
|
||
# Replace `pattern` with `replacement` in `traced` | ||
replace_pattern(traced, pattern, replacement) | ||
|
||
# After calling `replace_pattern`, the generated code is: | ||
''' | ||
def forward(self, x, w1, w2): | ||
stack_1 = torch.stack([w1, w2]) | ||
sum_1 = stack_1.sum() | ||
stack_2 = torch.stack([w1, w2]) | ||
sum_2 = stack_2.sum() | ||
max_1 = torch.max(sum_1) | ||
add_1 = x + max_1 | ||
max_2 = torch.max(sum_2) | ||
add_2 = add_1 + max_2 | ||
return add_2 | ||
''' |
Oops, something went wrong.