Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FX] Changes done internally at Facebook #1204

Merged
merged 1 commit into from
Jul 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions py/torch_tensorrt/fx/input_tensor_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,7 @@


def generate_input_specs(inputs, lower_setting, additional_inputs=None):
# AIT lower setting doesn't have explicit_batch_dimension field and
# we just return None.
if not hasattr(lower_setting, "explicit_batch_dimension"):
return None

# dynamic_batch is TRT only flag. It does not exist in AIT lower setting
# dynamic_batch is TRT only flag.
if (
not lower_setting.explicit_batch_dimension
or lower_setting.dynamic_batch is False
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/fx/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def __call__(
x.half() if x is not None and x.dtype == torch.float32 else x
for x in inputs
)
pm = self.lower_pass_manager_builder.build_lower_pipeline(
pm = self.lower_pass_manager_builder.build_trt_lower_pipeline(
inputs, additional_inputs
)

Expand Down
50 changes: 47 additions & 3 deletions py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def _split_pass(self) -> PassManager:
)
return PassManager.build_from_passlist(passes)

def _lower_pass(self) -> PassManager:
def _trt_lower_pass(self) -> PassManager:
def lower_func(split_result: SplitResult) -> nn.Module:
if (
hasattr(self.lower_setting, "explicit_batch_dimension")
Expand Down Expand Up @@ -169,7 +169,51 @@ def lower_func(split_result: SplitResult) -> nn.Module:

return PassManager.build_from_passlist([lower_func])

def build_lower_pipeline(
def _default_lower_pass(self) -> PassManager:
def lower_func(split_result: SplitResult) -> nn.Module:

for submod_name, submod_inputs in split_result.submodule_inputs.items():
submod = getattr(split_result.split_module, submod_name)

LOWER_SPLIT_PRE_OBSERVER.observe(submod_name, submod, submod_inputs)

# Only acc submodules will be lowered.
if not submod_name.startswith(split_result.non_acc_submodule_prefix):
print("Now lowering submodule", submod_name)
lowering_start_time = datetime.datetime.now()

lowered_module = self._lower_func(
submod, submod_inputs, self.lower_setting, submod_name
)
setattr(split_result.split_module, submod_name, lowered_module)
LOWER_SPLIT_POST_OBSERVER.observe(
submod_name, lowered_module, submod_inputs
)
print(
f"Lowering submodule {submod_name} elapsed time",
datetime.datetime.now() - lowering_start_time,
)

return split_result.split_module

return PassManager.build_from_passlist([lower_func])

def build_trt_lower_pipeline(
self, input: Input, additional_input: Optional[Input] = None
) -> PassManager:
self._input = input
self._additional_input = additional_input
passes = []

passes.append(self._const_fold_pass())
passes.append(self.graph_optimization_pass())
passes.append(self._split_pass())
passes.append(self._trt_lower_pass())

pm = PassManager.build_from_passlist(passes)
return pm

def build_default_lower_pipeline(
self, input: Input, additional_input: Optional[Input] = None
) -> PassManager:
self._input = input
Expand All @@ -179,7 +223,7 @@ def build_lower_pipeline(
passes.append(self._const_fold_pass())
passes.append(self.graph_optimization_pass())
passes.append(self._split_pass())
passes.append(self._lower_pass())
passes.append(self._default_lower_pass())

pm = PassManager.build_from_passlist(passes)
return pm
5 changes: 2 additions & 3 deletions py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,6 @@ def stack_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
return cat_node


@register_acc_op_properties(AccOpProperty.pointwise, AccOpProperty.unary)
@register_acc_op_mapping(op_and_target=("call_function", torch.clamp))
@register_acc_op_mapping(op_and_target=("call_method", "clamp"))
@register_acc_op
Expand Down Expand Up @@ -1743,7 +1742,7 @@ def quantized_conv2d(
dilation,
groups,
padding_mode,
acc_out_ty,
acc_out_ty=None,
):
qparams = acc_out_ty.qparams
return torch.nn.quantized.functional.conv2d(
Expand Down Expand Up @@ -2041,7 +2040,7 @@ def quantized_batch_norm2d(
weight,
bias,
eps,
acc_out_ty,
acc_out_ty=None,
):
qparams = acc_out_ty.qparams
return torch.ops.quantized.batch_norm2d(
Expand Down