Skip to content

Commit

Permalink
rename pl to poll and clean up AirbaseClient.request
Browse files Browse the repository at this point in the history
  • Loading branch information
avaldebe committed Aug 29, 2024
1 parent 8d41ffa commit 992da91
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 93 deletions.
114 changes: 40 additions & 74 deletions airbase/airbase.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from __future__ import annotations

import sys
import warnings
from datetime import datetime
from itertools import chain
from itertools import chain, product
from pathlib import Path
from typing import TypedDict

Expand All @@ -19,8 +18,8 @@


class PollutantDict(TypedDict):
pl: str
shortpl: int
poll: str
id: int


class AirbaseClient:
Expand All @@ -30,7 +29,7 @@ def __init__(self) -> None:
:example:
>>> client = AirbaseClient()
>>> r = client.request(["NL", "DE"], pl=["O3", "NO2"])
>>> r = client.request(["NL", "DE"], poll=["O3", "NO2"])
>>> r.download_to_directory("data/raw")
Generating CSV download links...
100%|██████████| 4/4 [00:09<00:00, 2.64s/it]
Expand All @@ -50,8 +49,7 @@ def __init__(self) -> None:
def request(
self,
country: str | list[str] | None = None,
pl: str | list[str] | None = None,
shortpl: str | list[str] | None = None,
poll: str | list[str] | None = None,
year_from: str = "2013",
year_to: str = CURRENT_YEAR,
source: str = "All",
Expand All @@ -62,9 +60,9 @@ def request(
"""
Initialize an AirbaseRequest for a query.
Pollutants can be specified either by name (`pl`) or by code
(`shortpl`). If no pollutants are specified, data for all
available pollutants will be requested. If a pollutant is not
Pollutants can be specified by name/notation (`poll`).
If no pollutants are specified, data for all
available pollutants will be requested. If a poll is not
available for a country, then we simply do not try to download
those CSVs.
Expand All @@ -79,13 +77,8 @@ def request(
country. Will raise ValueError if a country is not available
on the server. If None, data for all countries will be
requested. See `self.all_countries`.
:param pl: (optional) The pollutant(s) to request data
:param poll: (optional) The pollutant(s) to request data
for. Must be one of the pollutants in `self.all_pollutants`.
Cannot be used in conjunction with `shortpl`.
:param shortpl: (optional). The pollutant code(s) to
request data for. Will be applied to each country requested.
Cannot be used in conjunction with `pl`.
Deprecated, will be removed on v1.
:param year_from: (optional) The first year of data. Can
not be earlier than 2013. Default 2013.
:param year_to: (optional) The last year of data. Can not be
Expand All @@ -108,7 +101,7 @@ def request(
:example:
>>> client = AirbaseClient()
>>> r = client.request(["NL", "DE"], pl=["O3", "NO2"])
>>> r = client.request(["NL", "DE"], poll=["O3", "NO2"])
>>> r.download_to_directory("data/raw")
Generating CSV download links...
100%|██████████| 4/4 [00:09<00:00, 2.64s/it]
Expand All @@ -123,35 +116,27 @@ def request(
country = self.countries
else:
country = string_safe_list(country)
self._validate_country(country)

if shortpl is not None:
warnings.warn(
"the shortpl option has been deprecated and will be removed on v1. "
"Use client.request([client._pollutants_ids[p] for p in shortpl], ...) instead.",
DeprecationWarning,
stacklevel=2,
)
unknown = sorted(set(country) - set(self.countries))
if unknown:
raise ValueError(
f"Unknown country code(s) {', '.join(unknown)}."
)

if pl is not None and shortpl is not None:
raise ValueError("You cannot specify both 'pl' and 'shortpl'")

# construct shortpl form pl if applicable
if pl is not None:
pl_list = string_safe_list(pl)
try:
shortpl = list(
map(
str,
chain.from_iterable(
self._pollutants_ids[p] for p in pl_list
),
)
# construct shortpl form poll
shortpl: list[int] | None
try:
if poll is None:
shortpl = None
elif isinstance(poll, str):
shortpl = sorted(self._pollutants_ids[poll])
else:
shortpl = sorted(
chain.from_iterable(self._pollutants_ids[p] for p in poll)
)
except KeyError as e:
raise ValueError(
f"'{e.args[0]}' is not a valid pollutant name"
) from e
except KeyError as e:
raise ValueError(
f"'{e.args[0]}' is not a valid pollutant name"
) from e

return AirbaseRequest(
country,
Expand All @@ -168,21 +153,21 @@ def search_pollutant(
self, query: str, limit: int | None = None
) -> list[PollutantDict]:
"""
Search for a pollutant's `shortpl` number based on its name.
Search for a pollutant's `id` number based on its name.
:param query: The pollutant to search for.
:param limit: (optional) Max number of results.
:return: The best pollutant matches. Pollutants
are dicts with keys "pl" and "shortpl".
are dicts with keys "poll" and "id".
:example:
>>> AirbaseClient().search_pollutant("o3", limit=2)
>>> [{"pl": "O3", "shortpl": 7}, {"pl": "NO3", "shortpl": 46}]
>>> [{"poll": "O3", "id": 7}, {"poll": "NO3", "id": 46}]
"""
results = DB.search_pollutant(query, limit=limit)
return [dict(pl=poll.notation, shortpl=poll.id) for poll in results]
return [dict(poll=poll.notation, id=poll.id) for poll in results]

@staticmethod
def download_metadata(filepath: str | Path, verbose: bool = True) -> None:
Expand All @@ -196,28 +181,12 @@ def download_metadata(filepath: str | Path, verbose: bool = True) -> None:
"""
AirbaseRequest(verbose=verbose).download_metadata(filepath)

def _validate_country(self, country: str | list[str]) -> None:
"""
Ensure that a country or list of countries exists on the server.
Must first download the country list using `.connect()`. Raises
value error if a country does not exist.
:param country: The 2-letter country code to validate.
"""
country_list = string_safe_list(country)
for c in country_list:
if c not in self.countries:
raise ValueError(
f"'{c}' is not an available 2-letter country code."
)


class AirbaseRequest:
def __init__(
self,
country: str | list[str] | None = None,
shortpl: str | list[str] | None = None,
shortpl: int | list[int] | None = None,
year_from: str = "2013",
year_to: str = CURRENT_YEAR,
source: str = "All",
Expand Down Expand Up @@ -258,23 +227,20 @@ def __init__(
download links from the Airbase server at object
initialization. Default False.
"""
self.country = country
self.shortpl = shortpl
self.counties = string_safe_list(country)
self.pollutants = string_safe_list(shortpl)
self.year_from = year_from
self.year_to = year_to
self.source = source
self.update_date = update_date
self.verbose = verbose

self._country_list = string_safe_list(country)
self._shortpl_list = string_safe_list(shortpl)
self._download_links = []

for c in self._country_list:
for p in self._shortpl_list:
self._download_links.append(
link_list_url(c, p, year_from, year_to, source, update_date)
)
for c, p in product(self.counties, self.pollutants):
self._download_links.append(
link_list_url(c, p, year_from, year_to, source, update_date)
)

self._csv_links: list[str] = []

Expand Down
11 changes: 8 additions & 3 deletions airbase/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,20 @@ def string_safe_list(obj: str | Iterable[str]) -> list[str]: # pragma: no cover
...


def string_safe_list(obj: str | Iterable[str] | None) -> list[str] | list[None]:
@overload
def string_safe_list(obj: int | Iterable[int]) -> list[int]: # pragma: no cover
...


def string_safe_list(obj):
"""
Turn an (iterable) object into a list. If it is a string or not
iterable, put the whole object into a list of length 1.
:param obj:
:return list:
"""
if isinstance(obj, str):
if isinstance(obj, (str, int)):
return [obj]
if obj is None:
return [obj]
Expand All @@ -40,7 +45,7 @@ def string_safe_list(obj: str | Iterable[str] | None) -> list[str] | list[None]:

def link_list_url(
country: str | None,
shortpl: str | None = None,
shortpl: int | None = None,
year_from: str = "2013",
year_to: str = CURRENT_YEAR,
source: str = "All",
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/test_airbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def client():
)
def test_download_to_directory(client: airbase.AirbaseClient, tmp_path: Path):
r = client.request(
country=["AD", "BE"], pl="CO", year_from="2017", year_to="2017"
country=["AD", "BE"], poll="CO", year_from="2017", year_to="2017"
)

r.download_to_directory(dir=str(tmp_path), skip_existing=True)
Expand All @@ -29,7 +29,7 @@ def test_download_to_directory(client: airbase.AirbaseClient, tmp_path: Path):
)
def test_download_to_file(client: airbase.AirbaseClient, tmp_path: Path):
r = client.request(
country="CY", pl=["As", "NO2"], year_from="2014", year_to="2014"
country="CY", poll=["As", "NO2"], year_from="2014", year_to="2014"
)

path = tmp_path / "raw.csv"
Expand Down
22 changes: 8 additions & 14 deletions tests/test_airbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,32 +51,26 @@ def test_request_raises_bad_year(self, client: airbase.AirbaseClient):
client.request(year_to="9999")

def test_request_pl(self, client: airbase.AirbaseClient):
r = client.request(pl="NO")
assert r.shortpl is not None
assert len(r.shortpl) == 1
r = client.request(poll="NO")
assert r.pollutants == [38]

r = client.request(pl=["NO", "NO3"])
assert r.shortpl is not None
assert len(r.shortpl) == 2
r = client.request(poll=["NO", "NO3"])
assert r.pollutants == [38, 46]

with pytest.raises(ValueError):
r = client.request(pl=["NO", "NO3", "Not a pl"])
r = client.request(poll=["NO", "NO3", "Not a pl"])

def test_request_response_generated(self, client: airbase.AirbaseClient):
r = client.request()
assert isinstance(r, airbase.AirbaseRequest)

def test_request_not_pl_and_shortpl(self, client: airbase.AirbaseClient):
with pytest.raises(ValueError), pytest.warns(DeprecationWarning):
client.request(pl="O3", shortpl="123")

def test_search_pl_exact(self, client: airbase.AirbaseClient):
result = client.search_pollutant("NO3")
assert result[0]["pl"] == "NO3"
assert result[0]["poll"] == "NO3"

def test_search_pl_shortest_first(self, client: airbase.AirbaseClient):
result = client.search_pollutant("N")
names: list[str] = [r["pl"] for r in result]
names: list[str] = [r["poll"] for r in result]
assert len(names[0]) <= len(names[1])
assert len(names[0]) <= len(names[-1])

Expand All @@ -90,7 +84,7 @@ def test_search_pl_no_result(self, client: airbase.AirbaseClient):

def test_search_pl_case_insensitive(self, client: airbase.AirbaseClient):
result = client.search_pollutant("no3")
assert result[0]["pl"] == "NO3"
assert result[0]["poll"] == "NO3"


@pytest.mark.usefixtures("all_responses")
Expand Down

0 comments on commit 992da91

Please sign in to comment.