Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Export control-flow Ops in ppe.onnx.export #648

Merged
merged 17 commits into from
Mar 8, 2023
1 change: 1 addition & 0 deletions pytorch_pfn_extras/onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pytorch_pfn_extras.onnx._grad import grad # NOQA
from pytorch_pfn_extras.onnx.load import load_model # NOQA
from pytorch_pfn_extras.onnx._helper import no_grad # NOQA
import pytorch_pfn_extras.onnx._lax as lax # NOQA
available = True
except ImportError:
available = False
29 changes: 28 additions & 1 deletion pytorch_pfn_extras/onnx/_as_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,39 @@ def trace(
_outputs.outputs = None


def as_output(name: str, value: torch.Tensor) -> torch.Tensor:
# Add Identity function to prevent constant folding in torch.onnx
class _ExplicitIdentity(torch.autograd.Function):
@staticmethod
def forward( # type: ignore
ctx: Any,
x: torch.Tensor,
) -> torch.Tensor:
return x.clone()

@staticmethod
def backward( # type: ignore
ctx: Any,
dx: torch.Tensor,
) -> torch.Tensor:
return dx

@staticmethod
def symbolic(g, x): # type: ignore
return g.op("Identity", x)


def as_output(
name: str, value: torch.Tensor, add_identity: bool = True
) -> torch.Tensor:
if torch.jit.is_scripting(): # type: ignore[no-untyped-call]
warnings.warn(
'`as_output` seen in TorchScript compilation. The value is no '
'longer an output in the exported onnx.')
return value
if hasattr(_outputs, "outputs") and _outputs.outputs is not None:
if add_identity:
value = _ExplicitIdentity.apply(value)
_outputs.outputs.add(name, value)
if add_identity:
value = _ExplicitIdentity.apply(value)
return value
2 changes: 1 addition & 1 deletion pytorch_pfn_extras/onnx/_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def grad(
for i, input in enumerate(inputs):
input_name = f"Gradient_x_{i}_{n_grad_call}"
input_names.append(input_name)
inputs_l[i] = as_output(input_name, input)
inputs_l[i] = as_output(input_name, input, add_identity=False)

class _Gradient(torch.autograd.Function):
@staticmethod
Expand Down
Loading