From 2928da1368f067bcbc01a7e56b8eca2d7eca265f Mon Sep 17 00:00:00 2001 From: James Maslek Date: Fri, 1 Mar 2024 16:30:40 -0500 Subject: [PATCH] Add Share Price endpoint --- .../provider/standard_models/share_price.py | 40 +++++ .../economy/openbb_economy/economy_router.py | 24 +++ .../providers/oecd/openbb_oecd/__init__.py | 5 + .../models/long_term_interest_rate.py | 4 +- .../oecd/openbb_oecd/models/share_price.py | 138 ++++++++++++++++++ .../models/short_term_interest_rate.py | 4 +- .../oecd/openbb_oecd/utils/constants.py | 49 +++++++ 7 files changed, 260 insertions(+), 4 deletions(-) create mode 100644 openbb_platform/core/openbb_core/provider/standard_models/share_price.py create mode 100644 openbb_platform/providers/oecd/openbb_oecd/models/share_price.py diff --git a/openbb_platform/core/openbb_core/provider/standard_models/share_price.py b/openbb_platform/core/openbb_core/provider/standard_models/share_price.py new file mode 100644 index 000000000000..7f279509d023 --- /dev/null +++ b/openbb_platform/core/openbb_core/provider/standard_models/share_price.py @@ -0,0 +1,40 @@ +"""Share Prices Standard Model.""" + +from datetime import date as dateType +from typing import Optional + +from pydantic import Field + +from openbb_core.provider.abstract.data import Data +from openbb_core.provider.abstract.query_params import QueryParams +from openbb_core.provider.utils.descriptions import ( + DATA_DESCRIPTIONS, + QUERY_DESCRIPTIONS, +) + + +class SharePriceQueryParams(QueryParams): + """Share Price Query.""" + + start_date: Optional[dateType] = Field( + default=None, description=QUERY_DESCRIPTIONS.get("start_date") + ) + end_date: Optional[dateType] = Field( + default=None, description=QUERY_DESCRIPTIONS.get("end_date") + ) + + +class SharePriceData(Data): + """Share Price Data.""" + + date: Optional[dateType] = Field( + default=None, description=DATA_DESCRIPTIONS.get("date") + ) + value: Optional[float] = Field( + default=None, + description="Interest rate (given as a whole number, i.e 10=10%)", + ) + country: Optional[str] = Field( + default=None, + description="Country for which interest rate is given", + ) diff --git a/openbb_platform/extensions/economy/openbb_economy/economy_router.py b/openbb_platform/extensions/economy/openbb_economy/economy_router.py index b6f0bde05d54..3e2cf3fe0f7d 100644 --- a/openbb_platform/extensions/economy/openbb_economy/economy_router.py +++ b/openbb_platform/extensions/economy/openbb_economy/economy_router.py @@ -184,6 +184,30 @@ async def composite_leading_indicator( return await OBBject.from_query(Query(**locals())) +@router.command( + model="SharePrice", + exclude_auto_examples=True, + examples=[ + 'obb.economy.share_price(country="all").to_df()', + ], +) +async def share_price( + cc: CommandContext, + provider_choices: ProviderChoices, + standard_params: StandardParams, + extra_params: ExtraParams, +) -> OBBject: + """Share price indices are calculated from the prices of common shares of companies + traded on national or foreign stock exchanges. They are usually determined by the + stock exchange, using the closing daily values for the monthly data, and normally + expressed as simple arithmetic averages of the daily data. A share price index + measures how the value of the stocks in the index is changing, a share return index + tells the investor what their “return” is, meaning how much money they would make as + a result of investing in that basket of shares. + """ + return await OBBject.from_query(Query(**locals())) + + @router.command( model="STIR", exclude_auto_examples=True, diff --git a/openbb_platform/providers/oecd/openbb_oecd/__init__.py b/openbb_platform/providers/oecd/openbb_oecd/__init__.py index bf7063684274..4b53a5159c03 100644 --- a/openbb_platform/providers/oecd/openbb_oecd/__init__.py +++ b/openbb_platform/providers/oecd/openbb_oecd/__init__.py @@ -10,6 +10,10 @@ from openbb_oecd.models.short_term_interest_rate import OECDSTIRFetcher from openbb_oecd.models.unemployment import OECDUnemploymentFetcher +from openbb_platform.providers.oecd.openbb_oecd.models.share_price import ( + OECDSharePriceFetcher, +) + oecd_provider = Provider( name="oecd", website="https://stats.oecd.org/", @@ -23,5 +27,6 @@ "STIR": OECDSTIRFetcher, "LTIR": OECDLTIRFetcher, "ConsumerPriceIndex": OECDCPIFetcher, + "SharePrice": OECDSharePriceFetcher, }, ) diff --git a/openbb_platform/providers/oecd/openbb_oecd/models/long_term_interest_rate.py b/openbb_platform/providers/oecd/openbb_oecd/models/long_term_interest_rate.py index 9ac5c306eacd..e46e5af8e776 100644 --- a/openbb_platform/providers/oecd/openbb_oecd/models/long_term_interest_rate.py +++ b/openbb_platform/providers/oecd/openbb_oecd/models/long_term_interest_rate.py @@ -1,4 +1,4 @@ -"""OECD Long Term Interest Rate Rate Data.""" +"""OECD Long Term Interest Rate Data.""" # pylint: disable=unused-argument @@ -23,7 +23,7 @@ class OECDLTIRQueryParams(LTIRQueryParams): """OECD Short Term Interest Rate Query.""" country: CountriesLiteral = Field( - description="Country to get GDP for.", default="united_states" + description="Country to get interest rate for.", default="united_states" ) frequency: Literal["monthly", "quarterly", "annual"] = Field( diff --git a/openbb_platform/providers/oecd/openbb_oecd/models/share_price.py b/openbb_platform/providers/oecd/openbb_oecd/models/share_price.py new file mode 100644 index 000000000000..70ce9bef07ce --- /dev/null +++ b/openbb_platform/providers/oecd/openbb_oecd/models/share_price.py @@ -0,0 +1,138 @@ +"""OECD Short Term Interest Rate Data.""" + +# pylint: disable=unused-argument + +import re +from datetime import date, timedelta +from typing import Any, Dict, List, Literal, Optional, Union + +from openbb_core.provider.abstract.fetcher import Fetcher +from openbb_core.provider.standard_models.share_price import ( + SharePriceData, + SharePriceQueryParams, +) +from openbb_oecd.utils import helpers +from openbb_oecd.utils.constants import CODE_TO_COUNTRY_SHARES, COUNTRY_TO_CODE_SHARES +from pydantic import Field, field_validator + +countries = tuple(CODE_TO_COUNTRY_SHARES.values()) + ("all",) +CountriesLiteral = Literal[countries] # type: ignore + + +class OECDSharePriceQueryParams(SharePriceQueryParams): + """OECD Share Price Rate Query.""" + + country: CountriesLiteral = Field( + description="Country to get share price for.", default="united_states" + ) + + frequency: Literal["monthly", "quarterly", "annual"] = Field( + description="Frequency to get share price for for.", default="monthly" + ) + + units: Literal["yoy", "pop"] = Field( + description="Units to get share price for. Either change over period (pop) or change over year (yoy)", + default="yoy", + ) + + +class OECDSharePriceData(SharePriceData): + """OECD Share Price Rate Data.""" + + @field_validator("date", mode="before") + @classmethod + def date_validate(cls, in_date: Union[date, str]): # pylint: disable=E0213 + """Validate value.""" + if isinstance(in_date, str): + # i.e 2022-Q1 + if re.match(r"\d{4}-Q[1-4]$", in_date): + year, quarter = in_date.split("-") + _year = int(year) + if quarter == "Q1": + return date(_year, 3, 31) + if quarter == "Q2": + return date(_year, 6, 30) + if quarter == "Q3": + return date(_year, 9, 30) + if quarter == "Q4": + return date(_year, 12, 31) + # Now match if it is monthly, i.e 2022-01 + elif re.match(r"\d{4}-\d{2}$", in_date): + year, month = map(int, in_date.split("-")) # type: ignore + if month == 12: + return date(year, month, 31) # type: ignore + next_month = date(year, month + 1, 1) # type: ignore + return date(next_month.year, next_month.month, 1) - timedelta(days=1) + # Now match if it is yearly, i.e 2022 + elif re.match(r"\d{4}$", in_date): + return date(int(in_date), 12, 31) + # If the input date is a year + if isinstance(in_date, int): + return date(in_date, 12, 31) + + return in_date + + +class OECDSharePriceFetcher( + Fetcher[OECDSharePriceQueryParams, List[OECDSharePriceData]] +): + """Transform the query, extract and transform the data from the OECD endpoints.""" + + @staticmethod + def transform_query(params: Dict[str, Any]) -> OECDSharePriceQueryParams: + """Transform the query.""" + transformed_params = params.copy() + if transformed_params["start_date"] is None: + transformed_params["start_date"] = date(1950, 1, 1) + if transformed_params["end_date"] is None: + transformed_params["end_date"] = date(date.today().year, 12, 31) + + return OECDSharePriceQueryParams(**transformed_params) + + @staticmethod + def extract_data( + query: OECDSharePriceQueryParams, + credentials: Optional[Dict[str, str]], + **kwargs: Any, + ) -> List[Dict]: + """Return the raw data from the OECD endpoint.""" + frequency = query.frequency[0].upper() + country = ( + "" if query.country == "all" else COUNTRY_TO_CODE_SHARES[query.country] + ) + transform = {"pop": "G1", "yoy": "GY"}[query.units] + query_dict = { + k: v + for k, v in query.__dict__.items() + if k not in ["start_date", "end_date"] + } + + url = f"https://sdmx.oecd.org/public/rest/data/OECD.SDD.STES,DSD_KEI@DF_KEI,4.0/{country}.{frequency}.SHARE....{transform}" + data = helpers.get_possibly_cached_data( + url, function="economy_share_price", query_dict=query_dict + ) + url_query = f"FREQ=='{frequency}' & TRANSFORMATION=='{transform}'" + url_query = url_query + f" & REF_AREA=='{country}'" if country else url_query + # Filter down + data = ( + data.query(url_query) + .reset_index(drop=True)[["REF_AREA", "TIME_PERIOD", "VALUE"]] + .rename( + columns={"REF_AREA": "country", "TIME_PERIOD": "date", "VALUE": "value"} + ) + ) + data["country"] = data["country"].map(CODE_TO_COUNTRY_SHARES) + data = data.fillna("N/A").replace("N/A", None) + data["date"] = data["date"].apply(helpers.oecd_date_to_python_date) + data = data[ + (data["date"] <= query.end_date) & (data["date"] >= query.start_date) + ] + + return data.to_dict(orient="records") + + @staticmethod + def transform_data( + query: OECDSharePriceQueryParams, data: List[Dict], **kwargs: Any + ) -> List[OECDSharePriceData]: + """Transform the data from the OECD endpoint.""" + return [OECDSharePriceData.model_validate(d) for d in data] diff --git a/openbb_platform/providers/oecd/openbb_oecd/models/short_term_interest_rate.py b/openbb_platform/providers/oecd/openbb_oecd/models/short_term_interest_rate.py index aee3e11aef55..dfe8800768a2 100644 --- a/openbb_platform/providers/oecd/openbb_oecd/models/short_term_interest_rate.py +++ b/openbb_platform/providers/oecd/openbb_oecd/models/short_term_interest_rate.py @@ -1,4 +1,4 @@ -"""OECD Short Term Interest Rate Rate Data.""" +"""OECD Short Term Interest Rate Data.""" # pylint: disable=unused-argument @@ -23,7 +23,7 @@ class OECDSTIRQueryParams(STIRQueryParams): """OECD Short Term Interest Rate Query.""" country: CountriesLiteral = Field( - description="Country to get GDP for.", default="united_states" + description="Country to get interest rate for.", default="united_states" ) frequency: Literal["monthly", "quarterly", "annual"] = Field( diff --git a/openbb_platform/providers/oecd/openbb_oecd/utils/constants.py b/openbb_platform/providers/oecd/openbb_oecd/utils/constants.py index 60d8f5ab1809..67daa4b05313 100644 --- a/openbb_platform/providers/oecd/openbb_oecd/utils/constants.py +++ b/openbb_platform/providers/oecd/openbb_oecd/utils/constants.py @@ -582,3 +582,52 @@ } CODE_TO_COUNTRY_IR = {v: k for k, v in COUNTRY_TO_CODE_IR.items()} + +COUNTRY_TO_CODE_SHARES = { + "slovenia": "SVN", + "russia": "RUS", + "latvia": "LVA", + "korea": "KOR", + "brazil": "BRA", + "france": "FRA", + "sweden": "SWE", + "luxembourg": "LUX", + "belgium": "BEL", + "china": "CHN", + "finland": "FIN", + "euro_area19": "EA19", + "japan": "JPN", + "hungary": "HUN", + "australia": "AUS", + "switzerland": "CHE", + "portugal": "PRT", + "estonia": "EST", + "canada": "CAN", + "slovak_republic": "SVK", + "turkey": "TUR", + "croatia": "HRV", + "denmark": "DNK", + "italy": "ITA", + "india": "IND", + "south_africa": "ZAF", + "czech_republic": "CZE", + "new_zealand": "NZL", + "netherlands": "NLD", + "iceland": "ISL", + "germany": "DEU", + "indonesia": "IDN", + "ireland": "IRL", + "united_states": "USA", + "chile": "CHL", + "lithuania": "LTU", + "greece": "GRC", + "united_kingdom": "GBR", + "colombia": "COL", + "norway": "NOR", + "spain": "ESP", + "israel": "ISR", + "poland": "POL", + "austria": "AUT", + "mexico": "MEX", +} +CODE_TO_COUNTRY_SHARES = {v: k for k, v in COUNTRY_TO_CODE_SHARES.items()}