Skip to content

Commit

Permalink
fix(//tests): use right type for masked_fill test
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Aug 5, 2021
1 parent 0647d17 commit 4a5c28f
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion tests/core/conversion/converters/test_select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,13 +433,14 @@ TEST(Converters, ATenMaskedFillZerosConvertsCorrectly) {
%44 : Device = prim::Constant[value="cuda"]()
%8 : bool = prim::Constant[value=0]()
%7 : None = prim::Constant()
%f32_dtype: int = prim::Constant[value=11]()
%1 : int = prim::Constant[value=0]() # bert.py:5:26
%2 : int = prim::Constant[value=1]() # bert.py:5:32
%33 : int = prim::Constant[value=2]() # bert.py:6:31
%3 : int[] = prim::ListConstruct(%1, %1, %2)
%4 : int[] = prim::ListConstruct(%2, %2, %1)
%5 : int[][] = prim::ListConstruct(%3, %4)
%9 : Tensor = aten::tensor(%5, %1, %7, %8) # bert.py:5:11
%9 : Tensor = aten::tensor(%5, %f32_dtype, %7, %8) # bert.py:5:11
%mask.1 : Tensor = aten::to(%9, %44, %7, %8, %8) # bert.py:5:11
%mask.2 : Tensor = trt::const(%mask.1)
%34 : Tensor = aten::masked_fill(%x.1, %mask.1, %33) # bert.py:6:11
Expand Down

0 comments on commit 4a5c28f

Please sign in to comment.