Skip to content

Commit

Permalink
add batchsize in UT dataloader
Browse files Browse the repository at this point in the history
Signed-off-by: y <[email protected]>
  • Loading branch information
xin3he committed Sep 27, 2023
1 parent 3ae01fc commit 74209ed
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions test/algorithm/test_smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class TestSqDepthwiseConv(unittest.TestCase):
def setUpClass(self):
class RandDataloader:
def __init__(self):
pass
self.batch_size = 1

def __iter__(self):
yield torch.rand((1, 3, 1, 1))
Expand Down Expand Up @@ -142,7 +142,7 @@ class TestSqConvOpFuseAuto(unittest.TestCase):
def setUpClass(self):
class RandDataloader:
def __init__(self):
pass
self.batch_size = 1

def __iter__(self):
yield torch.rand((1, 3, 1, 1))
Expand Down Expand Up @@ -182,7 +182,7 @@ class TestSqConvOpFuse(unittest.TestCase):
def setUpClass(self):
class RandDataloader:
def __init__(self):
pass
self.batch_size = 1

def __iter__(self):
yield torch.rand((1, 3, 1, 1))
Expand Down Expand Up @@ -387,21 +387,21 @@ class TestSqListInput(unittest.TestCase):
def setUpClass(self):
class ListDataloader:
def __init__(self):
pass
self.batch_size = 1

def __iter__(self):
yield [torch.rand((1, 3))]

class TupleDataloader:
def __init__(self):
pass
self.batch_size = 1

def __iter__(self):
yield (torch.rand((1, 3)))

class ListTupleDataLoader:
def __init__(self):
pass
self.batch_size = 1

def __iter__(self):
input1 = torch.rand((1, 3))
Expand Down Expand Up @@ -500,7 +500,7 @@ class TestAlphaAutoLinear(unittest.TestCase):
def setUpClass(self):
class RandDataloader:
def __init__(self):
pass
self.batch_size = 1

def __iter__(self):
yield torch.rand((1, 3))
Expand Down Expand Up @@ -536,7 +536,7 @@ class TestSqLinearOpFuse(unittest.TestCase):
def setUpClass(self):
class RandDataloader:
def __init__(self):
pass
self.batch_size = 1

def __iter__(self):
yield torch.rand((1, 3))
Expand Down Expand Up @@ -995,7 +995,7 @@ class TestSqSkipOp_attn(unittest.TestCase):
def setUpClass(self):
class RandDataloader:
def __init__(self):
pass
self.batch_size = 1

def __iter__(self):
yield torch.rand((1, 4))
Expand Down

0 comments on commit 74209ed

Please sign in to comment.