Skip to content

Commit

Permalink
Update formatting.
Browse files Browse the repository at this point in the history
  • Loading branch information
vahidk committed Aug 18, 2024
1 parent 4b02596 commit 0a412c6
Show file tree
Hide file tree
Showing 12 changed files with 170 additions and 191 deletions.
6 changes: 3 additions & 3 deletions run_tests.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import unittest
import sys
import unittest

if __name__ == '__main__':
if __name__ == "__main__":
loader = unittest.TestLoader()
tests = loader.discover('tests')
tests = loader.discover("tests")
testRunner = unittest.TextTestRunner()
result = testRunner.run(tests)
# Exit with a non-zero status code if tests failed
Expand Down
29 changes: 14 additions & 15 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,34 @@
import os
import sys

from distutils.core import setup
from setuptools import find_packages

from setuptools import find_packages

# List of runtime dependencies required by this built package
install_requires = []
if sys.version_info <= (2, 7):
install_requires += ['future', 'typing']
install_requires += ['numpy', 'protobuf', 'crc32c']
install_requires += ["future", "typing"]
install_requires += ["numpy", "protobuf", "crc32c"]

# read the contents of README file
this_directory = os.path.abspath(os.path.dirname(__file__))
with open(os.path.join(this_directory, 'README.md')) as f:
with open(os.path.join(this_directory, "README.md")) as f:
long_description = f.read()

setup(
name='tfrecord',
version='1.14.5',
description='TFRecord reader',
name="tfrecord",
version="1.14.5",
description="TFRecord reader",
long_description=long_description,
long_description_content_type='text/markdown',
author='Vahid Kazemi',
author_email='[email protected]',
url='https://github.com/vahidk/tfrecord',
long_description_content_type="text/markdown",
author="Vahid Kazemi",
author_email="[email protected]",
url="https://github.com/vahidk/tfrecord",
packages=find_packages(),
license='MIT',
license="MIT",
install_requires=install_requires,
extras_require={
'torch': ['torch'],
"torch": ["torch"],
},
test_suite='tests',
test_suite="tests",
)
8 changes: 2 additions & 6 deletions tests/test_read_and_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ def test_write_and_read_integers(self):

self.assertEqual(len(records), 1)
example = list(example_loader(filename, None))
np.testing.assert_array_equal(
example[0]["int_key"], np.array([123], dtype=np.int64)
)
np.testing.assert_array_equal(example[0]["int_key"], np.array([123], dtype=np.int64))

os.remove(filename)

Expand All @@ -47,9 +45,7 @@ def test_write_and_read_floats(self):

self.assertEqual(len(records), 1)
example = list(example_loader(filename, None))
np.testing.assert_array_equal(
example[0]["float_key"], np.array([1.23], dtype=np.float32)
)
np.testing.assert_array_equal(example[0]["float_key"], np.array([1.23], dtype=np.float32))

os.remove(filename)

Expand Down
28 changes: 6 additions & 22 deletions tests/test_reader.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,24 @@
import unittest
from unittest.mock import mock_open, patch

import numpy as np
from tfrecord.reader import (
example_loader,
sequence_loader,
tfrecord_iterator,
process_feature,
)

from tfrecord import example_pb2
from tfrecord.reader import process_feature


class TestFeatureProcessing(unittest.TestCase):

def setUp(self):
self.feature_bytes = example_pb2.Feature(
bytes_list=example_pb2.BytesList(value=[b"test"])
)
self.feature_float = example_pb2.Feature(
float_list=example_pb2.FloatList(value=[1.0])
)
self.feature_int = example_pb2.Feature(
int64_list=example_pb2.Int64List(value=[1])
)
self.feature_bytes = example_pb2.Feature(bytes_list=example_pb2.BytesList(value=[b"test"]))
self.feature_float = example_pb2.Feature(float_list=example_pb2.FloatList(value=[1.0]))
self.feature_int = example_pb2.Feature(int64_list=example_pb2.Int64List(value=[1]))

def test_process_feature_bytes(self):
result = process_feature(
self.feature_bytes, "byte", {"byte": "bytes_list"}, "key"
)
result = process_feature(self.feature_bytes, "byte", {"byte": "bytes_list"}, "key")
self.assertEqual(result, b"test")

def test_process_feature_float(self):
result = process_feature(
self.feature_float, "float", {"float": "float_list"}, "key"
)
result = process_feature(self.feature_float, "float", {"float": "float_list"}, "key")
np.testing.assert_array_equal(result, np.array([1.0], dtype=np.float32))

def test_process_feature_int(self):
Expand Down
9 changes: 5 additions & 4 deletions tests/test_writer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import unittest
import tempfile
import os
import numpy as np
import tempfile
import unittest

from tfrecord.reader import tfrecord_iterator
from tfrecord.writer import TFRecordWriter
Expand Down Expand Up @@ -33,7 +32,9 @@ def test_tfrecord_writer_write_sequence_example(self):

iterator = tfrecord_iterator(filename)
records = list(iterator)
self.assertTrue(records[0].tobytes().startswith(b"\n\x12\n\x10\n\x03key\x12\t\n\x07\n\x05value"))
self.assertTrue(
records[0].tobytes().startswith(b"\n\x12\n\x10\n\x03key\x12\t\n\x07\n\x05value")
)
os.remove(filename)


Expand Down
17 changes: 8 additions & 9 deletions tfrecord/iterator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ def cycle(iterator_fn: typing.Callable) -> typing.Iterable[typing.Any]:
yield element


def sample_iterators(iterators: typing.List[typing.Iterator],
ratios: typing.List[int],
infinite: bool = True) -> typing.Iterable[typing.Any]:
def sample_iterators(
iterators: typing.List[typing.Iterator], ratios: typing.List[int], infinite: bool = True
) -> typing.Iterable[typing.Any]:
"""Retrieve info generated from the iterator(s) according to their
sampling ratios.
Expand All @@ -28,7 +28,7 @@ def sample_iterators(iterators: typing.List[typing.Iterator],
ratios: list of int
The ratios with which to sample each iterator.
infinite: bool, optional, default=True
Whether the returned iterator should be infinite or not
Expand All @@ -55,9 +55,7 @@ def sample_iterators(iterators: typing.List[typing.Iterator],
ratios = ratios / ratios.sum()



def shuffle_iterator(iterator: typing.Iterator,
queue_size: int) -> typing.Iterable[typing.Any]:
def shuffle_iterator(iterator: typing.Iterator, queue_size: int) -> typing.Iterable[typing.Any]:
"""Shuffle elements contained in an iterator.
Params:
Expand All @@ -80,8 +78,9 @@ def shuffle_iterator(iterator: typing.Iterator,
for _ in range(queue_size):
buffer.append(next(iterator))
except StopIteration:
warnings.warn("Number of elements in the iterator is less than the "
f"queue size (N={queue_size}).")
warnings.warn(
"Number of elements in the iterator is less than the " f"queue size (N={queue_size})."
)
while buffer:
index = np.random.randint(len(buffer))
try:
Expand Down
Loading

0 comments on commit 0a412c6

Please sign in to comment.