Skip to content

Commit

Permalink
feat(perf): Async requests, use gzip instead of .zip
Browse files Browse the repository at this point in the history
- Uses [niquests](https://niquests.readthedocs.io/en/latest/index.html) for async
- Also, see notes in  `_write_rezip`
  • Loading branch information
dangotbanned committed Dec 12, 2024
1 parent 2b1be70 commit aede7f6
Showing 1 changed file with 99 additions and 53 deletions.
152 changes: 99 additions & 53 deletions scripts/flights2.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "niquests",
# "polars",
# ]
# ///
from __future__ import annotations

import asyncio
import datetime as dt
import gzip
import io
import logging
import tomllib
import warnings
import zipfile
from collections import defaultdict, deque
from collections.abc import Sequence
from collections.abc import Iterable, Sequence
from enum import StrEnum
from functools import cached_property
from pathlib import Path
from typing import IO, TYPE_CHECKING
from urllib import request

import niquests
import polars as pl
from polars import col
from polars import selectors as cs
Expand Down Expand Up @@ -100,7 +103,8 @@ class DateTimeFormat(StrEnum):
"On_Time_Reporting_Carrier_On_Time_Performance_1987_present_"
)
ZIP: Literal[".zip"] = ".zip"
PATTERN_ZIP: LiteralString = f"*{REPORTING_PREFIX}*{ZIP}"
GZIP: Literal[".csv.gz"] = ".csv.gz"
PATTERN_GZIP: LiteralString = f"*{REPORTING_PREFIX}*{GZIP}"

COLUMNS_DEFAULT: Sequence[Columns] = (
"date",
Expand Down Expand Up @@ -204,27 +208,25 @@ def monthly(self) -> pl.Expr:
return pl.date_range(self.start, self.end, interval="1mo").alias("date")

@cached_property
def file_names(self) -> Sequence[str]:
"""Returns the file names of all sources the input would require."""
def file_stems(self) -> Sequence[str]:
"""Returns the file stems of all sources the input would require."""
date = col("date")
year, month = (date.dt.year().alias("year"), date.dt.month().alias("month"))
# Slightly different when only handling a single range
name_parts = pl.lit(REPORTING_PREFIX), year, pl.lit("_"), month, pl.lit(ZIP)
return tuple(
pl.select(self.monthly)
.lazy()
.select(pl.concat_str(*name_parts).sort_by(date))
.select(_file_stem_source(year, month).sort_by(date))
.collect()
.to_series()
.to_list()
)

def __eq__(self, other: Any, /) -> bool:
"""Two ``DateRange``s are equivalent if they would require the same files."""
return isinstance(other, DateRange) and self.file_names == other.file_names
return isinstance(other, DateRange) and self.file_stems == other.file_stems

def __hash__(self) -> int:
return hash(self.file_names)
return hash(self.file_stems)


class Spec:
Expand Down Expand Up @@ -333,21 +335,23 @@ def __init__(self, input_dir: Path, /) -> None:
def add_dependency(self, spec: Spec, /) -> None:
d_range: DateRange = spec.range
if d_range not in self._mapping:
self._frames[d_range] = self._extract(d_range).pipe(_clean_source)
self._frames[d_range] = self._scan(d_range).pipe(_clean_source)
self._mapping[d_range].append(spec)

def iter_tasks(self) -> Iterator[tuple[Spec, pl.LazyFrame]]:
for d_range, frame in self._frames.items():
for spec in self._mapping[d_range]:
yield spec, frame

def _extract(self, d_range: DateRange, /) -> pl.LazyFrame:
"""Combining zipped, monthly data into a single table."""
some: list[bytes] = []
for name in d_range.file_names:
for fp in zipfile.Path(Path(self.input_dir / name)).glob("*.csv"):
some.append(fp.read_bytes())
return _scan_csv(some)
def _scan(self, d_range: DateRange, /) -> pl.LazyFrame:
"""Lazily read all required files."""
# NOTE: files from `2001` have unused columns that break reading losslessly
return pl.scan_csv(
[self.input_dir / f"{stem}{GZIP}" for stem in d_range.file_stems],
try_parse_dates=True,
schema_overrides=SCAN_SCHEMA,
encoding="utf8-lossy",
).select(SCAN_SCHEMA.names())

def __len__(self) -> int:
return len(self._frames)
Expand Down Expand Up @@ -422,48 +426,42 @@ def ranges(self) -> pl.LazyFrame:
return pl.select(pl.concat(spec.range.monthly for spec in self)).lazy()

@property
def _required_names(self) -> set[str]:
def _required_stems(self) -> set[str]:
date = col("date")
name_parts = pl.lit(REPORTING_PREFIX), "year", pl.lit("_"), "month", pl.lit(ZIP)
return set(
self.ranges.select(
date.dt.year().alias("year"), date.dt.month().alias("month")
)
.unique()
.select(pl.concat_str(*name_parts))
.select(_file_stem_source("year", "month"))
.collect()
.to_series()
.to_list()
)

def _download_zip(self, name: str, /) -> Path:
"""Request a single month's data and write."""
fp = self.input_dir / name
msg = f"Requesting {name!r} ..."
logger.debug(msg)
with request.urlopen(f"{ROUTE_ZIP}{name}") as response:
fp.touch()
fp.write_bytes(response.read())
msg = f"Downloaded {name!r}."
logger.debug(msg)
return fp
@property
def _existing_stems(self) -> set[str]:
return {_without_suffixes(fp.name) for fp in app.input_dir.glob(PATTERN_GZIP)}

async def _download_into_gzip(self, names: Iterable[str], /) -> list[Path]:
"""Request, write missing data."""
it = (
_write_rezip_async(self.input_dir, buf)
for buf in await _request_all_async(names)
)
return await asyncio.gather(*it)

def download_sources(self) -> None:
"""Detect and download any missing monthly flights data - which are required by specs."""
logger.info("Detecting required sources ...")
existing = {fp.name for fp in self.input_dir.glob(PATTERN_ZIP)}
missing = self._required_names - existing
if missing:
msg_missing = f"Missing:\n {'\n '.join(sorted(missing))}"
logger.info(msg_missing)
if missing := self._required_stems - self._existing_stems:
msg = f"Missing {len(missing)} sources"
logger.info(msg)
if len(missing) >= 5:
warnings.warn("Downloads may exceed 100MB", stacklevel=2)
logger.warning("Downloads may exceed 100MB")
if len(missing) >= 11:
warnings.warn(
"Total number of rows will exceed 5_000_000", stacklevel=2
)
for name in missing:
self._download_zip(name)
logger.warning("Total number of rows will exceed 5_000_000")
asyncio.run(self._download_into_gzip(missing))
logger.info("Successfully downloaded all missing sources.")
else:
logger.info("Sources already downloaded.")
Expand All @@ -485,14 +483,62 @@ def run(self) -> None:
logger.info("Finished job.")


def _scan_csv(source: PlScanCsv, /) -> pl.LazyFrame:
# NOTE: files from `2001` have unused columns that break reading losslessly
return pl.scan_csv(
source,
try_parse_dates=True,
schema_overrides=SCAN_SCHEMA,
encoding="utf8-lossy",
).select(SCAN_SCHEMA.names())
async def _request_async(session: niquests.AsyncSession, name: str, /) -> io.BytesIO:
name = f"{_without_suffixes(name)}{ZIP}"
msg = f"Requesting {name!r} ..."
logger.info(msg)
async with session:
response = await session.get(name)
if response.ok and (content := response.content):
buf = io.BytesIO()
buf.write(content)
msg = f"Successful {name!r}"
logger.info(msg)
return buf
msg = f"Failed for {name!r}"
raise NotImplementedError(msg)


async def _request_all_async(names: Iterable[str], /) -> list[io.BytesIO]:
session = niquests.AsyncSession(base_url=ROUTE_ZIP)
return await asyncio.gather(*(_request_async(session, name) for name in names))


async def _write_rezip_async(input_dir: Path, buf: io.BytesIO, /) -> Path:
return await asyncio.to_thread(_write_rezip, input_dir, buf)


def _write_rezip(input_dir: Path, buf: io.BytesIO, /) -> Path:
"""
Extract inner csv from a zip file, writing to a gzipped csv of the same name.
Notes
-----
- ``.read_bytes()`` is the only expensive op here
- End result (gzip, single file) can be scanned in parallel by ``polars``
- And slightly smaller than zipped directory
"""
zip_csv = next(zipfile.Path(zipfile.ZipFile(buf)).glob("*.csv"))
stem = zip_csv.at.replace("(", "").replace(")", "")
gzipped: Path = (input_dir / stem).with_suffix(".csv.gz")
gzipped.touch()
msg = f"Writing {gzipped.as_posix()!r}"
logger.debug(msg)
with gzip.GzipFile(gzipped, mode="wb", mtime=0) as f:
f.write(zip_csv.read_bytes())
return gzipped


def _file_stem_source[T: (str, pl.Expr)](year: T, month: T, /) -> pl.Expr:
"""Returns an expression that composes the file stem for a single month."""
return pl.concat_str(pl.lit(REPORTING_PREFIX), year, pl.lit("_"), month)


def _without_suffixes[T: (str, Path)](source: T, /) -> T:
"""Ensure all suffixes (not just the last) are removed."""
if isinstance(source, str):
return source.removesuffix("".join(Path(source).suffixes))
return Path(str(source).removesuffix("".join(source.suffixes)))


def _clean_source(ldf: pl.LazyFrame, /) -> pl.LazyFrame:
Expand Down Expand Up @@ -546,7 +592,7 @@ def _transform_temporal(


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
logging.basicConfig(level=logging.INFO)
repo_root = Path(__file__).parent.parent
source_toml = repo_root / "_data" / "flights.toml"
temp_out = repo_root / "data" / "_flights"
Expand Down

0 comments on commit aede7f6

Please sign in to comment.