From 4a5c28f368bd92bb59a7949d57bc2d9dd1b1a54c Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Thu, 5 Aug 2021 12:17:47 -0700 Subject: [PATCH] fix(//tests): use right type for masked_fill test Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- tests/core/conversion/converters/test_select.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/core/conversion/converters/test_select.cpp b/tests/core/conversion/converters/test_select.cpp index 3b06059625..ea60aebc18 100644 --- a/tests/core/conversion/converters/test_select.cpp +++ b/tests/core/conversion/converters/test_select.cpp @@ -433,13 +433,14 @@ TEST(Converters, ATenMaskedFillZerosConvertsCorrectly) { %44 : Device = prim::Constant[value="cuda"]() %8 : bool = prim::Constant[value=0]() %7 : None = prim::Constant() + %f32_dtype: int = prim::Constant[value=11]() %1 : int = prim::Constant[value=0]() # bert.py:5:26 %2 : int = prim::Constant[value=1]() # bert.py:5:32 %33 : int = prim::Constant[value=2]() # bert.py:6:31 %3 : int[] = prim::ListConstruct(%1, %1, %2) %4 : int[] = prim::ListConstruct(%2, %2, %1) %5 : int[][] = prim::ListConstruct(%3, %4) - %9 : Tensor = aten::tensor(%5, %1, %7, %8) # bert.py:5:11 + %9 : Tensor = aten::tensor(%5, %f32_dtype, %7, %8) # bert.py:5:11 %mask.1 : Tensor = aten::to(%9, %44, %7, %8, %8) # bert.py:5:11 %mask.2 : Tensor = trt::const(%mask.1) %34 : Tensor = aten::masked_fill(%x.1, %mask.1, %33) # bert.py:6:11