diff --git a/kloppy/_providers/statsbomb.py b/kloppy/_providers/statsbomb.py index 1d750bc0..d26614ae 100644 --- a/kloppy/_providers/statsbomb.py +++ b/kloppy/_providers/statsbomb.py @@ -1,4 +1,3 @@ -import contextlib import warnings from typing import Union @@ -12,11 +11,6 @@ from kloppy.io import open_as_file, FileLike, Source -@contextlib.contextmanager -def dummy_context_mgr(): - yield None - - def load( event_data: FileLike, lineup_data: FileLike, diff --git a/kloppy/io.py b/kloppy/io.py index a3b43245..0726f99c 100644 --- a/kloppy/io.py +++ b/kloppy/io.py @@ -1,26 +1,53 @@ +"""I/O utilities for reading raw data.""" + +import bz2 import contextlib +import gzip import logging +import lzma import os import urllib.parse from dataclasses import dataclass, replace -from pathlib import PurePath -from typing import Union, IO, BinaryIO, Tuple - -from io import BytesIO +from io import BufferedWriter, BytesIO, TextIOWrapper +from typing import ( + IO, + BinaryIO, + ContextManager, + Generator, + Optional, + Tuple, + Union, +) from kloppy.config import get_config from kloppy.exceptions import InputNotFoundError from kloppy.infra.io.adapters import get_adapter - logger = logging.getLogger(__name__) -_open = open +DEFAULT_GZIP_COMPRESSION = 1 +DEFAULT_BZ2_COMPRESSION = 9 +DEFAULT_XZ_COMPRESSION = 6 + + +FilePath = Union[str, bytes, os.PathLike] +FileOrPath = Union[FilePath, IO] @dataclass(frozen=True) class Source: - data: "FileLike" + """A wrapper around a file-like object to enable optional inputs. + + Args: + data (FileLike): The file-like object. + optional (bool): Whether the file is optional. Defaults to False. + skip_if_missing (bool): Whether to skip the file if it is missing. Defaults to False. + + Example: + >>> open_as_file(Source.create("example.csv", optional=True)) + """ + + data: FileOrPath optional: bool = False skip_if_missing: bool = False @@ -31,45 +58,350 @@ def create(cls, input_: "FileLike", **kwargs): return Source(data=input_, **kwargs) -FileLike = Union[str, PurePath, bytes, IO[bytes], Source] +FileLike = Union[FileOrPath, Source] + + +def _file_or_path_to_binary_stream( + file_or_path: FileOrPath, binary_mode: str +) -> Tuple[BinaryIO, bool]: + """ + Converts a file path or a file-like object to a binary stream. + + Args: + file_or_path: The file path or file-like object to convert. + binary_mode: The binary mode to open the file in. Must be one of 'rb', 'wb', or 'ab'. + + Returns: + A tuple containing the binary stream and a boolean indicating whether + a new file was opened (True) or an existing file-like object was used (False). + """ + assert binary_mode in ("rb", "wb", "ab") + + if isinstance(file_or_path, (str, bytes)) or hasattr( + file_or_path, "__fspath__" + ): + # If file_or_path is a path-like object, open it and return the binary stream + return open(os.fspath(file_or_path), binary_mode), True # type: ignore + + if isinstance(file_or_path, TextIOWrapper): + # If file_or_path is a TextIOWrapper, return its underlying binary buffer + return file_or_path.buffer, False + + if hasattr(file_or_path, "readinto") or hasattr(file_or_path, "write"): + # If file_or_path is a file-like object, return it as is + return file_or_path, False # type: ignore + + raise TypeError( + f"Unsupported type for {file_or_path}, " + f"{file_or_path.__class__.__name__}." + ) + + +def _detect_format_from_content(file_or_path: FileOrPath) -> Optional[str]: + """ + Attempts to detect file format from the content by reading the first + 6 bytes. Returns None if no format could be detected. + """ + fileobj, closefd = _file_or_path_to_binary_stream(file_or_path, "rb") + try: + if not fileobj.readable(): + return None + if hasattr(fileobj, "peek"): + bs = fileobj.peek(6) + elif hasattr(fileobj, "seekable") and fileobj.seekable(): + current_pos = fileobj.tell() + bs = fileobj.read(6) + fileobj.seek(current_pos) + else: + return None + + if bs[:2] == b"\x1f\x8b": + # https://tools.ietf.org/html/rfc1952#page-6 + return "gz" + elif bs[:3] == b"\x42\x5a\x68": + # https://en.wikipedia.org/wiki/List_of_file_signatures + return "bz2" + elif bs[:6] == b"\xfd\x37\x7a\x58\x5a\x00": + # https://tukaani.org/xz/xz-file-format.txt + return "xz" + return None + finally: + if closefd: + fileobj.close() + + +def _detect_format_from_extension(filename: FilePath) -> Optional[str]: + """ + Attempt to detect file format from the filename extension. + Return None if no format could be detected. + """ + extensions = ("bz2", "xz", "gz") + + if isinstance(filename, bytes): + for ext in extensions: + if filename.endswith(b"." + ext.encode()): + return ext + if isinstance(filename, str): + for ext in extensions: + if filename.endswith("." + ext): + return ext -def get_file_extension(f: FileLike) -> str: - if isinstance(f, str): - return os.path.splitext(f)[1] - elif isinstance(f, PurePath): - return os.path.splitext(f.name)[1] - elif isinstance(f, Source): - return get_file_extension(f.data) + if hasattr(filename, "name"): + return _detect_format_from_extension(filename.name) + + return None + + +def _filepath_from_path_or_filelike(file_or_path: FileOrPath) -> str: + try: + return os.fspath(file_or_path) # type: ignore + except TypeError: + pass + + if hasattr(file_or_path, "name"): + name = file_or_path.name + if isinstance(name, str): + return name + elif isinstance(name, bytes): + return name.decode() + + return "" + + +def _open( + filename: FileOrPath, + mode: str = "rb", + compresslevel: Optional[int] = None, + format: Optional[str] = None, +) -> BinaryIO: + """ + A replacement for the "open" function that can also read and write + compressed files transparently. The supported compression formats are gzip, + bzip2 and xz. Filename can be a string, a Path or a file object. + + When writing, the file format is chosen based on the file name extension: + - .gz uses gzip compression + - .bz2 uses bzip2 compression + - .xz uses xz/lzma compression + - otherwise, no compression is used + + When reading, if a file name extension is available, the format is detected + using it, but if not, the format is detected from the contents. + + mode can be: 'rb', 'ab', or 'wb'. + + compresslevel is the compression level for writing to gzip, and xz. + This parameter is ignored for the other compression formats. + If set to None, a default depending on the format is used: + gzip: 6, xz: 6. + + format overrides the autodetection of input and output formats. This can be + useful when compressed output needs to be written to a file without an + extension. Possible values are "gz", "xz", "bz2" and "raw". In case of + "raw", no compression is used. + """ + if mode not in ("rb", "wb", "ab"): + raise ValueError("Mode '{}' not supported".format(mode)) + filepath = _filepath_from_path_or_filelike(filename) + + if format not in (None, "gz", "xz", "bz2", "raw"): + raise ValueError( + f"Format not supported: {format}. Choose one of: 'gz', 'xz', 'bz2'" + ) + + if format == "raw": + detected_format = None + else: + detected_format = format or _detect_format_from_extension(filepath) + if detected_format is None and "r" in mode: + detected_format = _detect_format_from_content(filename) + + if detected_format == "gz": + opened_file = _open_gz(filename, mode, compresslevel) + elif detected_format == "xz": + opened_file = _open_xz(filename, mode, compresslevel) + elif detected_format == "bz2": + opened_file = _open_bz2(filename, mode, compresslevel) else: - raise Exception("Could not determine extension") + opened_file, _ = _file_or_path_to_binary_stream(filename, mode) + + return opened_file + + +def _open_bz2( + filename: FileOrPath, + mode: str, + compresslevel: Optional[int] = None, +) -> BinaryIO: + assert mode in ("rb", "ab", "wb") + if compresslevel is None: + compresslevel = DEFAULT_BZ2_COMPRESSION + + if "r" in mode: + return bz2.open(filename, mode) # type: ignore + return BufferedWriter(bz2.open(filename, mode, compresslevel)) # type: ignore + + +def _open_xz( + filename: FileOrPath, + mode: str, + compresslevel: Optional[int] = None, +) -> BinaryIO: + assert mode in ("rb", "ab", "wb") + if compresslevel is None: + compresslevel = DEFAULT_XZ_COMPRESSION + + if "r" in mode: + return lzma.open(filename, mode) # type: ignore + return BufferedWriter(lzma.open(filename, mode, preset=compresslevel)) # type: ignore -def get_local_cache_stream(url: str, cache_dir: str) -> Tuple[BinaryIO, bool]: +def _open_gz( + filename: FileOrPath, + mode: str, + compresslevel: Optional[int] = None, +) -> BinaryIO: + assert mode in ("rb", "ab", "wb") + if compresslevel is None: + compresslevel = DEFAULT_GZIP_COMPRESSION + + if "r" in mode: + return gzip.open(filename, mode) # type: ignore + return BufferedWriter(gzip.open(filename, mode, compresslevel=compresslevel)) # type: ignore + + +def get_file_extension(file_or_path: FileLike) -> str: + """Determine the file extension of the given file-like object. + + If the file has compression extensions such as '.gz', '.xz', or '.bz2', + they will be stripped before determining the extension. + + Args: + file_or_path (FileLike): The file-like object whose extension needs to be determined. + + Returns: + str: The file extension, including the dot ('.') if present. + + Raises: + Exception: If the extension cannot be determined. + + Example: + >>> get_file_extension("example.xml.gz") + '.xml' + >>> get_file_extension(Path("example.txt")) + '.txt' + >>> get_file_extension(Source(data="example.csv")) + '.csv' + """ + if isinstance(file_or_path, (str, bytes)) or hasattr( + file_or_path, "__fspath__" + ): + path = os.fspath(file_or_path) # type: ignore + for ext in [".gz", ".xz", ".bz2"]: + if path.endswith(ext): + path = path[: -len(ext)] + return os.path.splitext(path)[1] + + if isinstance(file_or_path, Source): + return get_file_extension(file_or_path.data) + + raise TypeError( + f"Could not determine extension for input type: {type(file_or_path)}" + ) + + +def get_local_cache_stream( + url: str, cache_dir: str, mode: str = "rb", format: Optional[str] = None +) -> Tuple[BinaryIO, Union[bool, str]]: + """Get a stream to the local cache file for the given URL. + + Compressed files are read transparently. The supported compression formats + are gzip, bzip2 and xz. + + Args: + url (str): The URL to cache. + cache_dir (str): The directory where the cache file will be stored. + mode (str): The mode in which to open the cache file. Must be one of + 'rb', 'wb', or 'ab'. Defaults to 'ab'. + format (str): Overrides the autodetection of input and output formats. + Possible values are "gz", "xz", "bz2" and "raw". In case of "raw", + no compression is used.. + + Returns: + Tuple[BinaryIO, bool | str]: A tuple containing a binary stream to the + local cache file and the path to the cache file if it already + exists and is non-empty, otherwise False. + + Note: + - If the specified cache directory does not exist, it will be created. + - If the cache file does not exist, it will be created and will be + named after the URL. + + Example: + >>> stream, exists = get_local_cache_stream("https://example.com/data", "./cache") + """ + assert mode in ("rb", "wb", "ab") + + # Ensure the cache directory exists if not os.path.exists(cache_dir): os.makedirs(cache_dir) + # Generate the local filename based on the URL filename = urllib.parse.quote_plus(url) local_filename = f"{cache_dir}/{filename}" - # Open the file in append+read mode - # this makes sure: - # 1. The file is created when it does not exist - # 2. The file is not truncated when it does exist - # 3. The file can be read - return _open(local_filename, "a+b"), ( - os.path.exists(local_filename) - and os.path.getsize(local_filename) > 0 - and local_filename + # Ensure the file exists by opening it in append-binary mode, creating it if necessary + file_exists_and_non_empty = ( + os.path.exists(local_filename) and os.path.getsize(local_filename) > 0 ) + file = _open(local_filename, mode, format=format) + + return file, file_exists_and_non_empty @contextlib.contextmanager -def dummy_context_mgr(): - yield None +def dummy_context_mgr() -> Generator[None, None, None]: + yield + + +def open_as_file(input_: FileLike) -> ContextManager[Optional[BinaryIO]]: + """Open a byte stream to the given input object. + + The following input types are supported: + - A string or `pathlib.Path` object representing a local file path. + - A string representing a URL. It should start with 'http://' or + 'https://'. + - A string representing a path to a file in a Amazon S3 cloud storage + bucket. It should start with 's3://'. + - A xml or json string containing the data. The string should contain + a '{' or '<' character. Otherwise, it will be treated as a file path. + - A bytes object containing the data. + - A buffered binary stream that inherits from `io.BufferedIOBase`. + - A [Source](`kloppy.io.Source`) object that wraps any of the above + input types. + Args: + input_ (FileLike): The input object to be opened. -def open_as_file(input_: FileLike) -> IO: + Returns: + BinaryIO: A binary stream to the input object. + + Raises: + InputNotFoundError: If the input file is not found. + TypeError: If the input type is not supported. + + Example: + >>> with open_as_file("example.txt") as f: + ... contents = f.read() + + Note: + To support reading data from other sources, see the + [Adapter](`kloppy.io.adapters.Adapter`) class. + + If the given file path or URL ends with '.gz', '.xz', or '.bz2', the + file will be decompressed before being read. + """ if isinstance(input_, Source): if input_.data is None and input_.optional: # This saves us some additional code in every vendor specific code @@ -82,41 +414,56 @@ def open_as_file(input_: FileLike) -> IO: logging.info(f"Input {input_.data} not found. Skipping") return dummy_context_mgr() raise - elif isinstance(input_, str) or isinstance(input_, PurePath): - if isinstance(input_, PurePath): - input_ = str(input_) - is_path = True - else: - is_path = False - if not is_path and ("{" in input_ or "<" in input_): - return BytesIO(input_.encode("utf8")) - else: - adapter = get_adapter(input_) - if adapter: - cache_dir = get_config("cache") - if cache_dir: - stream, local_cache_file = get_local_cache_stream( - input_, cache_dir - ) - else: - stream = BytesIO() - local_cache_file = None - - if not local_cache_file: - logger.info(f"Retrieving {input_}") - adapter.read_to_stream(input_, stream) - logger.info("Retrieval complete") - else: - logger.info(f"Using local cached file {local_cache_file}") - stream.seek(0) - else: - if not os.path.exists(input_): - raise InputNotFoundError(f"File {input_} does not exist") + if isinstance(input_, str) and ("{" in input_ or "<" in input_): + # If input_ is a JSON or XML string, return it as a binary stream + return BytesIO(input_.encode("utf8")) - stream = _open(input_, "rb") - return stream - elif isinstance(input_, bytes): + if isinstance(input_, bytes): + # If input_ is a bytes object, return it as a binary stream return BytesIO(input_) - else: - return input_ + + if isinstance(input_, str) or hasattr(input_, "__fspath__"): + # If input_ is a path-like object, open it and return the binary stream + uri = _filepath_from_path_or_filelike(input_) + + adapter = get_adapter(uri) + if adapter: + cache_dir = get_config("cache") + assert cache_dir is None or isinstance(cache_dir, str) + if cache_dir: + stream, local_cache_file = get_local_cache_stream( + uri, cache_dir, "ab", format="raw" + ) + else: + stream, local_cache_file = BytesIO(), None + + if not local_cache_file: + logger.info(f"Retrieving {uri}") + adapter.read_to_stream(uri, stream) + logger.info("Retrieval complete") + else: + logger.info(f"Using local cached file {local_cache_file}") + + if cache_dir: + stream.close() + stream, _ = get_local_cache_stream(uri, cache_dir, "rb") + else: + stream.seek(0) + + else: + if not os.path.exists(uri): + raise InputNotFoundError(f"File {uri} does not exist") + + stream = _open(uri, "rb") + return stream + + if isinstance(input_, TextIOWrapper): + # If file_or_path is a TextIOWrapper, return its underlying binary buffer + return input_.buffer + + if hasattr(input_, "readinto"): + # If file_or_path is a file-like object, return it as is + return _open(input_) # type: ignore + + raise TypeError(f"Unsupported input type: {type(input_)}") diff --git a/kloppy/tests/test_helpers.py b/kloppy/tests/test_helpers.py index dd5036d6..0c421891 100644 --- a/kloppy/tests/test_helpers.py +++ b/kloppy/tests/test_helpers.py @@ -1,37 +1,31 @@ -import os import sys -from pathlib import Path import pytest - -from kloppy.config import config_context from pandas import DataFrame from pandas.testing import assert_frame_equal - +from kloppy import opta, statsbomb, tracab +from kloppy.config import config_context from kloppy.domain import ( - Period, - DatasetFlag, - Point, AttackingDirection, - TrackingDataset, - NormalizedPitchDimensions, + DatasetFlag, Dimension, - Orientation, - Provider, Frame, + Ground, Metadata, MetricaCoordinateSystem, - Team, - Ground, + NormalizedPitchDimensions, + Orientation, + Period, Player, PlayerData, + Point, Point3D, + Provider, + Team, + TrackingDataset, ) -from kloppy import opta, tracab, statsbomb -from kloppy.io import open_as_file - class TestHelpers: def _get_tracking_dataset(self): @@ -517,12 +511,3 @@ def test_to_df_pyarrow(self): df = dataset.to_df(engine="pandas[pyarrow]") assert isinstance(df, pd.DataFrame) assert isinstance(df.dtypes["ball_x"], pd.ArrowDtype) - - -class TestOpenAsFile: - def test_path(self): - path = Path(__file__).parent / "files/tracab_meta.xml" - with open_as_file(path) as fp: - data = fp.read() - - assert len(data) == os.path.getsize(path) diff --git a/kloppy/tests/test_io.py b/kloppy/tests/test_io.py new file mode 100644 index 00000000..78d3145e --- /dev/null +++ b/kloppy/tests/test_io.py @@ -0,0 +1,171 @@ +import gzip +import json +from io import BytesIO +from pathlib import Path + +import pytest +import s3fs +from moto import mock_aws + +from kloppy.exceptions import InputNotFoundError +from kloppy.io import get_file_extension, open_as_file + + +@pytest.fixture() +def filesystem_content(tmp_path: Path): + """Set up the content to be read from a file.""" + path = tmp_path / "testfile.txt" + with open(path, "w") as f: + f.write("Hello, world!") + + gz_path = tmp_path / "testfile.txt.gz" + with open(gz_path, "wb") as f: + import gzip + + with gzip.open(f, "wb") as f_out: + f_out.write(b"Hello, world!") + + xz_path = tmp_path / "testfile.txt.xz" + with open(xz_path, "wb") as f: + import lzma + + with lzma.open(f, "wb") as f_out: + f_out.write(b"Hello, world!") + + bz2_path = tmp_path / "testfile.txt.bz2" + with open(bz2_path, "wb") as f: + import bz2 + + with bz2.open(f, "wb") as f_out: + f_out.write(b"Hello, world!") + + return tmp_path + + +@pytest.fixture +def httpserver_content(httpserver): + """Set up the content to be read from a HTTP server.""" + httpserver.expect_request("/testfile.txt").respond_with_data( + "Hello, world!" + ) + httpserver.expect_request("/compressed_testfile.txt").respond_with_data( + gzip.compress(b"Hello, world!"), + headers={"Content-Encoding": "gzip", "Content-Type": "text/plain"}, + ) + httpserver.expect_request("/testfile.txt.gz").respond_with_data( + gzip.compress(b"Hello, world!"), + headers={"Content-Type": "application/x-gzip"}, + ) + + +@pytest.fixture +def s3_content(): + with mock_aws(): + s3_fs = s3fs.S3FileSystem(anon=True) + s3_fs.mkdir("test-bucket", region_name="eu-central-1") + with s3_fs.open("test-bucket/testfile.txt", "wb") as f: + f.write(b"Hello, world!") + with s3_fs.open("test-bucket/testfile.txt.gz", "wb") as f: + f.write(gzip.compress(b"Hello, world!")) + yield s3_fs + s3_fs.rm("test-bucket", recursive=True) + + +class TestOpenAsFile: + """Tests for the open_as_file function.""" + + def test_bytes(self): + """It should be able to open a file from a bytes object.""" + with open_as_file(b"Hello, world!") as fp: + assert fp is not None + assert fp.read() == b"Hello, world!" + + def test_data_string(self): + """It should be able to open a file from a string object.""" + with open_as_file('{"msg": "Hello, world!"}') as fp: + assert fp is not None + assert json.load(fp) == {"msg": "Hello, world!"} + + def test_stream(self): + """It should be able to open a file from a byte stream object.""" + data = b"Hello, world!" + with open_as_file(BytesIO(data)) as fp: + assert fp is not None + assert fp.read() == data + + def test_path_str(self, filesystem_content: Path): + """It should be able to open a file from a string path.""" + path = str(filesystem_content / "testfile.txt") + with open_as_file(path) as fp: + assert fp is not None + assert fp.read() == b"Hello, world!" + + def test_path_obj(self, filesystem_content: Path): + """It should be able to open a file from a Path object.""" + path = filesystem_content / "testfile.txt" + with open_as_file(path) as fp: + assert fp is not None + assert fp.read() == b"Hello, world!" + + @pytest.mark.parametrize("ext", [".gz", ".xz", ".bz2"]) + def test_path_compressed(self, filesystem_content: Path, ext: str): + """It should be able to open a compressed local file.""" + path = filesystem_content / f"testfile.txt{ext}" + with open_as_file(path) as fp: + assert fp is not None + assert fp.read() == b"Hello, world!" + + def test_path_missing(self, filesystem_content: Path): + """It should raise an error if the file is not found.""" + path = filesystem_content / "missing.txt" + with pytest.raises(InputNotFoundError): + with open_as_file(path) as fp: + pass + + def test_http(self, httpserver, httpserver_content): + """It should be able to open a file from a URL.""" + url = httpserver.url_for("/testfile.txt") + with open_as_file(url) as fp: + assert fp is not None + assert fp.read() == b"Hello, world!" + + def test_http_compressed(self, httpserver, httpserver_content): + """It should be able to open a compressed file from a URL.""" + # If the server returns a content-encoding header, the file should be + # decompressed by the request library + url = httpserver.url_for("/compressed_testfile.txt") + with open_as_file(url) as fp: + assert fp is not None + assert fp.read() == b"Hello, world!" + + # If the server does not set a content-type header, but the URL ends + # with .gz, the file should be decompressed by kloppy + url = httpserver.url_for("/testfile.txt.gz") + with open_as_file(url) as fp: + assert fp is not None + assert fp.read() == b"Hello, world!" + + @pytest.mark.skip( + reason="see https://github.com/aio-libs/aiobotocore/issues/755" + ) + def test_s3(self, s3_content): + """It should be able to open a file from an S3 bucket.""" + with open_as_file("s3://test-bucket/testfile.txt") as fp: + assert fp is not None + assert fp.read() == b"Hello, world!" + + @pytest.mark.skip( + reason="see https://github.com/aio-libs/aiobotocore/issues/755" + ) + def test_s3_compressed(self, s3_content): + """It should be able to open a file from an S3 bucket.""" + with open_as_file("s3://test-bucket/testfile.txt.gz") as fp: + assert fp is not None + assert fp.read() == b"Hello, world!" + + +def test_get_file_extension(): + assert get_file_extension(Path("data.xml")) == ".xml" + assert get_file_extension("data.xml") == ".xml" + assert get_file_extension("data.xml.gz") == ".xml" + assert get_file_extension("data") == "" diff --git a/setup.py b/setup.py index 09e10111..f13c9e67 100644 --- a/setup.py +++ b/setup.py @@ -52,6 +52,9 @@ def setup_package(): "polars>=0.16.6", "pyarrow", "pytest-lazy-fixture", + "s3fs", + "moto[s3]", + "pytest-httpserver", ], "development": ["pre-commit==2.6.0"], "query": ["networkx>=2.4,<3"],