From d2dc957b3f0938cbab125b43f1f1a7ed334b4239 Mon Sep 17 00:00:00 2001 From: Yang Chen Date: Mon, 18 Mar 2024 23:48:29 -0700 Subject: [PATCH] Fixed a couple of dynamic shape detection issues (#996) Summary: Pull Request resolved: https://github.com/facebookincubator/AITemplate/pull/996 This diff fixed a couple of issues for dynamic shape detection: * handle cases where sample tensors may not have any dynamic dimension * added two lowering configs to guide dynamic shape detection (1) can_last_dim_be_dynamic: specifies if the last dimension can be dynamic (2) can_value_one_be_dynamic: specifies if dimension value one is allowed to appear at any dynamic dimension Reviewed By: hl475 Differential Revision: D54831007 fbshipit-source-id: cc17db2dd386748fec5e4d491fc58ae975b16b7f --- fx2ait/fx2ait/tensor_spec.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/fx2ait/fx2ait/tensor_spec.py b/fx2ait/fx2ait/tensor_spec.py index 1db415571..3c4703594 100644 --- a/fx2ait/fx2ait/tensor_spec.py +++ b/fx2ait/fx2ait/tensor_spec.py @@ -470,7 +470,12 @@ def from_input_list_with_batch_size_jagged_tensor( @classmethod # pyre-ignore [2]: Parameter `sample_input` must have a type other than `Any` - def find_batch_size_dim(cls, inputs: Any) -> []: + def find_batch_size_dim( + cls, + inputs: Any, + can_non_first_dim_be_dynamic: bool = True, + can_dim_value_one_be_dynamic: bool = True, + ) -> []: if isinstance(inputs, torch.Tensor) or len(inputs) <= 1: return [0] shapes = [i.shape for i in inputs] @@ -484,7 +489,9 @@ def find_batch_size_dim(cls, inputs: Any) -> []: # Dedup shape value for single tensor first_dims.add(shape[0]) seen_dims = set() - for i, dim in enumerate(shape): + valid_len = len(shape) if can_non_first_dim_be_dynamic else 1 + for i in range(valid_len): + dim = shape[i] if dim not in seen_dims: frequency_map[dim] = frequency_map.get(dim, 0) + 1 position_scores[dim] = position_scores.get(dim, 0) + i @@ -501,7 +508,18 @@ def find_batch_size_dim(cls, inputs: Any) -> []: frequency_map.items(), key=lambda x: (-x[1], position_scores[x[0]]), ) - batch_size = sorted_frequency[0][0] + if len(sorted_frequency) > 1: + if not can_dim_value_one_be_dynamic and sorted_frequency[0][0] == 1: + # It's often that dim value one indicates a non-dynamic dimension. + # If the user says so, we pick the second most frequent value. + batch_size = sorted_frequency[1][0] + else: + batch_size = sorted_frequency[0][0] + else: + if not can_dim_value_one_be_dynamic and sorted_frequency[0][0] == 1: + batch_size = -1 + else: + batch_size = sorted_frequency[0][0] else: # no dims to sort: no batch_size batch_size = -1 @@ -511,6 +529,8 @@ def find_batch_size_dim(cls, inputs: Any) -> []: # Default batch size dim = -1, indicate no batch_size dim = -1 for index, val in enumerate(i.shape): + if not can_non_first_dim_be_dynamic and index > 0: + break if val == batch_size: dim = index break