Skip to content

Commit

Permalink
[Bug fix] - Handle multiple items with arbitrary type (#6171)
Browse files Browse the repository at this point in the history
* handle multiple items with arbitrary type

* minor fix

* ruff

* inequality

* integration tests

* test

* pylint

* fix tests

* fix category

* ruff
  • Loading branch information
montezdesousa authored Mar 6, 2024
1 parent 76556df commit a8122e9
Show file tree
Hide file tree
Showing 12 changed files with 118 additions and 46 deletions.
17 changes: 14 additions & 3 deletions openbb_platform/core/openbb_core/app/static/package_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,9 @@ def format_params(
type_ = MethodDefinition.get_type(field)
default = MethodDefinition.get_default(field)
extra = MethodDefinition.get_extra(field)
new_type = MethodDefinition.get_expanded_type(field_name, extra)
new_type = MethodDefinition.get_expanded_type(
field_name, extra, type_
)
updated_type = type_ if new_type is ... else Union[type_, new_type]

formatted[field_name] = Parameter(
Expand Down Expand Up @@ -782,10 +784,19 @@ def build_command_method_body(path: str, func: Callable):
return code

@classmethod
def get_expanded_type(cls, field_name: str, extra: Optional[dict] = None) -> object:
def get_expanded_type(
cls,
field_name: str,
extra: Optional[dict] = None,
original_type: Optional[type] = None,
) -> object:
"""Expand the original field type."""
if extra and "multiple_items_allowed" in extra:
return List[str]
if original_type is None:
raise ValueError(
"multiple_items_allowed requires the original type to be specified."
)
return List[original_type]
return cls.TYPE_EXPANSION.get(field_name, ...)

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,9 @@ def wrapper(*f_args, **f_kwargs):
).with_traceback(tb) from None

# If the error is not a ValidationError, then it is a generic exception
error_type = getattr(e, "original", e).__class__.__name__
raise OpenBBError(
f"\nType -> {e.original.__class__.__name__}\n\nDetail -> {str(e)}"
f"\nType -> {error_type}\n\nDetail -> {str(e)}"
).with_traceback(tb) from None

return wrapper
4 changes: 3 additions & 1 deletion openbb_platform/core/openbb_core/app/static/utils/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def filter_inputs(
if field in kwargs.get(p, {}):
current = kwargs[p][field]
new = (
",".join(current) if isinstance(current, list) else current
",".join(map(str, current))
if isinstance(current, list)
else current
)

if provider and provider not in props[PROPERTY]:
Expand Down
26 changes: 8 additions & 18 deletions openbb_platform/core/openbb_core/provider/standard_models/spot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from datetime import (
date as dateType,
)
from typing import List, Literal, Optional
from typing import Optional, Union

from pydantic import Field, field_validator
from pydantic import Field

from openbb_core.provider.abstract.data import Data
from openbb_core.provider.abstract.query_params import QueryParams
Expand All @@ -26,25 +26,15 @@ class SpotRateQueryParams(QueryParams):
default=None,
description=QUERY_DESCRIPTIONS.get("end_date", ""),
)
maturity: List[float] = Field(
default=[10.0], description="The maturities in years."
maturity: Union[float, str] = Field(
default=10.0, description="Maturities in years."
)
category: List[Literal["par_yield", "spot_rate"]] = Field(
default=["spot_rate"],
description="The category.",
category: str = Field(
default="spot_rate",
description="Rate category. Options: spot_rate, par_yield.",
choices=["par_yield", "spot_rate"],
)

@field_validator("maturity")
@classmethod
def maturity_validate(cls, v):
"""Validate maturity."""
for i in v:
if not isinstance(i, float):
raise ValueError("`maturity` must be a float")
if not 1 <= i <= 100:
raise ValueError("`maturity` must be between 1 and 100")
return v


class SpotRateData(Data):
"""Spot Rate Data."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,10 +341,26 @@ def test_fixedincome_corporate_commercial_paper(params, headers):
"start_date": "2023-01-01",
"end_date": "2023-06-06",
"maturity": [10.0],
"category": ["spot_rate"],
"category": "spot_rate",
"provider": "fred",
}
)
),
(
{
"start_date": None,
"end_date": None,
"maturity": 5.5,
"category": ["spot_rate"],
}
),
(
{
"start_date": None,
"end_date": None,
"maturity": "1,5.5,10",
"category": "spot_rate,par_yield",
}
),
],
)
@pytest.mark.integration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,10 +312,26 @@ def test_fixedincome_corporate_commercial_paper(params, obb):
"start_date": "2023-01-01",
"end_date": "2023-06-06",
"maturity": [10.0],
"category": ["spot_rate"],
"category": "spot_rate",
"provider": "fred",
}
)
),
(
{
"start_date": None,
"end_date": None,
"maturity": 5.5,
"category": ["spot_rate"],
}
),
(
{
"start_date": None,
"end_date": None,
"maturity": "1,5.5,10",
"category": "spot_rate,par_yield",
}
),
],
)
@pytest.mark.integration
Expand Down
6 changes: 3 additions & 3 deletions openbb_platform/extensions/tests/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,13 @@ def list_openbb_extensions() -> Tuple[Set[str], Set[str], Set[str]]:
obbject_extensions = set()
entry_points_dict = entry_points()

for entry_point in entry_points_dict["openbb_core_extension"]:
for entry_point in entry_points_dict.get("openbb_core_extension", []):
core_extensions.add(f"{entry_point.name}")

for entry_point in entry_points_dict["openbb_provider_extension"]:
for entry_point in entry_points_dict.get("openbb_provider_extension", []):
provider_extensions.add(f"{entry_point.name}")

for entry_point in entry_points_dict["openbb_obbject_extension"]:
for entry_point in entry_points_dict.get("openbb_obbject_extension", []):
obbject_extensions.add(f"{entry_point.name}")

return core_extensions, provider_extensions, obbject_extensions
4 changes: 2 additions & 2 deletions openbb_platform/openbb/package/equity_estimates.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def historical(
def price_target(
self,
symbol: Annotated[
Union[str, None, List[str]],
Union[str, None, List[Optional[str]]],
OpenBBCustomParameter(
description="Symbol to get data for. Multiple items allowed for provider(s): benzinga."
),
Expand All @@ -417,7 +417,7 @@ def price_target(
Parameters
----------
symbol : Union[str, None, List[str]]
symbol : Union[str, None, List[Optional[str]]]
Symbol to get data for. Multiple items allowed for provider(s): benzinga.
limit : int
The number of data entries to return.
Expand Down
30 changes: 21 additions & 9 deletions openbb_platform/openbb/package/fixedincome_corporate.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,12 +419,17 @@ def spot_rates(
),
] = None,
maturity: Annotated[
List[float], OpenBBCustomParameter(description="The maturities in years.")
] = [10.0],
Union[float, str, List[Union[float, str]]],
OpenBBCustomParameter(
description="Maturities in years. Multiple items allowed for provider(s): fred."
),
] = 10.0,
category: Annotated[
List[Literal["par_yield", "spot_rate"]],
OpenBBCustomParameter(description="The category."),
] = ["spot_rate"],
Union[str, List[str]],
OpenBBCustomParameter(
description="Rate category. Options: spot_rate, par_yield. Multiple items allowed for provider(s): fred."
),
] = "spot_rate",
provider: Optional[Literal["fred"]] = None,
**kwargs
) -> OBBject:
Expand All @@ -442,10 +447,10 @@ def spot_rates(
Start date of the data, in YYYY-MM-DD format.
end_date : Union[datetime.date, None, str]
End date of the data, in YYYY-MM-DD format.
maturity : List[float]
The maturities in years.
category : List[Literal['par_yield', 'spot_rate']]
The category.
maturity : Union[float, str, List[Union[float, str]]]
Maturities in years. Multiple items allowed for provider(s): fred.
category : Union[str, List[str]]
Rate category. Options: spot_rate, par_yield. Multiple items allowed for provider(s): fred.
provider : Optional[Literal['fred']]
The provider to use for the query, by default None.
If None, the provider specified in defaults is selected or 'fred' if there is
Expand Down Expand Up @@ -495,5 +500,12 @@ def spot_rates(
"category": category,
},
extra_params=kwargs,
extra_info={
"maturity": {"multiple_items_allowed": ["fred"]},
"category": {
"choices": ["par_yield", "spot_rate"],
"multiple_items_allowed": ["fred"],
},
},
)
)
4 changes: 2 additions & 2 deletions openbb_platform/openbb/package/news.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __repr__(self) -> str:
def company(
self,
symbol: Annotated[
Union[str, None, List[str]],
Union[str, None, List[Optional[str]]],
OpenBBCustomParameter(
description="Symbol to get data for. This endpoint will accept multiple symbols separated by commas. Multiple items allowed for provider(s): benzinga, fmp, intrinio, polygon, tiingo, yfinance."
),
Expand Down Expand Up @@ -56,7 +56,7 @@ def company(
Parameters
----------
symbol : Union[str, None, List[str]]
symbol : Union[str, None, List[Optional[str]]]
Symbol to get data for. This endpoint will accept multiple symbols separated by commas. Multiple items allowed for provider(s): benzinga, fmp, intrinio, polygon, tiingo, yfinance.
start_date : Union[datetime.date, None, str]
Start date of the data, in YYYY-MM-DD format.
Expand Down
19 changes: 16 additions & 3 deletions openbb_platform/providers/fred/openbb_fred/models/spot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,18 @@
SpotRateQueryParams,
)
from openbb_fred.utils.fred_base import Fred
from openbb_fred.utils.fred_helpers import get_spot_series_id
from openbb_fred.utils.fred_helpers import comma_to_float_list, get_spot_series_id
from pydantic import field_validator


class FREDSpotRateQueryParams(SpotRateQueryParams):
"""FRED Spot Rate Query."""

__json_schema_extra__ = {
"maturity": ["multiple_items_allowed"],
"category": ["multiple_items_allowed"],
}


class FREDSpotRateData(SpotRateData):
"""FRED Spot Rate Data."""
Expand Down Expand Up @@ -56,9 +61,17 @@ def extract_data(
key = credentials.get("fred_api_key") if credentials else ""
fred = Fred(key)

maturity = (
comma_to_float_list(query.maturity)
if isinstance(query.maturity, str)
else [query.maturity]
)
if any(1 > m > 100 for m in maturity):
raise ValueError("Maturity must be between 1 and 100")

series = get_spot_series_id(
maturity=query.maturity,
category=query.category,
maturity=maturity,
category=query.category.split(","),
)

data = []
Expand Down
11 changes: 11 additions & 0 deletions openbb_platform/providers/fred/openbb_fred/utils/fred_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,16 @@
}


def comma_to_float_list(v: str) -> List[float]:
"""Convert comma-separated string to list of floats."""
try:
return [float(m) for m in v.split(",")]
except ValueError as e:
raise ValueError(
"'maturity' must be a float or a comma-separated string of floats"
) from e


def all_cpi_options(harmonized: bool = False) -> List[dict]:
"""Get all CPI options."""
data = []
Expand Down Expand Up @@ -136,6 +146,7 @@ def get_ice_bofa_series_id(
units = "index" if type_ == "total_return" else "percent"

for s in series:
# pylint: disable=too-many-boolean-expressions
if (
s["Type"] == type_
and s["Units"] == units
Expand Down

0 comments on commit a8122e9

Please sign in to comment.