Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix flaky CSVIter test (#18390)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhreshold authored May 23, 2020
1 parent 497bf7e commit 6ab6128
Showing 1 changed file with 11 additions and 14 deletions.
25 changes: 11 additions & 14 deletions tests/python/unittest/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,12 +324,11 @@ def test_NDArrayIter_csr():
begin += batch_size


def test_LibSVMIter():
def test_LibSVMIter(tmpdir):

def check_libSVMIter_synthetic():
cwd = os.getcwd()
data_path = os.path.join(cwd, 'data.t')
label_path = os.path.join(cwd, 'label.t')
data_path = os.path.join(str(tmpdir), 'data.t')
label_path = os.path.join(str(tmpdir), 'label.t')
with open(data_path, 'w') as fout:
fout.write('1.0 0:0.5 2:1.2\n')
fout.write('-2.0\n')
Expand All @@ -342,7 +341,7 @@ def check_libSVMIter_synthetic():
fout.write('-3.0 2:1.2\n')
fout.write('4 1:1.0 2:-1.2\n')

data_dir = os.path.join(cwd, 'data')
data_dir = os.path.join(str(tmpdir), 'data')
data_train = mx.io.LibSVMIter(data_libsvm=data_path, label_libsvm=label_path,
data_shape=(3, ), label_shape=(3, ), batch_size=3)

Expand All @@ -367,7 +366,7 @@ def check_libSVMIter_news_data():
}
batch_size = 33
num_examples = news_metadata['num_examples']
data_dir = os.path.join(os.getcwd(), 'data')
data_dir = os.path.join(str(tmpdir), 'data')
get_bz2_data(data_dir, news_metadata['name'], news_metadata['url'],
news_metadata['origin_name'])
path = os.path.join(data_dir, news_metadata['name'])
Expand All @@ -388,9 +387,8 @@ def check_libSVMIter_news_data():
data_train.reset()

def check_libSVMIter_exception():
cwd = os.getcwd()
data_path = os.path.join(cwd, 'data.t')
label_path = os.path.join(cwd, 'label.t')
data_path = os.path.join(str(tmpdir), 'data.t')
label_path = os.path.join(str(tmpdir), 'label.t')
with open(data_path, 'w') as fout:
fout.write('1.0 0:0.5 2:1.2\n')
fout.write('-2.0\n')
Expand All @@ -403,7 +401,7 @@ def check_libSVMIter_exception():
fout.write('-2.0 0:0.125\n')
fout.write('-3.0 2:1.2\n')
fout.write('4 1:1.0 2:-1.2\n')
data_dir = os.path.join(cwd, 'data')
data_dir = os.path.join(str(tmpdir), 'data')
data_train = mx.io.LibSVMIter(data_libsvm=data_path, label_libsvm=label_path,
data_shape=(3, ), label_shape=(3, ), batch_size=3)
for batch in iter(data_train):
Expand All @@ -426,11 +424,10 @@ def test_DataBatch():
r'DataBatch: data shapes: \[\(2L?, 3L?\), \(7L?, 8L?\)\] label shapes: \[\(4L?, 5L?\)\]', str(batch))


def test_CSVIter():
def test_CSVIter(tmpdir):
def check_CSVIter_synthetic(dtype='float32'):
cwd = os.getcwd()
data_path = os.path.join(cwd, 'data.t')
label_path = os.path.join(cwd, 'label.t')
data_path = os.path.join(str(tmpdir), 'data.t')
label_path = os.path.join(str(tmpdir), 'label.t')
entry_str = '1'
if dtype is 'int32':
entry_str = '200000001'
Expand Down

0 comments on commit 6ab6128

Please sign in to comment.