Skip to content

Commit

Permalink
refactor: type hints fixes and mypy in tox (#1269)
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrunato authored Jul 23, 2024
1 parent feef2bf commit 6055d7b
Show file tree
Hide file tree
Showing 26 changed files with 146 additions and 82 deletions.
10 changes: 9 additions & 1 deletion eodag/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,14 @@ class Sort(TypedDict):
sort_order_mapping: Dict[Literal["ascending", "descending"], str]
max_sort_params: Annotated[int, Gt(0)]

class DiscoverMetadata(TypedDict):
"""Configuration for metadata discovery"""

auto_discovery: bool
metadata_pattern: str
search_param: str
metadata_path: str

class OrderOnResponse(TypedDict):
"""Configuration for order on-response during download"""

Expand Down Expand Up @@ -319,7 +327,7 @@ class OrderStatus(TypedDict):
pagination: PluginConfig.Pagination
sort: PluginConfig.Sort
query_params_key: str
discover_metadata: Dict[str, Union[str, bool]]
discover_metadata: PluginConfig.DiscoverMetadata
discover_product_types: Dict[str, Any]
discover_queryables: Dict[str, Any]
metadata_mapping: Dict[str, Union[str, List[str]]]
Expand Down
7 changes: 4 additions & 3 deletions eodag/plugins/authentication/openid_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ def _constant_or_xpath_extracted(
if not match:
return value
value_from_xpath = form_element.xpath(
self.CONFIG_XPATH_REGEX.match(value).groupdict("xpath_value")["xpath_value"]
match.groupdict("xpath_value")["xpath_value"]
)
if len(value_from_xpath) == 1:
return value_from_xpath[0]
Expand Down Expand Up @@ -512,9 +512,10 @@ def __init__(self, token: str, where: str, key: Optional[str] = None) -> None:
def __call__(self, request: PreparedRequest) -> PreparedRequest:
"""Perform the actual authentication"""
if self.where == "qs":
parts = urlparse(request.url)
parts = urlparse(str(request.url))
query_dict = parse_qs(parts.query)
query_dict.update({self.key: self.token})
if self.key is not None:
query_dict.update({self.key: [self.token]})
url_without_args = parts._replace(query="").geturl()

request.prepare_url(url_without_args, query_dict)
Expand Down
3 changes: 2 additions & 1 deletion eodag/plugins/authentication/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ def __call__(self, request: PreparedRequest) -> PreparedRequest:
if self.where == "qs":
parts = urlparse(str(request.url))
qs = parse_qs(parts.query)
qs[self.qs_key] = self.token # type: ignore
if self.qs_key is not None:
qs[self.qs_key] = [self.token]
request.url = urlunparse(
(
parts.scheme,
Expand Down
2 changes: 1 addition & 1 deletion eodag/plugins/download/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ def _stream_download(
build_safe: bool,
progress_callback: ProgressCallback,
assets_values: List[Dict[str, Any]],
) -> Iterator[Tuple[str, datetime, int, Any, Iterator[Any]]]:
) -> Iterator[Any]:
"""Yield product data chunks"""

chunk_size = 4096 * 1024
Expand Down
2 changes: 1 addition & 1 deletion eodag/plugins/download/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,7 @@ def download_and_retry(*args: Any, **kwargs: Unpack[DownloadConf]) -> T:
not_available_info = str(e)

if datetime_now >= product.next_try and datetime_now < stop_time:
wait_seconds = (
wait_seconds: Union[float, int] = (
datetime_now - product.next_try + timedelta(minutes=wait)
).seconds
retry_count += 1
Expand Down
14 changes: 9 additions & 5 deletions eodag/plugins/download/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
Iterator,
List,
Optional,
Tuple,
TypedDict,
Union,
cast,
Expand Down Expand Up @@ -1011,8 +1010,11 @@ def _stream_download(
"content-disposition"
] = f"attachment; filename={filename}"
content_type = product.headers.get("Content-Type")
if filename and not content_type:
product.headers["Content-Type"] = guess_file_type(filename)
guessed_content_type = (
guess_file_type(filename) if filename and not content_type else None
)
if guessed_content_type is not None:
product.headers["Content-Type"] = guessed_content_type

progress_callback.reset(total=stream_size)
for chunk in self.stream.iter_content(chunk_size=64 * 1024):
Expand All @@ -1027,7 +1029,7 @@ def _stream_download_assets(
progress_callback: Optional[ProgressCallback] = None,
assets_values: List[Asset] = [],
**kwargs: Unpack[DownloadConf],
) -> Iterator[Tuple[str, datetime, int, Any, Iterator[Any]]]:
) -> Iterator[Any]:
if progress_callback is None:
logger.info("Progress bar unavailable, please call product.download()")
progress_callback = ProgressCallback(disable=True)
Expand Down Expand Up @@ -1201,7 +1203,9 @@ def _download_assets(
# start reading chunks to set asset.rel_path
first_chunks_tuple = next(chunks_tuples)
chunks = chain(iter([first_chunks_tuple]), chunks_tuples)
chunks_tuples = [(assets_values[0].rel_path, None, None, None, chunks)]
chunks_tuples = iter(
[(assets_values[0].rel_path, None, None, None, chunks)]
)

for chunk_tuple in chunks_tuples:
asset_path = chunk_tuple[0]
Expand Down
2 changes: 2 additions & 0 deletions eodag/plugins/download/s3rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ def download_request(
bucket_name, prefix = get_bucket_name_and_prefix(
url=product.location, bucket_path_level=self.config.bucket_path_level
)
if prefix is None:
raise DownloadError(f"Could not extract prefix from {product.location}")

if (
bucket_name is None
Expand Down
6 changes: 5 additions & 1 deletion eodag/plugins/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,11 @@ def __init__(self, providers_config: Dict[str, ProviderConfig]) -> None:
"Check that the plugin module (%s) is importable",
entry_point.module_name,
)
if entry_point.dist and entry_point.dist.key != "eodag":
if (
entry_point.dist
and entry_point.dist.key != "eodag"
and entry_point.dist.location is not None
):
# use plugin providers if any
plugin_providers_config_path = [
str(x)
Expand Down
4 changes: 2 additions & 2 deletions eodag/plugins/search/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ class PreparedSearch:
"""An object collecting needed information for search."""

product_type: Optional[str] = None
page: int = DEFAULT_PAGE
items_per_page: int = DEFAULT_ITEMS_PER_PAGE
page: Optional[int] = DEFAULT_PAGE
items_per_page: Optional[int] = DEFAULT_ITEMS_PER_PAGE
auth: Optional[Union[AuthBase, Dict[str, str]]] = None
auth_plugin: Optional[Authentication] = None
count: bool = True
Expand Down
2 changes: 1 addition & 1 deletion eodag/plugins/search/cop_marine.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def query(
items_per_page = prep.items_per_page

# only return 1 page if pagination is disabled
if page > 1 and items_per_page <= 0:
if page is None or items_per_page is None or page > 1 and items_per_page <= 0:
return ([], 0) if prep.count else ([], None)

product_type = kwargs.get("productType", prep.product_type)
Expand Down
53 changes: 38 additions & 15 deletions eodag/plugins/search/qssearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Expand Down Expand Up @@ -203,7 +204,10 @@ class QueryStringSearch(Search):
:type config: str
"""

extract_properties = {"xml": properties_from_xml, "json": properties_from_json}
extract_properties: Dict[str, Callable[..., Dict[str, Any]]] = {
"xml": properties_from_xml,
"json": properties_from_json,
}

def __init__(self, provider: str, config: PluginConfig) -> None:
super(QueryStringSearch, self).__init__(provider, config)
Expand Down Expand Up @@ -631,7 +635,7 @@ def discover_queryables(
)
)

field_definitions = dict()
field_definitions: Dict[str, Any] = dict()
for json_param, json_mtd in constraint_params.items():
param = (
get_queryable_from_provider(
Expand Down Expand Up @@ -786,7 +790,7 @@ def collect_search_urls(
prep.need_count = True
prep.total_items_nb = None

for collection in self.get_collections(prep, **kwargs):
for collection in self.get_collections(prep, **kwargs) or (None,):
# skip empty collection if one is required in api_endpoint
if "{collection}" in self.config.api_endpoint and not collection:
continue
Expand Down Expand Up @@ -1059,20 +1063,19 @@ def count_hits(self, count_url: str, result_type: Optional[str] = "json") -> int
total_results = int(count_results)
return total_results

def get_collections(
self, prep: PreparedSearch, **kwargs: Any
) -> Tuple[Set[Dict[str, Any]], ...]:
def get_collections(self, prep: PreparedSearch, **kwargs: Any) -> Tuple[str, ...]:
"""Get the collection to which the product belongs"""
# See https://earth.esa.int/web/sentinel/missions/sentinel-2/news/-
# /asset_publisher/Ac0d/content/change-of
# -format-for-new-sentinel-2-level-1c-products-starting-on-6-december
product_type: Optional[str] = kwargs.get("productType")
collection: Optional[str] = None
if product_type is None and (
not hasattr(prep, "product_type_def_params")
or not prep.product_type_def_params
):
collections: Set[Dict[str, Any]] = set()
collection: Optional[str] = getattr(self.config, "collection", None)
collections: Set[str] = set()
collection = getattr(self.config, "collection", None)
if collection is None:
try:
for product_type, product_config in self.config.products.items():
Expand All @@ -1090,18 +1093,26 @@ def get_collections(
collections.add(collection)
return tuple(collections)

collection: Optional[str] = getattr(self.config, "collection", None)
collection = getattr(self.config, "collection", None)
if collection is None:
collection = (
prep.product_type_def_params.get("collection", None) or product_type
)
return (collection,) if not isinstance(collection, list) else tuple(collection)

if collection is None:
return ()
elif not isinstance(collection, list):
return (collection,)
else:
return tuple(collection)

def _request(
self,
prep: PreparedSearch,
) -> Response:
url = prep.url
if url is None:
raise ValidationError("Cannot request empty URL")
info_message = prep.info_message
exception_message = prep.exception_message
try:
Expand Down Expand Up @@ -1347,8 +1358,11 @@ def query(
"specific_qssearch"
].get("merge_responses", None)

self.count_hits = lambda *x, **y: 1
self._request = super(PostJsonSearch, self)._request
def count_hits(self, *x, **y):
return 1

def _request(self, *x, **y):
return super(PostJsonSearch, self)._request(*x, **y)

try:
eo_products, total_items = super(PostJsonSearch, self).query(
Expand Down Expand Up @@ -1449,7 +1463,7 @@ def collect_search_urls(
auth_conf_dict = getattr(prep.auth_plugin.config, "credentials", {})
else:
auth_conf_dict = {}
for collection in self.get_collections(prep, **kwargs):
for collection in self.get_collections(prep, **kwargs) or (None,):
try:
search_endpoint: str = self.config.api_endpoint.rstrip("/").format(
**dict(collection=collection, **auth_conf_dict)
Expand All @@ -1472,7 +1486,11 @@ def collect_search_urls(
if getattr(self.config, "merge_responses", False):
total_results = _total_results or 0
else:
total_results += _total_results or 0
total_results = (
(_total_results or 0)
if total_results is None
else total_results + (_total_results or 0)
)
if "next_page_query_obj" in self.config.pagination and isinstance(
self.config.pagination["next_page_query_obj"], str
):
Expand All @@ -1497,6 +1515,8 @@ def _request(
prep: PreparedSearch,
) -> Response:
url = prep.url
if url is None:
raise ValidationError("Cannot request empty URL")
info_message = prep.info_message
exception_message = prep.exception_message
timeout = getattr(self.config, "timeout", HTTP_REQ_TIMEOUT)
Expand All @@ -1515,7 +1535,10 @@ def _request(
kwargs["auth"] = prep.auth

# perform the request using the next page arguments if they are defined
if getattr(self, "next_page_query_obj", None):
if (
hasattr(self, "next_page_query_obj")
and self.next_page_query_obj is not None
):
prep.query_params = self.next_page_query_obj
if info_message:
logger.info(info_message)
Expand Down
10 changes: 5 additions & 5 deletions eodag/rest/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
)

if TYPE_CHECKING:
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

from fastapi import Request
from requests.auth import AuthBase
Expand Down Expand Up @@ -215,7 +215,7 @@ def search_stac_items(
}

search_results = eodag_api.search(count=True, **criteria)
total = search_results.number_matched
total = search_results.number_matched or 0
if search_request.crunch:
search_results = crunch_products(
search_results, search_request.crunch, **criteria
Expand Down Expand Up @@ -588,9 +588,9 @@ def get_stac_extension_oseo(url: str) -> Dict[str, str]:
:rtype: dict
"""

apply_method: Callable[[str, str], str] = lambda _, x: str(x).replace(
"$.product.", "$."
)
def apply_method(_: str, x: str) -> str:
return str(x).replace("$.product.", "$.")

item_mapping = dict_items_recursive_apply(stac_config["item"], apply_method)

# all properties as string type by default
Expand Down
24 changes: 13 additions & 11 deletions eodag/rest/stac.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import os
from collections import defaultdict
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, cast
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse

import dateutil.parser
Expand Down Expand Up @@ -149,16 +149,18 @@ def update_data(self, data: Dict[str, Any]) -> None:
):
for i, bbox in enumerate(self.data["extent"]["spatial"]["bbox"]):
self.data["extent"]["spatial"]["bbox"][i] = [float(x) for x in bbox]
# "None" values to None
apply_method: Callable[[str, str], Optional[str]] = lambda _, v: (
None if v == "None" else v
)
self.data = dict_items_recursive_apply(self.data, apply_method)
# ids and titles as str
apply_method: Callable[[str, str], Optional[str]] = lambda k, v: (
str(v) if k in ["title", "id"] else v
)
self.data = dict_items_recursive_apply(self.data, apply_method)

def apply_method_none(_: str, v: str) -> Optional[str]:
""" "None" values to None"""
return None if v == "None" else v

self.data = dict_items_recursive_apply(self.data, apply_method_none)

def apply_method_ids(k, v):
"""ids and titles as str"""
return str(v) if k in ["title", "id"] else v

self.data = dict_items_recursive_apply(self.data, apply_method_ids)

# empty stac_extensions: "" to []
if not self.data.get("stac_extensions", True):
Expand Down
5 changes: 1 addition & 4 deletions eodag/rest/types/eodag_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,7 @@
from eodag.utils import DEFAULT_ITEMS_PER_PAGE

if TYPE_CHECKING:
try:
from typing import Self
except ImportError:
from _typeshed import Self
from typing_extensions import Self

Geometry = Union[
Dict[str, Any],
Expand Down
Loading

0 comments on commit 6055d7b

Please sign in to comment.