Skip to content

Commit

Permalink
Add Share Price endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
jmaslek committed Mar 1, 2024
1 parent f356a58 commit 2928da1
Show file tree
Hide file tree
Showing 7 changed files with 260 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Share Prices Standard Model."""

from datetime import date as dateType
from typing import Optional

from pydantic import Field

from openbb_core.provider.abstract.data import Data
from openbb_core.provider.abstract.query_params import QueryParams
from openbb_core.provider.utils.descriptions import (
DATA_DESCRIPTIONS,
QUERY_DESCRIPTIONS,
)


class SharePriceQueryParams(QueryParams):
"""Share Price Query."""

start_date: Optional[dateType] = Field(
default=None, description=QUERY_DESCRIPTIONS.get("start_date")
)
end_date: Optional[dateType] = Field(
default=None, description=QUERY_DESCRIPTIONS.get("end_date")
)


class SharePriceData(Data):
"""Share Price Data."""

date: Optional[dateType] = Field(
default=None, description=DATA_DESCRIPTIONS.get("date")
)
value: Optional[float] = Field(
default=None,
description="Interest rate (given as a whole number, i.e 10=10%)",
)
country: Optional[str] = Field(
default=None,
description="Country for which interest rate is given",
)
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,30 @@ async def composite_leading_indicator(
return await OBBject.from_query(Query(**locals()))


@router.command(
model="SharePrice",
exclude_auto_examples=True,
examples=[
'obb.economy.share_price(country="all").to_df()',
],
)
async def share_price(
cc: CommandContext,
provider_choices: ProviderChoices,
standard_params: StandardParams,
extra_params: ExtraParams,
) -> OBBject:
"""Share price indices are calculated from the prices of common shares of companies
traded on national or foreign stock exchanges. They are usually determined by the
stock exchange, using the closing daily values for the monthly data, and normally
expressed as simple arithmetic averages of the daily data. A share price index
measures how the value of the stocks in the index is changing, a share return index
tells the investor what their “return” is, meaning how much money they would make as
a result of investing in that basket of shares.
"""
return await OBBject.from_query(Query(**locals()))


@router.command(
model="STIR",
exclude_auto_examples=True,
Expand Down
5 changes: 5 additions & 0 deletions openbb_platform/providers/oecd/openbb_oecd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from openbb_oecd.models.short_term_interest_rate import OECDSTIRFetcher
from openbb_oecd.models.unemployment import OECDUnemploymentFetcher

from openbb_platform.providers.oecd.openbb_oecd.models.share_price import (
OECDSharePriceFetcher,
)

oecd_provider = Provider(
name="oecd",
website="https://stats.oecd.org/",
Expand All @@ -23,5 +27,6 @@
"STIR": OECDSTIRFetcher,
"LTIR": OECDLTIRFetcher,
"ConsumerPriceIndex": OECDCPIFetcher,
"SharePrice": OECDSharePriceFetcher,
},
)
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""OECD Long Term Interest Rate Rate Data."""
"""OECD Long Term Interest Rate Data."""

# pylint: disable=unused-argument

Expand All @@ -23,7 +23,7 @@ class OECDLTIRQueryParams(LTIRQueryParams):
"""OECD Short Term Interest Rate Query."""

country: CountriesLiteral = Field(
description="Country to get GDP for.", default="united_states"
description="Country to get interest rate for.", default="united_states"
)

frequency: Literal["monthly", "quarterly", "annual"] = Field(
Expand Down
138 changes: 138 additions & 0 deletions openbb_platform/providers/oecd/openbb_oecd/models/share_price.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""OECD Short Term Interest Rate Data."""

# pylint: disable=unused-argument

import re
from datetime import date, timedelta
from typing import Any, Dict, List, Literal, Optional, Union

from openbb_core.provider.abstract.fetcher import Fetcher
from openbb_core.provider.standard_models.share_price import (
SharePriceData,
SharePriceQueryParams,
)
from openbb_oecd.utils import helpers
from openbb_oecd.utils.constants import CODE_TO_COUNTRY_SHARES, COUNTRY_TO_CODE_SHARES
from pydantic import Field, field_validator

countries = tuple(CODE_TO_COUNTRY_SHARES.values()) + ("all",)
CountriesLiteral = Literal[countries] # type: ignore


class OECDSharePriceQueryParams(SharePriceQueryParams):
"""OECD Share Price Rate Query."""

country: CountriesLiteral = Field(
description="Country to get share price for.", default="united_states"
)

frequency: Literal["monthly", "quarterly", "annual"] = Field(
description="Frequency to get share price for for.", default="monthly"
)

units: Literal["yoy", "pop"] = Field(
description="Units to get share price for. Either change over period (pop) or change over year (yoy)",
default="yoy",
)


class OECDSharePriceData(SharePriceData):
"""OECD Share Price Rate Data."""

@field_validator("date", mode="before")
@classmethod
def date_validate(cls, in_date: Union[date, str]): # pylint: disable=E0213
"""Validate value."""
if isinstance(in_date, str):
# i.e 2022-Q1
if re.match(r"\d{4}-Q[1-4]$", in_date):
year, quarter = in_date.split("-")
_year = int(year)
if quarter == "Q1":
return date(_year, 3, 31)
if quarter == "Q2":
return date(_year, 6, 30)
if quarter == "Q3":
return date(_year, 9, 30)
if quarter == "Q4":
return date(_year, 12, 31)
# Now match if it is monthly, i.e 2022-01
elif re.match(r"\d{4}-\d{2}$", in_date):
year, month = map(int, in_date.split("-")) # type: ignore
if month == 12:
return date(year, month, 31) # type: ignore
next_month = date(year, month + 1, 1) # type: ignore
return date(next_month.year, next_month.month, 1) - timedelta(days=1)
# Now match if it is yearly, i.e 2022
elif re.match(r"\d{4}$", in_date):
return date(int(in_date), 12, 31)
# If the input date is a year
if isinstance(in_date, int):
return date(in_date, 12, 31)

return in_date


class OECDSharePriceFetcher(
Fetcher[OECDSharePriceQueryParams, List[OECDSharePriceData]]
):
"""Transform the query, extract and transform the data from the OECD endpoints."""

@staticmethod
def transform_query(params: Dict[str, Any]) -> OECDSharePriceQueryParams:
"""Transform the query."""
transformed_params = params.copy()
if transformed_params["start_date"] is None:
transformed_params["start_date"] = date(1950, 1, 1)
if transformed_params["end_date"] is None:
transformed_params["end_date"] = date(date.today().year, 12, 31)

return OECDSharePriceQueryParams(**transformed_params)

@staticmethod
def extract_data(
query: OECDSharePriceQueryParams,
credentials: Optional[Dict[str, str]],
**kwargs: Any,
) -> List[Dict]:
"""Return the raw data from the OECD endpoint."""
frequency = query.frequency[0].upper()
country = (
"" if query.country == "all" else COUNTRY_TO_CODE_SHARES[query.country]
)
transform = {"pop": "G1", "yoy": "GY"}[query.units]
query_dict = {
k: v
for k, v in query.__dict__.items()
if k not in ["start_date", "end_date"]
}

url = f"https://sdmx.oecd.org/public/rest/data/OECD.SDD.STES,DSD_KEI@DF_KEI,4.0/{country}.{frequency}.SHARE....{transform}"
data = helpers.get_possibly_cached_data(
url, function="economy_share_price", query_dict=query_dict
)
url_query = f"FREQ=='{frequency}' & TRANSFORMATION=='{transform}'"
url_query = url_query + f" & REF_AREA=='{country}'" if country else url_query
# Filter down
data = (
data.query(url_query)
.reset_index(drop=True)[["REF_AREA", "TIME_PERIOD", "VALUE"]]
.rename(
columns={"REF_AREA": "country", "TIME_PERIOD": "date", "VALUE": "value"}
)
)
data["country"] = data["country"].map(CODE_TO_COUNTRY_SHARES)
data = data.fillna("N/A").replace("N/A", None)
data["date"] = data["date"].apply(helpers.oecd_date_to_python_date)
data = data[
(data["date"] <= query.end_date) & (data["date"] >= query.start_date)
]

return data.to_dict(orient="records")

@staticmethod
def transform_data(
query: OECDSharePriceQueryParams, data: List[Dict], **kwargs: Any
) -> List[OECDSharePriceData]:
"""Transform the data from the OECD endpoint."""
return [OECDSharePriceData.model_validate(d) for d in data]
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""OECD Short Term Interest Rate Rate Data."""
"""OECD Short Term Interest Rate Data."""

# pylint: disable=unused-argument

Expand All @@ -23,7 +23,7 @@ class OECDSTIRQueryParams(STIRQueryParams):
"""OECD Short Term Interest Rate Query."""

country: CountriesLiteral = Field(
description="Country to get GDP for.", default="united_states"
description="Country to get interest rate for.", default="united_states"
)

frequency: Literal["monthly", "quarterly", "annual"] = Field(
Expand Down
49 changes: 49 additions & 0 deletions openbb_platform/providers/oecd/openbb_oecd/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,3 +582,52 @@
}

CODE_TO_COUNTRY_IR = {v: k for k, v in COUNTRY_TO_CODE_IR.items()}

COUNTRY_TO_CODE_SHARES = {
"slovenia": "SVN",
"russia": "RUS",
"latvia": "LVA",
"korea": "KOR",
"brazil": "BRA",
"france": "FRA",
"sweden": "SWE",
"luxembourg": "LUX",
"belgium": "BEL",
"china": "CHN",
"finland": "FIN",
"euro_area19": "EA19",
"japan": "JPN",
"hungary": "HUN",
"australia": "AUS",
"switzerland": "CHE",
"portugal": "PRT",
"estonia": "EST",
"canada": "CAN",
"slovak_republic": "SVK",
"turkey": "TUR",
"croatia": "HRV",
"denmark": "DNK",
"italy": "ITA",
"india": "IND",
"south_africa": "ZAF",
"czech_republic": "CZE",
"new_zealand": "NZL",
"netherlands": "NLD",
"iceland": "ISL",
"germany": "DEU",
"indonesia": "IDN",
"ireland": "IRL",
"united_states": "USA",
"chile": "CHL",
"lithuania": "LTU",
"greece": "GRC",
"united_kingdom": "GBR",
"colombia": "COL",
"norway": "NOR",
"spain": "ESP",
"israel": "ISR",
"poland": "POL",
"austria": "AUT",
"mexico": "MEX",
}
CODE_TO_COUNTRY_SHARES = {v: k for k, v in COUNTRY_TO_CODE_SHARES.items()}

0 comments on commit 2928da1

Please sign in to comment.