Skip to content

Commit

Permalink
[Feature] Add EIA Provider & Weekly Petroleum Status Report (#6693)
Browse files Browse the repository at this point in the history
* add eia provider and weekly petroleum status report

* codespell

* sort imports

* lint and update lock

* move raise up in try block

* add STEO

* steo tests

* linting

* mypy

* readme and docstring grammar police

* fix test..?

* add symbol field to steo

* static files

* update integration test

* mypy

* pylint

* fix test..?

* undo attempt to fix test

* add empty init file

* add empty init file

* change package name so the tests don't fail

* some updates

* lint

---------

Co-authored-by: Theodore Aptekarev <[email protected]>
  • Loading branch information
deeleeramone and piiq authored Oct 31, 2024
1 parent 246f3b0 commit dd42f7a
Show file tree
Hide file tree
Showing 30 changed files with 21,897 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Petroleum Status Report Standard Model."""

from datetime import date as dateType
from typing import Optional, Union

from pydantic import Field

from openbb_core.provider.abstract.data import Data
from openbb_core.provider.abstract.query_params import QueryParams
from openbb_core.provider.utils.descriptions import (
DATA_DESCRIPTIONS,
QUERY_DESCRIPTIONS,
)


class PetroleumStatusReportQueryParams(QueryParams):
"""Petroleum Status Report Query."""

start_date: Optional[dateType] = Field(
default=None,
description=QUERY_DESCRIPTIONS.get("start_date", ""),
)
end_date: Optional[dateType] = Field(
default=None,
description=QUERY_DESCRIPTIONS.get("end_date", ""),
)


class PetroleumStatusReportData(Data):
"""Petroleum Status Report Data."""

date: dateType = Field(description=DATA_DESCRIPTIONS.get("date", ""))
table: Optional[str] = Field(description="Table name for the data.")
symbol: str = Field(description=DATA_DESCRIPTIONS.get("symbol", ""))
order: Optional[int] = Field(
default=None, description="Presented order of the data, relative to the table."
)
title: Optional[str] = Field(default=None, description="Title of the data.")
value: Union[int, float] = Field(description="Value of the data.")
unit: Optional[str] = Field(default=None, description="Unit or scale of the data.")
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Short Term Energy Outlook Standard Model."""

from datetime import date as dateType
from typing import Optional, Union

from pydantic import Field

from openbb_core.provider.abstract.data import Data
from openbb_core.provider.abstract.query_params import QueryParams
from openbb_core.provider.utils.descriptions import (
DATA_DESCRIPTIONS,
QUERY_DESCRIPTIONS,
)


class ShortTermEnergyOutlookQueryParams(QueryParams):
"""Short Term Energy Outlook Query."""

start_date: Optional[dateType] = Field(
default=None,
description=QUERY_DESCRIPTIONS.get("start_date", ""),
)
end_date: Optional[dateType] = Field(
default=None,
description=QUERY_DESCRIPTIONS.get("end_date", ""),
)


class ShortTermEnergyOutlookData(Data):
"""Short Term Energy Outlook Data."""

date: dateType = Field(description=DATA_DESCRIPTIONS.get("date", ""))
table: Optional[str] = Field(default=None, description="Table name for the data.")
symbol: str = Field(description=DATA_DESCRIPTIONS.get("symbol", ""))
order: Optional[int] = Field(
default=None, description="Presented order of the data, relative to the table."
)
title: Optional[str] = Field(default=None, description="Title of the data.")
value: Union[int, float] = Field(description="Value of the data.")
unit: Optional[str] = Field(default=None, description="Unit or scale of the data.")
1 change: 1 addition & 0 deletions openbb_platform/dev_install.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
openbb-sec = { path = "./providers/sec", develop = true }
openbb-tiingo = { path = "./providers/tiingo", develop = true }
openbb-tradingeconomics = { path = "./providers/tradingeconomics", develop = true }
openbb-us-eia = { path = "./providers/eia", develop = true }
openbb-yfinance = { path = "./providers/yfinance", develop = true }
openbb-commodity = { path = "./extensions/commodity", develop = true }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,67 @@ def test_commodity_price_spot(params, headers):
result = requests.get(url, headers=headers, timeout=10)
assert isinstance(result, requests.Response)
assert result.status_code == 200


@pytest.mark.parametrize(
"params",
[
(
{
"category": "balance_sheet",
"table": "stocks",
"start_date": None,
"end_date": None,
"provider": "eia",
"use_cache": True,
}
),
(
{
"category": "weekly_estimates",
"table": "crude_production",
"start_date": "2020-01-01",
"end_date": "2023-12-31",
"provider": "eia",
"use_cache": True,
}
),
],
)
@pytest.mark.integration
def test_commodity_petroleum_status_report(params, headers):
"""Test the Petroleum Status Report endpoint."""
params = {p: v for p, v in params.items() if v}

query_str = get_querystring(params, [])
url = f"http://0.0.0.0:8000/api/v1/commodity/petroleum_status_report?{query_str}"
result = requests.get(url, headers=headers, timeout=10)
assert isinstance(result, requests.Response)
assert result.status_code == 200


@pytest.mark.parametrize(
"params",
[
(
{
"table": "01",
"symbol": None,
"start_date": "2024-09-01",
"end_date": "2024-10-01",
"provider": "eia",
"frequency": "month",
}
),
],
)
@pytest.mark.integration
def test_commodity_short_term_energy_outlook(params, headers):
"""Test the Short Term Energy Outlook endpoint."""
params = {p: v for p, v in params.items() if v}

query_str = get_querystring(params, [])
url = f"http://0.0.0.0:8000/api/v1/commodity/short_term_energy_outlook?{query_str}"
result = requests.get(url, headers=headers, timeout=10)
assert isinstance(result, requests.Response)
assert result.status_code == 200
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,61 @@ def test_commodity_price_spot(params, obb):
assert result
assert isinstance(result, OBBject)
assert len(result.results) > 0


@pytest.mark.parametrize(
"params",
[
(
{
"category": "balance_sheet",
"table": "stocks",
"start_date": None,
"end_date": None,
"provider": "eia",
"use_cache": True,
}
),
(
{
"category": "weekly_estimates",
"table": "crude_production",
"start_date": "2020-01-01",
"end_date": "2023-12-31",
"provider": "eia",
"use_cache": True,
}
),
],
)
@pytest.mark.integration
def test_commodity_petroleum_status_report(params, obb):
"""Test Commodity Petroleum Status Report endpoint."""
result = obb.commodity.petroleum_status_report(**params)
assert result
assert isinstance(result, OBBject)
assert len(result.results) > 0


@pytest.mark.parametrize(
"params",
[
(
{
"table": "01",
"symbol": None,
"start_date": "2024-09-01",
"end_date": "2024-10-01",
"provider": "eia",
"frequency": "month",
}
),
],
)
@pytest.mark.integration
def test_commodity_short_term_energy_outlook(params, obb):
"""Test Commodity Short Term Energy Outlook endpoint."""
result = obb.commodity.short_term_energy_outlook(**params)
assert result
assert isinstance(result, OBBject)
assert len(result.results) > 0
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# pylint: disable=unused-argument,unused-import
# flake8: noqa: F401

# pylint: disable=unused-argument

from openbb_core.app.model.command_context import CommandContext
from openbb_core.app.model.example import APIEx
from openbb_core.app.model.obbject import OBBject
Expand All @@ -20,3 +22,59 @@


router.include_router(price_router)


@router.command(
model="PetroleumStatusReport",
examples=[
APIEx(
description="Get the EIA's Weekly Petroleum Status Report.",
parameters={"provider": "eia"},
),
APIEx(
description="Select the category of data, and filter for a specific table within the report.",
parameters={
"category": "weekly_estimates",
"table": "imports",
"provider": "eia",
},
),
],
)
async def petroleum_status_report(
cc: CommandContext,
provider_choices: ProviderChoices,
standard_params: StandardParams,
extra_params: ExtraParams,
) -> OBBject:
"""EIA Weekly Petroleum Status Report."""
return await OBBject.from_query(Query(**locals()))


@router.command(
model="ShortTermEnergyOutlook",
examples=[
APIEx(
description="Get the EIA's Short Term Energy Outlook.",
parameters={"provider": "eia"},
),
APIEx(
description="Select the specific table of data from the STEO. Table 03d is World Crude Oil Production.",
parameters={
"table": "03d",
"provider": "eia",
},
),
],
)
async def short_term_energy_outlook(
cc: CommandContext,
provider_choices: ProviderChoices,
standard_params: StandardParams,
extra_params: ExtraParams,
) -> OBBject:
"""Monthly short term (18 month) projections using EIA's STEO model.
Source: www.eia.gov/steo/
"""
return await OBBject.from_query(Query(**locals()))
Loading

0 comments on commit dd42f7a

Please sign in to comment.