Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add @classproperty to fetcher for nested return types #5371

Merged
merged 11 commits into from
Aug 24, 2023
10 changes: 7 additions & 3 deletions openbb_sdk/providers/fred/openbb_fred/models/ameribor_rates.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,13 @@ def value_validate(cls, v): # pylint: disable=E0213
return float("nan")


class FREDAMERIBORFetcher(Fetcher[FREDAMERIBORQueryParams, FREDAMERIBORData]):
class FREDAMERIBORFetcher(
Fetcher[FREDAMERIBORQueryParams, List[Dict[str, List[FREDAMERIBORData]]]]
):
"""FRED AMERIBOR Fetcher."""

data_type = FREDAMERIBORData

@staticmethod
def transform_query(params: Dict[str, Any]) -> FREDAMERIBORQueryParams:
return FREDAMERIBORQueryParams(**params)
Expand All @@ -74,14 +78,14 @@ def extract_data(
query: FREDAMERIBORQueryParams,
credentials: Optional[Dict[str, str]],
**kwargs: Any
) -> list:
) -> dict:
key = credentials.get("fred_api_key") if credentials else ""
fred_series = AMERIBOR_PARAMETER_TO_FRED_ID[query.parameter]
fred = Fred(key)
data = fred.get_series(fred_series, query.start_date, query.end_date, **kwargs)
return data

@staticmethod
def transform_data(data: list) -> List[Dict[str, List[FREDAMERIBORData]]]:
def transform_data(data: dict) -> List[Dict[str, List[FREDAMERIBORData]]]:
keys = ["date", "value"]
return [FREDAMERIBORData(**{k: x[k] for k in keys}) for x in data]
12 changes: 4 additions & 8 deletions openbb_sdk/providers/fred/openbb_fred/models/cpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,23 @@

from openbb_fred.utils.fred_base import Fred
from openbb_fred.utils.fred_helpers import all_cpi_options
from openbb_provider.abstract.data import Data
from openbb_provider.abstract.fetcher import Fetcher
from openbb_provider.standard_models.cpi import CPIData, CPIQueryParams
from pydantic import Field


class FREDCPIQueryParams(CPIQueryParams):
"""CPI query."""


class FREDCPIData(Data):
class FREDCPIData(CPIData):
"""CPI data."""

country_unit_freq: Optional[List[CPIData]] = Field(
description="CPI data for a country, units, and frequency combination."
)


class FREDCPIFetcher(Fetcher[FREDCPIQueryParams, List[FREDCPIData]]):
class FREDCPIFetcher(Fetcher[FREDCPIQueryParams, List[Dict[str, List[FREDCPIData]]]]):
"""FRED CPI Fetcher."""

data_type = FREDCPIData

@staticmethod
def transform_query(params: Dict[str, Any]) -> FREDCPIQueryParams:
return FREDCPIQueryParams(**params)
Expand Down
10 changes: 7 additions & 3 deletions openbb_sdk/providers/fred/openbb_fred/models/estr_rates.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,24 +53,28 @@ def value_validate(cls, v): # pylint: disable=E0213
return float("nan")


class FREDESTRFetcher(Fetcher[FREDESTRQueryParams, FREDESTRData]):
class FREDESTRFetcher(
Fetcher[FREDESTRQueryParams, List[Dict[str, List[FREDESTRData]]]]
):
"""FRED ESTR Fetcher."""

data_type = FREDESTRData

@staticmethod
def transform_query(params: Dict[str, Any]) -> FREDESTRQueryParams:
return FREDESTRQueryParams(**params)

@staticmethod
def extract_data(
query: FREDESTRQueryParams, credentials: Optional[Dict[str, str]], **kwargs: Any
) -> list:
) -> dict:
key = credentials.get("fred_api_key") if credentials else ""
fred_series = ESTR_PARAMETER_TO_ID[query.parameter]
fred = Fred(key)
data = fred.get_series(fred_series, query.start_date, query.end_date, **kwargs)
return data

@staticmethod
def transform_data(data: list) -> List[Dict[str, List[FREDESTRData]]]:
def transform_data(data: dict) -> List[Dict[str, List[FREDESTRData]]]:
keys = ["date", "value"]
return [FREDESTRData(**{k: x[k] for k in keys}) for x in data]
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,13 @@ class FREDPROJECTIONData(PROJECTIONData):
"""PROJECTION data."""


class FREDPROJECTIONFetcher(Fetcher[FREDPROJECTIONQueryParams, FREDPROJECTIONData]):
class FREDPROJECTIONFetcher(
Fetcher[FREDPROJECTIONQueryParams, List[Dict[str, List[FREDPROJECTIONData]]]]
):
"""FRED PROJECTION Fetcher."""

data_type = FREDPROJECTIONData

@staticmethod
def transform_query(params: Dict[str, Any]) -> FREDPROJECTIONQueryParams:
return FREDPROJECTIONQueryParams(**params)
Expand Down
8 changes: 5 additions & 3 deletions openbb_sdk/providers/fred/openbb_fred/models/fed_rates.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,24 +54,26 @@ def value_validate(cls, v): # pylint: disable=E0213
return float("nan")


class FREDFEDFetcher(Fetcher[FREDFEDQueryParams, FREDFEDData]):
class FREDFEDFetcher(Fetcher[FREDFEDQueryParams, List[Dict[str, List[FREDFEDData]]]]):
"""FRED FED Fetcher."""

data_type = FREDFEDData

@staticmethod
def transform_query(params: Dict[str, Any]) -> FREDFEDQueryParams:
return FREDFEDQueryParams(**params)

@staticmethod
def extract_data(
query: FREDFEDQueryParams, credentials: Optional[Dict[str, str]], **kwargs: Any
) -> list:
) -> dict:
key = credentials.get("fred_api_key") if credentials else ""
fred_series = FED_PARAMETER_TO_FRED_ID[query.parameter]
fred = Fred(key)
data = fred.get_series(fred_series, query.start_date, query.end_date, **kwargs)
return data

@staticmethod
def transform_data(data: list) -> List[Dict[str, List[FREDFEDData]]]:
def transform_data(data: dict) -> List[Dict[str, List[FREDFEDData]]]:
keys = ["date", "value"]
return [FREDFEDData(**{k: x[k] for k in keys}) for x in data]
10 changes: 7 additions & 3 deletions openbb_sdk/providers/fred/openbb_fred/models/iorb_rates.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,28 @@ def value_validate(cls, v): # pylint: disable=E0213
return float("nan")


class FREDIORBFetcher(Fetcher[FREDIORBQueryParams, FREDIORBData]):
class FREDIORBFetcher(
Fetcher[FREDIORBQueryParams, List[Dict[str, List[FREDIORBData]]]]
):
"""FRED IORB Fetcher."""

data_type = FREDIORBData

@staticmethod
def transform_query(params: Dict[str, Any]) -> FREDIORBQueryParams:
return FREDIORBQueryParams(**params)

@staticmethod
def extract_data(
query: FREDIORBQueryParams, credentials: Optional[Dict[str, str]], **kwargs: Any
) -> list:
) -> dict:
key = credentials.get("fred_api_key") if credentials else ""
fred_series = "IORB"
fred = Fred(key)
data = fred.get_series(fred_series, query.start_date, query.end_date, **kwargs)
return data

@staticmethod
def transform_data(data: list) -> List[Dict[str, List[FREDIORBData]]]:
def transform_data(data: dict) -> List[Dict[str, List[FREDIORBData]]]:
keys = ["date", "value"]
return [FREDIORBData(**{k: x[k] for k in keys}) for x in data]
10 changes: 7 additions & 3 deletions openbb_sdk/providers/fred/openbb_fred/models/sofr_rates.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,24 +43,28 @@ def value_validate(cls, v): # pylint: disable=E0213
return float("nan")


class FREDSOFRFetcher(Fetcher[FREDSOFRQueryParams, FREDSOFRData]):
class FREDSOFRFetcher(
Fetcher[FREDSOFRQueryParams, List[Dict[str, List[FREDSOFRData]]]]
):
"""FRED SOFR Fetcher."""

data_type = FREDSOFRData

@staticmethod
def transform_query(params: Dict[str, Any]) -> FREDSOFRQueryParams:
return FREDSOFRQueryParams(**params)

@staticmethod
def extract_data(
query: FREDSOFRQueryParams, credentials: Optional[Dict[str, str]], **kwargs: Any
) -> list:
) -> dict:
key = credentials.get("fred_api_key") if credentials else ""
fred_series = SOFR_PARAMETER_TO_FRED_ID[query.period]
fred = Fred(key)
data = fred.get_series(fred_series, query.start_date, query.end_date, **kwargs)
return data

@staticmethod
def transform_data(data: list) -> List[Dict[str, List[FREDSOFRData]]]:
def transform_data(data: dict) -> List[Dict[str, List[FREDSOFRData]]]:
keys = ["date", "value"]
return [FREDSOFRData(**{k: x[k] for k in keys}) for x in data]
10 changes: 7 additions & 3 deletions openbb_sdk/providers/fred/openbb_fred/models/sonia_rates.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,13 @@ def value_validate(cls, v): # pylint: disable=E0213
return float("nan")


class FREDSONIAFetcher(Fetcher[FREDSONIAQueryParams, FREDSONIAData]):
class FREDSONIAFetcher(
Fetcher[FREDSONIAQueryParams, List[Dict[str, List[FREDSONIAData]]]]
):
"""FRED SONIA Fetcher."""

data_type = FREDSONIAData

@staticmethod
def transform_query(params: Dict[str, Any]) -> FREDSONIAQueryParams:
return FREDSONIAQueryParams(**params)
Expand All @@ -63,14 +67,14 @@ def extract_data(
query: FREDSONIAQueryParams,
credentials: Optional[Dict[str, str]],
**kwargs: Any
) -> list:
) -> dict:
key = credentials.get("fred_api_key") if credentials else ""
fred_series = SONIA_PARAMETER_TO_FRED_ID[query.parameter]
fred = Fred(key)
data = fred.get_series(fred_series, query.start_date, query.end_date, **kwargs)
return data

@staticmethod
def transform_data(data: list) -> List[Dict[str, List[FREDSONIAData]]]:
def transform_data(data: dict) -> List[Dict[str, List[FREDSONIAData]]]:
keys = ["date", "value"]
return [FREDSONIAData(**{k: x[k] for k in keys}) for x in data]
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,13 @@ class FREDYieldCurveData(USYieldCurveData):
"""Fred Yield Curve data."""


class FREDYieldCurveFetcher(Fetcher[FREDYieldCurveQueryParams, FREDYieldCurveData]):
class FREDYieldCurveFetcher(
Fetcher[FREDYieldCurveQueryParams, List[Dict[str, List[FREDYieldCurveData]]]]
):
"""FRED Yield Curve Fetcher."""

data_type = FREDYieldCurveData

@staticmethod
def transform_query(params: Dict[str, Any]) -> FREDYieldCurveQueryParams:
return FREDYieldCurveQueryParams(**params)
Expand Down
4 changes: 2 additions & 2 deletions openbb_sdk/providers/fred/openbb_fred/utils/fred_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

YIELD_CURVE_NOMINAL_RATES = [round(1 / 12, 3), 0.25, 0.5, 1, 2, 3, 5, 7, 10, 20, 30]
YIELD_CURVE_SPOT_RATES = [0.5, 1, 2, 3, 5, 7, 10, 20, 30, 50, 75, 100]
YIELD_CURVE_REAL_RATES = [5, 7, 10, 20, 30]
YIELD_CURVE_PAR_RATES = [2, 5, 10, 30]
YIELD_CURVE_REAL_RATES = [5.0, 7, 10, 20, 30]
YIELD_CURVE_PAR_RATES = [2.0, 5, 10, 30]
YIELD_CURVE_SERIES_NOMINAL = {
"1Month": "DGS1MO",
"3Month": "DGS3MO",
Expand Down
11 changes: 8 additions & 3 deletions openbb_sdk/sdk/core/openbb_core/app/static/package/economy.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,9 +405,14 @@ def cpi(

CPI
---
country_unit_freq : Optional[List[CPIData]]
CPI data for a country, units, and frequency combination. (provider: fred)
"""
date : Optional[date]
The date of the data.
realtime_start : Optional[date]
Date the data was updated.
realtime_end : Optional[date]
Date the data was updated.
value : Optional[float]
Value of the data."""

inputs = filter_inputs(
provider_choices={
Expand Down
23 changes: 17 additions & 6 deletions openbb_sdk/sdk/provider/openbb_provider/abstract/fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,24 @@

from typing import Any, Dict, Generic, Optional, TypeVar, get_args, get_origin

from openbb_provider.abstract.data import Data
from openbb_provider.abstract.query_params import QueryParams

Q = TypeVar("Q", bound=QueryParams)
D = TypeVar("D") # Data
D = TypeVar("D", bound=Data)
R = TypeVar("R") # Return, usually List[D], but can be just D for example


class classproperty:
"""Class property decorator."""

def __init__(self, f):
self.f = f

def __get__(self, obj, owner):
return self.f(owner)


class Fetcher(Generic[Q, R]):
"""Abstract class for the fetcher."""

Expand Down Expand Up @@ -40,20 +51,20 @@ def fetch_data(
data = cls.extract_data(query=query, credentials=credentials, **kwargs)
return cls.transform_data(data=data)

@property
def query_params(self) -> Q:
@classproperty
def query_params_type(self) -> Q:
"""Get the type of query."""
# pylint: disable=E1101
return self.__orig_bases__[0].__args__[0] # type: ignore

@property
@classproperty
def return_type(self) -> R:
"""Get the type of return."""
# pylint: disable=E1101
return self.__orig_bases__[0].__args__[1] # type: ignore

@property
def data(self) -> D: # type: ignore
@classproperty
def data_type(self) -> D: # type: ignore
"""Get the type data."""
# pylint: disable=E1101
return self._get_data_type(self.__orig_bases__[0].__args__[1]) # type: ignore
Expand Down
28 changes: 21 additions & 7 deletions openbb_sdk/sdk/provider/openbb_provider/registry_map.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Provider registry map."""

import inspect
import os
from inspect import getfile, isclass
from typing import Any, Dict, List, Literal, Optional, Tuple

from openbb_provider.abstract.data import Data
from openbb_provider.abstract.fetcher import Fetcher
from openbb_provider.abstract.query_params import QueryParams
from openbb_provider.registry import Registry, RegistryLoader

MapType = Dict[str, Dict[str, Dict[str, Dict[str, Any]]]]
Expand Down Expand Up @@ -70,10 +72,9 @@ def _get_map(self, registry: Registry) -> Tuple[MapType, MapType]:

for p in registry.providers:
for model_name, fetcher in registry.providers[p].fetcher_dict.items():
f = fetcher()
standard_query, extra_query = self.extract_info(f, "query_params")
standard_data, extra_data = self.extract_info(f, "data")
return_type = self.extract_return_type(f)
standard_query, extra_query = self.extract_info(fetcher, "query_params")
standard_data, extra_data = self.extract_info(fetcher, "data")
return_type = self.extract_return_type(fetcher)

if model_name not in map_:
map_[model_name] = {}
Expand All @@ -97,7 +98,9 @@ def _get_models(self, map_: MapType) -> List[str]:
@staticmethod
def extract_info(fetcher: Fetcher, type_: Literal["query_params", "data"]) -> tuple:
"""Extract info (fields and docstring) from fetcher query params or data."""
super_model = getattr(fetcher, type_)
super_model = getattr(fetcher, f"{type_}_type")

RegistryMap.validate_model(super_model, type_)

skip_classes = {"object", "Representation", "BaseModel", "QueryParams", "Data"}
inheritance_list = [
Expand All @@ -109,7 +112,7 @@ def extract_info(fetcher: Fetcher, type_: Literal["query_params", "data"]) -> tu
found_standard = False

for model in inheritance_list:
model_file_dir = os.path.dirname(inspect.getfile(model))
model_file_dir = os.path.dirname(getfile(model))
model_name = os.path.basename(model_file_dir)

if (model_name == "standard_models") or found_standard:
Expand All @@ -132,3 +135,14 @@ def extract_info(fetcher: Fetcher, type_: Literal["query_params", "data"]) -> tu
def extract_return_type(fetcher: Fetcher):
"""Extract return info from fetcher."""
return getattr(fetcher, "return_type", None)

@staticmethod
def validate_model(model: Any, type_: Literal["query_params", "data"]):
parent_model = QueryParams if type_ == "query_params" else Data
if not isclass(model) or not issubclass(model, parent_model):
model_str = str(model).replace("<", "<'").replace(">", "'>")
raise ValueError(
f"'{model_str}' must be a subclass of '{parent_model.__name__}'.\n"
"If you are returning a nested type, try specifying"
f" `{type_}_type = <'your_{type_}_type'>` in the fetcher."
)