diff --git a/src/eventio/file_types.py b/src/eventio/file_types.py index ff29ba2..78e3e44 100644 --- a/src/eventio/file_types.py +++ b/src/eventio/file_types.py @@ -1,9 +1,8 @@ import gzip try: import zstandard as zstd - has_zstd = True -except ImportError: - has_zstd = False +except ModuleNotFoundError: + zstd = None from .constants import ( SYNC_MARKER_SIZE, @@ -11,24 +10,30 @@ SYNC_MARKER_BIG_ENDIAN, ) +ZSTD_MARKER = b'\x28\xb5\x2f\xfd' +GZIP_MARKER = b'\x1f\x8b' + + +def _check_marker(path, marker): + with open(path, 'rb') as f: + marker_bytes = f.read(len(marker)) + + if len(marker_bytes) < len(marker): + return False + + return marker_bytes == marker def is_gzip(path): '''Test if a file is gzipped by reading its first two bytes and compare to the gzip marker bytes. ''' - with open(path, 'rb') as f: - marker_bytes = f.read(2) - - return marker_bytes[0] == 0x1f and marker_bytes[1] == 0x8b + return _check_marker(path, GZIP_MARKER) def is_zstd(path): '''Test if a file is compressed using zstd using its magic marker bytes ''' - with open(path, 'rb') as f: - marker_bytes = f.read(4) - - return marker_bytes == b'\x28\xb5\x2f\xfd' + return _check_marker(path, ZSTD_MARKER) def is_eventio(path): @@ -39,7 +44,7 @@ def is_eventio(path): with gzip.open(path, 'rb') as f: marker_bytes = f.read(SYNC_MARKER_SIZE) elif is_zstd(path): - if not has_zstd: + if zstd is None: raise IOError('You need the `zstandard` module to read zstd files') with open(path, 'rb') as f: cctx = zstd.ZstdDecompressor() diff --git a/tests/test_open_file.py b/tests/test_open_file.py index 0c59c88..e5d2c48 100644 --- a/tests/test_open_file.py +++ b/tests/test_open_file.py @@ -1,6 +1,7 @@ -import eventio from os import path from itertools import zip_longest +import eventio +import pytest def test_is_install_folder_a_directory(): @@ -19,6 +20,13 @@ def test_file_is_iterable(): for event in f: pass +def test_empty(tmp_path): + path = tmp_path / "empty.dat" + path.write_bytes(b"") + + with pytest.raises(ValueError, match="^File .* is not an eventio file$"): + eventio.EventIOFile(path) + def test_file_has_objects_at_expected_position(): expected = [