diff --git a/src/traffic/data/datasets/__init__.py b/src/traffic/data/datasets/__init__.py index 15c2f96c..8f2c5a9f 100644 --- a/src/traffic/data/datasets/__init__.py +++ b/src/traffic/data/datasets/__init__.py @@ -1,65 +1,57 @@ -import io -from hashlib import md5 -from typing import Any, Dict +from __future__ import annotations -from ... import cache_dir, config -from ...core import Traffic, tqdm -from .squawk7700 import Squawk7700Dataset +from typing import Any -datasets = dict( +from ... import config +from ...core import Traffic +from ._squawk7700 import Squawk7700Dataset +from .default import Default, Entry + +datasets: dict[str, Entry] = dict( paris_toulouse_2017=dict( url="https://ndownloader.figshare.com/files/20291055", md5sum="e869a60107fdb9f092f5abbb3b43a2c0", filename="city_pair_dataset.parquet", - reader=Traffic.from_file, ), airspace_bordeaux_2017=dict( url="https://ndownloader.figshare.com/files/20291040", md5sum="f7b057f12cc735a15b93984b9ae7b8fc", filename="airspace_dataset.parquet", - reader=Traffic.from_file, ), landing_toulouse_2017=dict( url="https://ndownloader.figshare.com/files/24926849", md5sum="141e6c39211c382e5dd8ec66096b3798", filename="toulouse2017.parquet.gz", - reader=Traffic.from_file, ), landing_zurich_2019=dict( url="https://ndownloader.figshare.com/files/20291079", md5sum="c5577f450424fa74ca673ed8a168c67f", filename="landing_dataset.parquet", - reader=Traffic.from_file, ), landing_dublin_2019=dict( url="https://data.4tu.nl/file/4e042fbc-4f76-4f28-ac4b-a0120558ceba/94ec5814-6ee9-4cbc-88f2-ca5e1e0dfbf8", md5sum="73cc3b882df958cc3b5de547740a5006", filename="EIDW_dataset.parquet", - reader=Traffic.from_file, ), landing_cdg_2019=dict( url="https://data.4tu.nl/file/4e042fbc-4f76-4f28-ac4b-a0120558ceba/0ad60c2d-a63d-446a-976f-d61fe262c144", md5sum="9a2af398037fbfb66f16bf171ca7cf93", filename="LFPG_dataset.parquet", - reader=Traffic.from_file, ), landing_amsterdam_2019=dict( url="https://data.4tu.nl/file/4e042fbc-4f76-4f28-ac4b-a0120558ceba/901d842e-658c-40a6-98cc-64688a560f57", md5sum="419ab7390ee0f3deb0d46fbeecc29c57", filename="EHAM_dataset.parquet", - reader=Traffic.from_file, ), landing_heathrow_2019=dict( url="https://data.4tu.nl/file/4e042fbc-4f76-4f28-ac4b-a0120558ceba/b40a8064-3b90-4416-8a32-842104b21e4d", md5sum="161470e4e93f088cead98178408aa8d1", filename="EGLL_dataset.parquet", - reader=Traffic.from_file, ), landing_londoncity_2019=dict( url="https://data.4tu.nl/file/4e042fbc-4f76-4f28-ac4b-a0120558ceba/74006294-2a66-4b63-9e32-424da2f74201", md5sum="a7ff695355759e72703f101a1e43298c", filename="EGLC_dataset.parquet", - reader=Traffic.from_file, ), ) @@ -71,55 +63,12 @@ __all__ = __all__ + list(config["datasets"].keys()) -def download_data(dataset: Dict[str, str]) -> io.BytesIO: - from .. import session - - f = session.get(dataset["url"], stream=True) - buffer = io.BytesIO() - - if "Content-Length" in f.headers: - total = int(f.headers["Content-Length"]) - for chunk in tqdm( - f.iter_content(1024), - total=total // 1024 + 1 if total % 1024 > 0 else 0, - desc="download", - ): - buffer.write(chunk) - else: - buffer.write(f.content) - - buffer.seek(0) - - compute_md5 = md5(buffer.getbuffer()).hexdigest() - if compute_md5 != dataset["md5sum"]: - raise RuntimeError( - f"Error in MD5 check: {compute_md5} instead of {dataset['md5sum']}" - ) - - return buffer - - -def get_dataset(dataset: Dict[str, Any]) -> Any: - dataset_dir = cache_dir / "datasets" - - if not dataset_dir.exists(): - dataset_dir.mkdir(parents=True) - - filename = dataset_dir / dataset["filename"] - if not filename.exists(): - buffer = download_data(dataset) - with filename.open("wb") as fh: - fh.write(buffer.getbuffer()) - - return dataset["reader"](filename) - - def __getattr__(name: str) -> Any: - on_disk = config.get("datasets", name, fallback=None) - if on_disk is not None: + if on_disk := config.get("datasets", name, fallback=None): return Traffic.from_file(on_disk) if name not in datasets: raise AttributeError(f"No such dataset: {name}") - return get_dataset(datasets[name]) + filename = Default().get_data(datasets[name]) + return Traffic.from_file(filename) diff --git a/src/traffic/data/datasets/squawk7700.py b/src/traffic/data/datasets/_squawk7700.py similarity index 100% rename from src/traffic/data/datasets/squawk7700.py rename to src/traffic/data/datasets/_squawk7700.py diff --git a/src/traffic/data/datasets/default.py b/src/traffic/data/datasets/default.py new file mode 100644 index 00000000..6596757c --- /dev/null +++ b/src/traffic/data/datasets/default.py @@ -0,0 +1,57 @@ +import hashlib +from pathlib import Path +from typing import TypedDict + +import httpx +from tqdm.rich import tqdm + +from ... import cache_dir + +client = httpx.Client() + + +class Entry(TypedDict): + url: str + md5sum: str + filename: str + + +class Default: + def __init__(self) -> None: + cache = cache_dir / "datasets" / "default" + if not cache.exists(): + cache.mkdir(parents=True) + + self.cache = cache + + def get_data(self, entry: Entry) -> Path: + if not (filename := self.cache / entry["filename"]).exists(): + md5_hash = hashlib.md5() + with filename.open("wb") as file_handle: + with client.stream("GET", entry["url"]) as response: + content_length = response.headers.get("Content-Length") + if content_length is None: + content = response.content + file_handle.write(content) + md5_hash.update(content) + else: + with tqdm( + total=int(content_length), + unit_scale=True, + unit_divisor=1024, + unit="B", + ) as progress: + n_bytes = response.num_bytes_downloaded + for chunk in response.iter_bytes(): + file_handle.write(chunk) + md5_hash.update(chunk) + progress.update( + response.num_bytes_downloaded - n_bytes + ) + n_bytes = response.num_bytes_downloaded + + if md5_hash.hexdigest() != entry["md5sum"]: + filename.unlink() + raise ValueError("Mismatch in MD5 hash") + + return filename