From 74209ede2b5c0de6d7dee5d554e16c9ea7ab878f Mon Sep 17 00:00:00 2001 From: y Date: Tue, 26 Sep 2023 23:30:30 -0700 Subject: [PATCH] add batchsize in UT dataloader Signed-off-by: y --- test/algorithm/test_smooth_quant.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/test/algorithm/test_smooth_quant.py b/test/algorithm/test_smooth_quant.py index 1b8d7a927c4..74e4a2b8919 100644 --- a/test/algorithm/test_smooth_quant.py +++ b/test/algorithm/test_smooth_quant.py @@ -65,7 +65,7 @@ class TestSqDepthwiseConv(unittest.TestCase): def setUpClass(self): class RandDataloader: def __init__(self): - pass + self.batch_size = 1 def __iter__(self): yield torch.rand((1, 3, 1, 1)) @@ -142,7 +142,7 @@ class TestSqConvOpFuseAuto(unittest.TestCase): def setUpClass(self): class RandDataloader: def __init__(self): - pass + self.batch_size = 1 def __iter__(self): yield torch.rand((1, 3, 1, 1)) @@ -182,7 +182,7 @@ class TestSqConvOpFuse(unittest.TestCase): def setUpClass(self): class RandDataloader: def __init__(self): - pass + self.batch_size = 1 def __iter__(self): yield torch.rand((1, 3, 1, 1)) @@ -387,21 +387,21 @@ class TestSqListInput(unittest.TestCase): def setUpClass(self): class ListDataloader: def __init__(self): - pass + self.batch_size = 1 def __iter__(self): yield [torch.rand((1, 3))] class TupleDataloader: def __init__(self): - pass + self.batch_size = 1 def __iter__(self): yield (torch.rand((1, 3))) class ListTupleDataLoader: def __init__(self): - pass + self.batch_size = 1 def __iter__(self): input1 = torch.rand((1, 3)) @@ -500,7 +500,7 @@ class TestAlphaAutoLinear(unittest.TestCase): def setUpClass(self): class RandDataloader: def __init__(self): - pass + self.batch_size = 1 def __iter__(self): yield torch.rand((1, 3)) @@ -536,7 +536,7 @@ class TestSqLinearOpFuse(unittest.TestCase): def setUpClass(self): class RandDataloader: def __init__(self): - pass + self.batch_size = 1 def __iter__(self): yield torch.rand((1, 3)) @@ -995,7 +995,7 @@ class TestSqSkipOp_attn(unittest.TestCase): def setUpClass(self): class RandDataloader: def __init__(self): - pass + self.batch_size = 1 def __iter__(self): yield torch.rand((1, 4))