Skip to content

Commit

Permalink
clarification of the datasets interface
Browse files Browse the repository at this point in the history
  • Loading branch information
xoolive committed Jan 10, 2024
1 parent fc06ceb commit 27d53e1
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 62 deletions.
73 changes: 11 additions & 62 deletions src/traffic/data/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,57 @@
import io
from hashlib import md5
from typing import Any, Dict
from __future__ import annotations

from ... import cache_dir, config
from ...core import Traffic, tqdm
from .squawk7700 import Squawk7700Dataset
from typing import Any

datasets = dict(
from ... import config
from ...core import Traffic
from ._squawk7700 import Squawk7700Dataset
from .default import Default, Entry

datasets: dict[str, Entry] = dict(
paris_toulouse_2017=dict(
url="https://ndownloader.figshare.com/files/20291055",
md5sum="e869a60107fdb9f092f5abbb3b43a2c0",
filename="city_pair_dataset.parquet",
reader=Traffic.from_file,
),
airspace_bordeaux_2017=dict(
url="https://ndownloader.figshare.com/files/20291040",
md5sum="f7b057f12cc735a15b93984b9ae7b8fc",
filename="airspace_dataset.parquet",
reader=Traffic.from_file,
),
landing_toulouse_2017=dict(
url="https://ndownloader.figshare.com/files/24926849",
md5sum="141e6c39211c382e5dd8ec66096b3798",
filename="toulouse2017.parquet.gz",
reader=Traffic.from_file,
),
landing_zurich_2019=dict(
url="https://ndownloader.figshare.com/files/20291079",
md5sum="c5577f450424fa74ca673ed8a168c67f",
filename="landing_dataset.parquet",
reader=Traffic.from_file,
),
landing_dublin_2019=dict(
url="https://data.4tu.nl/file/4e042fbc-4f76-4f28-ac4b-a0120558ceba/94ec5814-6ee9-4cbc-88f2-ca5e1e0dfbf8",
md5sum="73cc3b882df958cc3b5de547740a5006",
filename="EIDW_dataset.parquet",
reader=Traffic.from_file,
),
landing_cdg_2019=dict(
url="https://data.4tu.nl/file/4e042fbc-4f76-4f28-ac4b-a0120558ceba/0ad60c2d-a63d-446a-976f-d61fe262c144",
md5sum="9a2af398037fbfb66f16bf171ca7cf93",
filename="LFPG_dataset.parquet",
reader=Traffic.from_file,
),
landing_amsterdam_2019=dict(
url="https://data.4tu.nl/file/4e042fbc-4f76-4f28-ac4b-a0120558ceba/901d842e-658c-40a6-98cc-64688a560f57",
md5sum="419ab7390ee0f3deb0d46fbeecc29c57",
filename="EHAM_dataset.parquet",
reader=Traffic.from_file,
),
landing_heathrow_2019=dict(
url="https://data.4tu.nl/file/4e042fbc-4f76-4f28-ac4b-a0120558ceba/b40a8064-3b90-4416-8a32-842104b21e4d",
md5sum="161470e4e93f088cead98178408aa8d1",
filename="EGLL_dataset.parquet",
reader=Traffic.from_file,
),
landing_londoncity_2019=dict(
url="https://data.4tu.nl/file/4e042fbc-4f76-4f28-ac4b-a0120558ceba/74006294-2a66-4b63-9e32-424da2f74201",
md5sum="a7ff695355759e72703f101a1e43298c",
filename="EGLC_dataset.parquet",
reader=Traffic.from_file,
),
)

Expand All @@ -71,55 +63,12 @@
__all__ = __all__ + list(config["datasets"].keys())


def download_data(dataset: Dict[str, str]) -> io.BytesIO:
from .. import session

f = session.get(dataset["url"], stream=True)
buffer = io.BytesIO()

if "Content-Length" in f.headers:
total = int(f.headers["Content-Length"])
for chunk in tqdm(
f.iter_content(1024),
total=total // 1024 + 1 if total % 1024 > 0 else 0,
desc="download",
):
buffer.write(chunk)
else:
buffer.write(f.content)

buffer.seek(0)

compute_md5 = md5(buffer.getbuffer()).hexdigest()
if compute_md5 != dataset["md5sum"]:
raise RuntimeError(
f"Error in MD5 check: {compute_md5} instead of {dataset['md5sum']}"
)

return buffer


def get_dataset(dataset: Dict[str, Any]) -> Any:
dataset_dir = cache_dir / "datasets"

if not dataset_dir.exists():
dataset_dir.mkdir(parents=True)

filename = dataset_dir / dataset["filename"]
if not filename.exists():
buffer = download_data(dataset)
with filename.open("wb") as fh:
fh.write(buffer.getbuffer())

return dataset["reader"](filename)


def __getattr__(name: str) -> Any:
on_disk = config.get("datasets", name, fallback=None)
if on_disk is not None:
if on_disk := config.get("datasets", name, fallback=None):
return Traffic.from_file(on_disk)

if name not in datasets:
raise AttributeError(f"No such dataset: {name}")

return get_dataset(datasets[name])
filename = Default().get_data(datasets[name])
return Traffic.from_file(filename)
File renamed without changes.
57 changes: 57 additions & 0 deletions src/traffic/data/datasets/default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import hashlib
from pathlib import Path
from typing import TypedDict

import httpx
from tqdm.rich import tqdm

from ... import cache_dir

client = httpx.Client()


class Entry(TypedDict):
url: str
md5sum: str
filename: str


class Default:
def __init__(self) -> None:
cache = cache_dir / "datasets" / "default"
if not cache.exists():
cache.mkdir(parents=True)

self.cache = cache

def get_data(self, entry: Entry) -> Path:
if not (filename := self.cache / entry["filename"]).exists():
md5_hash = hashlib.md5()
with filename.open("wb") as file_handle:
with client.stream("GET", entry["url"]) as response:
content_length = response.headers.get("Content-Length")
if content_length is None:
content = response.content
file_handle.write(content)
md5_hash.update(content)
else:
with tqdm(
total=int(content_length),
unit_scale=True,
unit_divisor=1024,
unit="B",
) as progress:
n_bytes = response.num_bytes_downloaded
for chunk in response.iter_bytes():
file_handle.write(chunk)
md5_hash.update(chunk)
progress.update(
response.num_bytes_downloaded - n_bytes
)
n_bytes = response.num_bytes_downloaded

if md5_hash.hexdigest() != entry["md5sum"]:
filename.unlink()
raise ValueError("Mismatch in MD5 hash")

return filename

0 comments on commit 27d53e1

Please sign in to comment.