diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index bcbbc04ca9..1705dd06db 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -231,26 +231,7 @@ def aten_ops_cat( ) -def embedding_param_validator(embedding_node: Node) -> bool: - scale_grad_by_freq = args_bounds_check(embedding_node.args, 3) - sparse = args_bounds_check(embedding_node.args, 4) - - if scale_grad_by_freq is not None: - _LOGGER.debug( - f"Currently we don't support specifying scale gradient by word frequency, got {scale_grad_by_freq}." - ) - return False - - if sparse is not None: - _LOGGER.debug(f"Currently we don't support sparse gradient, got {sparse}.") - return False - - return True - - -@dynamo_tensorrt_converter( - torch.ops.aten.embedding.default, capability_validator=embedding_param_validator -) +@dynamo_tensorrt_converter(torch.ops.aten.embedding.default) def aten_ops_embedding( ctx: ConversionContext, target: Target, @@ -265,22 +246,19 @@ def aten_ops_embedding( name, input=args[1], weight=args[0], - # args[2] is the padding index, which is useful for training only - scale_grad_by_freq=args_bounds_check(args, 3), - sparse=args_bounds_check(args, 4), ) def embedding_bag_validator(node: Node) -> bool: - mode = args_bounds_check(node.args, 4, 0) - indices = node.args[1].meta.get("tensor_meta") + if not one_user_validator(node): + return False + meta = node.args[1].meta + indices = meta.get("tensor_meta") + if indices is None: + indices = meta.get("val") if indices is None: return False - return ( - bool(node.args[2].op == "get_attr") - and (mode == 0 or mode == 1 or mode == 2) - and len(indices.shape) == 1 - ) + return len(indices.shape) == 1 # currently only support 1D indices @dynamo_tensorrt_converter( @@ -293,7 +271,6 @@ def embedding_bag_validator(node: Node) -> bool: { 0: (TRTTensor,), 1: (TRTTensor,), - 2: (np.ndarray, torch.Tensor), } ) def aten_ops_embedding_bag( @@ -311,12 +288,9 @@ def aten_ops_embedding_bag( weight=args[0], indices=args[1], offsets=args[2], - scale_grad_by_freq=args_bounds_check(args, 3, False), mode=args_bounds_check(args, 4, 0), - sparse=args_bounds_check(args, 5, False), per_sample_weights=args_bounds_check(args, 6, None), include_last_offset=args_bounds_check(args, 7, False), - # padding index is useful for training only ) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 8f11e7fb91..a263440128 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -5,6 +5,7 @@ import numpy as np import torch +import torch_tensorrt.dynamo.conversion.impl as impl from torch import SymBool, SymFloat, SymInt from torch.fx.node import Argument, Target from torch_tensorrt import _enums @@ -530,3 +531,111 @@ def flatten_dims( new_shape = tuple(shape[:start_dim]) + (num_elements,) + tuple(shape[end_dim + 1 :]) return new_shape + + +def append( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + original_tensor: TRTTensor, + new_value: Union[TRTTensor, int, float, torch.Tensor, np.ndarray], + dim: int = 0, +) -> TRTTensor: + """ + Append a new value to the last of the original tensor along the specified dimension (default 0). + For example, if the original tensor is [1, 2, 3], the new value is 4, and the dim is 0, + the new tensor will be [1, 2, 3, 4]. + + Args: + ctx (ConversionContext): A ConversionContext containing the TensorRT network + target (Target): Target of calling node + source_ir (Optional[SourceIR]): SourceIR of calling converter + name (str): Name of the calling layer + original_tensor (TRTTensor): A TRTTensor to append the new value to + new_value (Union[TRTTensor, int, float, torch.Tensor, np.ndarray]): A new value to append + dim (int, optional): Dimention to append the new value. Defaults to 0. + + Returns: + TRTTensor: A new TRTTensor that is the result of appending the new value to the original tensor + """ + if isinstance(new_value, (int, float)): + new_value = np.array([new_value]) + new_value = get_trt_tensor(ctx, new_value, name, original_tensor.dtype) + + return impl.cat.cat( + ctx, + target, + source_ir, + f"{name}_concat", + [original_tensor, new_value], + get_positive_dim(dim, len(original_tensor.shape)), + ) + + +def set_item( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + original_tensor: TRTTensor, + index: int, + new_value: Union[TRTTensor, int, float, torch.Tensor, np.ndarray], +) -> TRTTensor: + """ + Set a new value to the original tensor at the specified index. For example, + if the original tensor is [1, 2, 3], the new value is 4, and the index is 1, + the new tensor will be [1, 4, 3]. + If the index is out of bound, the new value will be appended to the end. + + Args: + ctx (ConversionContext): A ConversionContext containing the TensorRT network + target (Target): Target of calling node + source_ir (Optional[SourceIR]): SourceIR of calling converter + name (str): Name of the calling layer + original_tensor (TRTTensor): A TRTTensor to set the new value to + index (int): The index to set the new value + new_value (Union[TRTTensor, int, float, torch.Tensor, np.ndarray]): A new value to set + + Returns: + TRTTensor: A new TRTTensor that is the result of setting the new value to the original tensor + """ + if isinstance(new_value, (int, float)): + new_value = np.array([new_value]) + new_value = get_trt_tensor(ctx, new_value, name, original_tensor.dtype) + + len_original_tensor = original_tensor.shape[0] + index = get_positive_dim(index, len_original_tensor) + + front_tensor = impl.slice.slice_op( + ctx, + target, + source_ir, + f"{name}_slice_front", + original_tensor, + dim=0, + start=0, + stop=index, + step=1, + ) + rear_tensor = impl.slice.slice_op( + ctx, + target, + source_ir, + f"{name}_slice_rear", + original_tensor, + dim=0, + start=index + 1, + stop=len_original_tensor, + step=1, + ) + + ans = impl.cat.cat( + ctx, + target, + source_ir, + f"{name}_concat", + [front_tensor, new_value, rear_tensor], + 0, + ) + return ans diff --git a/py/torch_tensorrt/dynamo/conversion/impl/embedding.py b/py/torch_tensorrt/dynamo/conversion/impl/embedding.py index ee3354ae08..f4e98ac3ee 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/embedding.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/embedding.py @@ -1,16 +1,23 @@ import functools +import time from typing import Optional, Sequence, Tuple, Union import numpy as np +import tensorrt as trt import torch import torch_tensorrt.dynamo.conversion.impl as impl from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor, to_numpy +from torch_tensorrt.dynamo.conversion.converter_utils import ( + append, + cast_trt_tensor, + get_trt_tensor, + set_item, + to_numpy, +) from torch_tensorrt.fx.converters.converter_utils import set_layer_name - -import tensorrt as trt +from torch_tensorrt.fx.types import TRTTensor def embedding( @@ -18,25 +25,29 @@ def embedding( target: Target, source_ir: Optional[SourceIR], name: str, - input: trt.ITensor, - weight: trt.ITensor, - scale_grad_by_freq: bool, - sparse: bool, -) -> trt.ITensor: + input: TRTTensor, + weight: TRTTensor, +) -> TRTTensor: indices_tensor = input embedding_tensor = weight + if isinstance(indices_tensor, torch.Tensor) and indices_tensor.dtype == torch.int64: + raise RuntimeError( + "The `embedding` op has indices_tensor dtype=int64. This is incorrect since it has to be int32 to run on TRT." + ) indices_tensor = get_trt_tensor(ctx, indices_tensor, f"{name}_indices_tensor") embedding_tensor = get_trt_tensor(ctx, embedding_tensor, f"{name}_embedding_tensor") # unsupported parameters - # ignore padding_idx since it is meaningful for training only + # ignore padding_idx, scale_grad_by_freq, and sparse + # since they are meaningful for training only - if scale_grad_by_freq: - raise RuntimeError( - "Currently we don't support scale gradient by word frequency." - ) + # useful for training only + # if scale_grad_by_freq: + # raise RuntimeError( + # "Currently we don't support scale gradient by word frequency." + # ) - if sparse: - raise RuntimeError("Currently we don't support sparse gradient.") + # if sparse: + # raise RuntimeError("Currently we don't support sparse gradient.") # Implement embedding lookup with gather layer gather_layer = ctx.net.add_gather(embedding_tensor, indices_tensor, axis=0) @@ -44,34 +55,16 @@ def embedding( return gather_layer.get_output(0) -def embedding_bag( +def embedding_bag_with_traversable_offsets( ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, - weight: trt.ITensor, - indices: trt.ITensor, - offsets: Union[torch.Tensor, np.ndarray, Sequence[int]], - scale_grad_by_freq: bool, + embed: TRTTensor, + offsets_list: Union[torch.Tensor, np.ndarray, Sequence[int]], mode: int, - sparse: bool, - per_sample_weights: Optional[trt.ITensor], include_last_offset: bool, -) -> Tuple[trt.ITensor, trt.ITensor, trt.ITensor, trt.ITensor]: - """ - This function is for calculating embedding bags. - - In PyTorch, `offsets` is only used when input is 1D. If input is 2D of shape (B, N), - it will be treated as B bags (sequences) each of fixed length N, and this will return - B values aggregated in a way depending on the mode. `offsets` is ignored and required - to be None in this case. - - However, according to the schema, `offsets` is required for input with any dimensions. - Accordingly, this function flattens N-D input to 1D and then to calculate embedding bags. - """ - - # TODO: support 2D inputs - # indices = impl.shuffle.reshape(ctx, target, source_ir, f"{name}_reshape_indices", indices, (-1,)) +) -> Tuple[TRTTensor, TRTTensor, TRTTensor, TRTTensor]: reduce_name = "" if mode == 0: # sum reduce_op = functools.partial( @@ -93,6 +86,270 @@ def embedding_bag( ) reduce_name = "max" + offsets: np.ndarray = to_numpy(offsets_list) + len_embed = embed.shape[0] + + if include_last_offset: + # modify the last index of offsets to the end index + # however, pytorch doc says if `include_last_offset` is True, the size of offsets + # is equal to the number of bags + 1. The last element is the size of the input, + # or the ending index position of the last bag (sequence). + offsets.itemset(-1, len_embed) + else: + # add the end index to offsets + offsets = np.append(offsets, len_embed) + + zero_tensor = get_trt_tensor( + ctx, np.zeros((1, embed.shape[1]), dtype=np.float32), f"{name}_zero_tensor" + ) + + # separately reduce embeddings for different bags + reduced_embed_bags = [] + len_offsets = offsets.shape[0] + for i in range(len_offsets - 1): + if offsets[i] < offsets[i + 1]: + sliced_embed = impl.slice.slice_op( + ctx, + target, + source_ir, + f"{name}_slice_embed_{i}", + embed, + 0, + int(offsets[i]), + int(offsets[i + 1]), + 1, + ) + reduced_one_bag = reduce_op( + name=f"{name}_{reduce_name}_{i}", + input_val=sliced_embed, + dim=0, + keepdim=True, + ) + reduced_embed_bags.append(reduced_one_bag) + else: # offsets[i] == offsets[i + 1] + reduced_embed_bags.append(zero_tensor) + + out = impl.cat.cat(ctx, target, source_ir, f"{name}_cat", reduced_embed_bags, 0) + return out, None, None, None + + +def embedding_bag_with_ITensor_offsets( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + embed: TRTTensor, + offsets: TRTTensor, + mode: int, + include_last_offset: bool, +) -> Tuple[TRTTensor, TRTTensor, TRTTensor, TRTTensor]: + len_embed = embed.shape[0] + + if include_last_offset: + # modify the last index of offsets to the end index + # however, pytorch doc says if `include_last_offset` is True, the size of offsets + # is equal to the number of bags + 1. The last element is the size of the input, + # or the ending index position of the last bag (sequence). + offsets = set_item( + ctx, target, source_ir, f"{name}_set_item", offsets, -1, len_embed + ) + else: + # add the end index to `offsets` + offsets = append(ctx, target, source_ir, f"{name}_append", offsets, len_embed) + + # create a placeholder tensor, whose shape is the same as an embedding + # if mode is 0 (sum) or 1 (mean), the placeholder tensor is filled with zeros + # if mode is 2 (max), the placeholder tensor is filled with negative infinity + placeholder_tensor = ( + get_trt_tensor( + ctx, + np.full(embed.shape, -np.inf, dtype=np.float32), + f"{name}_negative_inf_tensor", + ) + if mode == 2 + else get_trt_tensor( + ctx, np.zeros(embed.shape, dtype=np.float32), f"{name}_zero_tensors" + ) + ) + + # prepare some tensors for future use + zero_tensor = get_trt_tensor( + ctx, np.zeros((embed.shape[1],), dtype=np.float32), f"{name}_zero_tensor" + ) + constant_0 = get_trt_tensor(ctx, 0, f"{name}_constant_tensor_0") + constant_1 = get_trt_tensor(ctx, 1, f"{name}_constant_tensor_1") + + # Use two for loops to calculate the embedding of each bag + ###### Outer loop: traverse offsets ###### + loop1 = ctx.net.add_loop() + trip_limit1 = ctx.net.add_constant( + shape=(), + weights=trt.Weights(np.array([offsets.shape[0] - 1], dtype=np.dtype("i"))), + ).get_output(0) + loop1.add_trip_limit(trip_limit1, trt.TripLimit.COUNT) + + rec1_i_tensor = loop1.add_recurrence(constant_1) + set_layer_name(rec1_i_tensor, target, f"{name}_rec1_i_tensor", source_ir) + i_tensor = rec1_i_tensor.get_output(0) + + start = ctx.net.add_gather(offsets, constant_0, 0).get_output(0) + rec1_start = loop1.add_recurrence(start) + set_layer_name(rec1_start, target, f"{name}_rec1_start", source_ir) + start = rec1_start.get_output(0) + + end = ctx.net.add_gather(offsets, constant_1, 0).get_output(0) + rec1_end = loop1.add_recurrence(end) + set_layer_name(rec1_end, target, f"{name}_rec1_end", source_ir) + end = rec1_end.get_output(0) + + ###### Inner loop: traverse indices ###### + loop2 = ctx.net.add_loop() + trip_limit2 = ctx.net.add_constant( + shape=(), weights=trt.Weights(np.array([len_embed], dtype=np.dtype("i"))) + ).get_output(0) + loop2.add_trip_limit(trip_limit2, trt.TripLimit.COUNT) + rec2_j_tensor = loop2.add_recurrence(constant_0) + set_layer_name(rec2_j_tensor, target, f"{name}_rec2_j_tensor", source_ir) + j_tensor = rec2_j_tensor.get_output(0) + + # create a TRT Select layer + cond1 = impl.elementwise.ge( + ctx, target, source_ir, f"{name}_ge_{time.time()}", j_tensor, start + ) + cond2 = impl.elementwise.lt( + ctx, target, source_ir, f"{name}_lt_{time.time()}", j_tensor, end + ) + condition1 = impl.elementwise.logical_and( + ctx, target, source_ir, f"{name}_and_{time.time()}", cond1, cond2 + ) + next_j = impl.elementwise.add( + ctx, target, source_ir, f"{name}_j_tensor_add_1_{time.time()}", j_tensor, 1 + ) + rec2_j_tensor.set_input(1, next_j) + loop_out2 = loop2.add_loop_output(condition1, trt.LoopOutput.CONCATENATE) + loop_out2.set_input(1, trip_limit2) + ####### Inner loop end ####### + + select_layer1 = ctx.net.add_select( + loop_out2.get_output(0), embed, placeholder_tensor + ) + one_bag = select_layer1.get_output(0) + + # reduce the one_bag along the dim=0, the result of which is an embedding of each bag + if mode == 0: # sum + reduced_one_bag = impl.reduce.sum( + ctx, + target, + source_ir, + name=f"{name}_sum_bag{time.time()}", + input_val=one_bag, + dim=0, + keepdim=False, + ) + + # Since one_bag includes many zeros, directly calculating mean will cause results incorrect + elif mode == 1: # mean + reduced_one_bag = impl.reduce.sum( + ctx, + target, + source_ir, + name=f"{name}_sum_bag{time.time()}", + input_val=one_bag, + dim=0, + keepdim=False, + ) + diff = impl.elementwise.sub( + ctx, target, source_ir, f"{name}_diff_bag{time.time()}", end, start + ) + reduced_one_bag = impl.elementwise.div( + ctx, + target, + source_ir, + f"{name}_div_bag{time.time()}", + reduced_one_bag, + diff, + ) + + elif mode == 2: # max + reduced_one_bag = impl.reduce.max( + ctx, + target, + source_ir, + name=f"{name}_max_bag{time.time()}", + input_val=one_bag, + dim=0, + keepdim=False, + return_indices=False, + ) + + # create a TRT conditional layer + conditional_layer1 = ctx.net.add_if_conditional() + condition2 = impl.elementwise.eq( + ctx, target, source_ir, f"{name}_condition2_eq_{time.time()}", start, end + ) + condition2 = impl.shuffle.reshape( + ctx, + target, + source_ir, + f"{name}_reshape_condition2_eq_{time.time()}", + condition2, + [], + ) + # set the combined condition to the conditional layer + conditional_layer1.set_condition(condition2) + # if true, run this subgraph + true_sg = conditional_layer1.add_input(zero_tensor) + # if false, run this subgraph + false_sg = conditional_layer1.add_input(reduced_one_bag) + + reduced_one_bag_layer = conditional_layer1.add_output( + true_sg.get_output(0), false_sg.get_output(0) + ) + + # reset the variables for the next iteration of the outer loop + next_i = impl.elementwise.add( + ctx, target, source_ir, f"{name}_i_tensor_add_1_{time.time()}", i_tensor, 1 + ) + rec1_i_tensor.set_input(1, next_i) + rec1_start.set_input(1, end) + rec1_end.set_input(1, ctx.net.add_gather(offsets, next_i, 0).get_output(0)) + + loop_out1 = loop1.add_loop_output( + reduced_one_bag_layer.get_output(0), trt.LoopOutput.CONCATENATE + ) + loop_out1.set_input(1, trip_limit1) + reduced_embed_bags = loop_out1.get_output(0) + ####### Outer loop end ####### + return reduced_embed_bags, None, None, None + + +def embedding_bag( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + weight: TRTTensor, + indices: TRTTensor, + offsets: TRTTensor, + mode: int, + per_sample_weights: Optional[TRTTensor], # for sum mode only + include_last_offset: bool, +) -> Tuple[TRTTensor, TRTTensor, TRTTensor, TRTTensor]: + """ + This function is for calculating embedding bags. + + In PyTorch, `offsets` is only used when input is 1D. If input is 2D of shape (B, N), + it will be treated as B bags (sequences) each of fixed length N, and this will return + B values aggregated in a way depending on the mode. `offsets` is ignored and required + to be None in this case. + + However, according to the schema, `offsets` is required for input with any dimensions. + Accordingly, this function flattens N-D input to 1D and then to calculate embedding bags. + """ + + # TODO: support 2D inputs + # indices = impl.shuffle.reshape(ctx, target, source_ir, f"{name}_reshape_indices", indices, (-1,)) + # calculate embedding embed = embedding( ctx, @@ -101,8 +358,9 @@ def embedding_bag( f"{name}_embedding", indices, weight, - scale_grad_by_freq, - sparse, + ) + embed = cast_trt_tensor( + ctx, embed, torch.float, f"{name}_cast_embed_to_fp32", target, source_ir ) # give weights to embedding @@ -130,43 +388,12 @@ def embedding_bag( per_sample_weights, ) - offsets = to_numpy(offsets) - - if include_last_offset is False: - # add the end index to offsets - offsets = np.append(offsets, indices.shape[0]) + if isinstance(offsets, TRTTensor): + return embedding_bag_with_ITensor_offsets( + ctx, target, source_ir, name, embed, offsets, mode, include_last_offset + ) else: - # modify the last index of offsets to the end index - # however, pytorch doc says if `include_last_offset` is True, the size of offsets - # is equal to the number of bags + 1. The last element is the size of the input, - # or the ending index position of the last bag (sequence). - offsets[-1] = indices.shape[0] # type: ignore[index] - - # separately reduce embeddings for different bags - reduced_embed = [] - len_offsets = len(offsets) - for i in range(len_offsets - 1): - if offsets[i] < offsets[i + 1]: - sliced_embed = impl.slice.slice_op( - ctx, - target, - source_ir, - f"{name}_slice_embed_{i}", - embed, - 0, - int(offsets[i]), - int(offsets[i + 1]), - 1, - ) - reduced_sliced_embed = reduce_op( - name=f"{name}_{reduce_name}_{i}", - input_val=sliced_embed, - dim=0, - keepdim=True, - ) - reduced_embed.append(reduced_sliced_embed) - - out = impl.cat.cat(ctx, target, source_ir, f"{name}_cat", reduced_embed, 0) - # out = reduce_op(input_val=embed, dim=1, keepdim=False) # Note: This implementation doesn't work for N-dim - - return out, None, None, None + # this branch has less time complexity + return embedding_bag_with_traversable_offsets( + ctx, target, source_ir, name, embed, offsets, mode, include_last_offset + ) diff --git a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py index 84d9af5939..bfb3d9545c 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py +++ b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py @@ -35,7 +35,6 @@ aten.diagonal_backward, aten.dot, aten.elu_backward, - aten._embedding_bag, aten.embedding_dense_backward, aten.empty_like, aten._euclidean_dist.default, diff --git a/tests/py/dynamo/conversion/test_embedding_aten.py b/tests/py/dynamo/conversion/test_embedding_aten.py index 0ce4c5b49b..c04d89ff9e 100644 --- a/tests/py/dynamo/conversion/test_embedding_aten.py +++ b/tests/py/dynamo/conversion/test_embedding_aten.py @@ -1,5 +1,4 @@ import torch -import torch.nn as nn from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests from torch_tensorrt import Input @@ -14,11 +13,13 @@ class TestEmbeddingConverter(DispatchTestCase): test_name="1d_indices", indices_tensor=torch.tensor([3, 1, 2], dtype=torch.int32), weights_tensor=torch.randn((5, 10), dtype=torch.float32), + sparse=False, ), param( test_name="2d_indices", indices_tensor=torch.tensor([[3, 1, 2], [4, 1, 3]], dtype=torch.int32), weights_tensor=torch.randn((5, 10), dtype=torch.float32), + sparse=True, ), param( test_name="3d_indices", @@ -26,6 +27,7 @@ class TestEmbeddingConverter(DispatchTestCase): [[[0, 1], [2, 3]], [[3, 4], [4, 0]]], dtype=torch.int32 ), weights_tensor=torch.randn((5, 10), dtype=torch.float32), + sparse=True, ), ] ) @@ -38,7 +40,7 @@ def test_embedding( max_norm=None, norm_type=2.0, scale_grad_by_freq=None, - sparse=None, + sparse=False, ): class TestEmbedding(torch.nn.Module): def forward(self, indices, weights): diff --git a/tests/py/dynamo/conversion/test_embedding_bag_aten.py b/tests/py/dynamo/conversion/test_embedding_bag_aten.py index 6d7b05f0e1..2154937b43 100644 --- a/tests/py/dynamo/conversion/test_embedding_bag_aten.py +++ b/tests/py/dynamo/conversion/test_embedding_bag_aten.py @@ -8,12 +8,65 @@ class TestEmbeddingBagConverter(DispatchTestCase): @parameterized.expand( [ + # mode=0: sum, mode=1: mean, mode=2: max # 1D input param( test_name="1d_indices_1", - weight=torch.randn((10, 3), dtype=torch.float32), - indices=torch.tensor([1, 2, 4, 5, 4, 3], dtype=torch.int32), - offsets=torch.tensor([0, 3], dtype=torch.int32), + weight=torch.randn((10, 2), dtype=torch.float16), + indices=torch.tensor( + [1, 2, 4, 5, 4, 3, 2, 6, 8, 1, 2], dtype=torch.int32 + ), + offsets=torch.tensor([0, 2, 4], dtype=torch.int32), + scale_grad_by_freq=False, + mode=0, + sparse=True, + per_sample_weights=None, + include_last_offset=False, + padding_idx=-1, + ), + param( + test_name="1d_indices_2", + weight=torch.randn((10, 2), dtype=torch.float16), + indices=torch.tensor( + [1, 2, 4, 5, 4, 3, 2, 6, 8, 1, 2], dtype=torch.int32 + ), + offsets=torch.tensor([0, 2, 4], dtype=torch.int32), + scale_grad_by_freq=False, + mode=1, + sparse=True, + per_sample_weights=None, + include_last_offset=True, + padding_idx=-1, + ), + param( + test_name="1d_indices_3", + weight=torch.randn((10, 4), dtype=torch.float16), + indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32), + offsets=torch.tensor([0, 2, 8], dtype=torch.int32), + scale_grad_by_freq=False, + mode=2, + sparse=False, + per_sample_weights=None, + include_last_offset=False, + padding_idx=-1, + ), + param( + test_name="1d_indices_4", + weight=torch.randn((10, 4), dtype=torch.float16), + indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32), + offsets=torch.tensor([0, 2, 8], dtype=torch.int32), + scale_grad_by_freq=False, + mode=0, + sparse=False, + per_sample_weights=torch.randn((8,), dtype=torch.float16), + include_last_offset=True, + padding_idx=-1, + ), + param( + test_name="1d_indices_5", + weight=torch.randn((10, 4), dtype=torch.float32), + indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32), + offsets=torch.tensor([0, 5, 5], dtype=torch.int32), scale_grad_by_freq=False, mode=1, sparse=False, @@ -22,22 +75,150 @@ class TestEmbeddingBagConverter(DispatchTestCase): padding_idx=-1, ), param( - test_name="1d_indices_2", - weight=torch.randn((10, 3), dtype=torch.float32), - indices=torch.tensor([1, 2, 4, 5, 4, 3], dtype=torch.int32), - offsets=torch.tensor([0, 5], dtype=torch.int32), + test_name="1d_indices_6", + weight=torch.randn((10, 4), dtype=torch.float32), + indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32), + offsets=torch.tensor([0, 5, 5], dtype=torch.int32), + scale_grad_by_freq=False, + mode=2, + sparse=False, + per_sample_weights=None, + include_last_offset=False, + padding_idx=-1, + ), + param( + test_name="1d_indices_7", + weight=torch.randn((10, 4), dtype=torch.float32), + indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32), + offsets=torch.tensor([0, 8, 8], dtype=torch.int32), scale_grad_by_freq=False, mode=0, sparse=False, - per_sample_weights=torch.randn((6,)), + per_sample_weights=None, + include_last_offset=True, + padding_idx=-1, + ), + param( + test_name="1d_indices_8", + weight=torch.randn((10, 4), dtype=torch.float32), + indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32), + offsets=torch.tensor([0, 8, 8], dtype=torch.int32), + scale_grad_by_freq=False, + mode=1, + sparse=False, + per_sample_weights=None, + include_last_offset=False, + padding_idx=-1, + ), + ] + ) + def test_embedding_bag_with_traversable_offsets( + self, + test_name, + weight, + indices, + offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, + ): + class TestEmbeddingBag(torch.nn.Module): + def forward(self, weight, indices): + return torch.ops.aten._embedding_bag.default( + weight, + indices, + offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, + )[0] + + self.run_test( + TestEmbeddingBag(), + inputs=[weight, indices], + precision=weight.dtype, + enable_passes=True, + ) + + @parameterized.expand( + [ + # mode=0: sum, mode=1: mean, mode=2: max + # 1D input + param( + test_name="1d_indices_1", + weight=torch.randn((10, 2), dtype=torch.float32), + indices=torch.tensor( + [1, 2, 4, 5, 4, 3, 2, 6, 8, 1, 2], dtype=torch.int32 + ), + offsets=torch.tensor([0, 2, 4], dtype=torch.int32), + scale_grad_by_freq=False, + mode=0, + sparse=True, + per_sample_weights=None, include_last_offset=False, padding_idx=-1, ), + param( + test_name="1d_indices_2", + weight=torch.randn((10, 2), dtype=torch.float32), + indices=torch.tensor( + [1, 2, 4, 5, 4, 3, 2, 6, 8, 1, 2], dtype=torch.int32 + ), + offsets=torch.tensor([0, 2, 4], dtype=torch.int32), + scale_grad_by_freq=False, + mode=1, + sparse=True, + per_sample_weights=None, + include_last_offset=True, + padding_idx=-1, + ), param( test_name="1d_indices_3", - weight=torch.randn((10, 3), dtype=torch.float32), + weight=torch.randn((10, 4), dtype=torch.float32), indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32), - offsets=torch.tensor([0, 2, 4], dtype=torch.int32), + offsets=torch.tensor([0, 2, 8], dtype=torch.int32), + scale_grad_by_freq=False, + mode=2, + sparse=False, + per_sample_weights=None, + include_last_offset=False, + padding_idx=-1, + ), + param( + test_name="1d_indices_4", + weight=torch.randn((10, 4), dtype=torch.float32), + indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32), + offsets=torch.tensor([0, 2, 8], dtype=torch.int32), + scale_grad_by_freq=False, + mode=0, + sparse=False, + per_sample_weights=torch.randn((8,), dtype=torch.float32), + include_last_offset=True, + padding_idx=-1, + ), + param( + test_name="1d_indices_5", + weight=torch.randn((10, 4), dtype=torch.float16), + indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32), + offsets=torch.tensor([0, 5, 5], dtype=torch.int32), + scale_grad_by_freq=False, + mode=1, + sparse=False, + per_sample_weights=None, + include_last_offset=True, + padding_idx=-1, + ), + param( + test_name="1d_indices_6", + weight=torch.randn((10, 4), dtype=torch.float16), + indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32), + offsets=torch.tensor([0, 5, 5], dtype=torch.int32), scale_grad_by_freq=False, mode=2, sparse=False, @@ -45,6 +226,30 @@ class TestEmbeddingBagConverter(DispatchTestCase): include_last_offset=False, padding_idx=-1, ), + param( + test_name="1d_indices_7", + weight=torch.randn((10, 4), dtype=torch.float16), + indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32), + offsets=torch.tensor([0, 8, 8], dtype=torch.int32), + scale_grad_by_freq=False, + mode=0, + sparse=False, + per_sample_weights=None, + include_last_offset=True, + padding_idx=-1, + ), + param( + test_name="1d_indices_8", + weight=torch.randn((10, 4), dtype=torch.float16), + indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32), + offsets=torch.tensor([0, 8, 8], dtype=torch.int32), + scale_grad_by_freq=False, + mode=1, + sparse=False, + per_sample_weights=None, + include_last_offset=False, + padding_idx=-1, + ), # 2D input # param( # test_name="2d_indices_1", @@ -103,7 +308,7 @@ class TestEmbeddingBagConverter(DispatchTestCase): # ), ] ) - def test_embedding_bag( + def test_embedding_bag_with_ITensor_offsets( self, test_name, weight, @@ -117,7 +322,7 @@ def test_embedding_bag( padding_idx, ): class TestEmbeddingBag(torch.nn.Module): - def forward(self, weight, indices): + def forward(self, weight, indices, offsets): return torch.ops.aten._embedding_bag.default( weight, indices, @@ -132,7 +337,71 @@ def forward(self, weight, indices): self.run_test( TestEmbeddingBag(), - inputs=[weight, indices], + inputs=[weight, indices, offsets], + precision=weight.dtype, + enable_passes=True, + ) + + @parameterized.expand( + [ + param( + test_name="dynamic_offsets_1", + weight=torch.range(0, 29, dtype=torch.float32).reshape(15, 2), + indices=torch.tensor([i for i in range(15)], dtype=torch.int32), + offsets=torch.tensor([0, 2], dtype=torch.int32), + scale_grad_by_freq=False, + mode=0, + sparse=False, + per_sample_weights=None, + include_last_offset=False, + padding_idx=-1, + ), + ] + ) + def test_embedding_bag_with_dynamic_offsets( + self, + test_name, + weight, + indices, + offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, + ): + class TestEmbeddingBag(torch.nn.Module): + def forward(self, weight, indices, offsets): + offsets_list = [] + end = torch.randint(8, 14, (1,))[0] + for i in range(3, 0, -1): + rand_tensor = torch.arange(5, end, step=i, dtype=torch.int32) + offsets_list.append( + torch.ops.aten.cat.default((offsets, rand_tensor)) + ) + + res = [] + for one_offsets in offsets_list: + output = torch.ops.aten._embedding_bag.default( + weight, + indices, + one_offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, + )[0] + res.append(output) + + return res + + self.run_test( + TestEmbeddingBag(), + inputs=[weight, indices, offsets], + precision=weight.dtype, enable_passes=True, )