Skip to content

Commit

Permalink
[Feature] Add Chart To econometrics.correlation_matrix (#6750)
Browse files Browse the repository at this point in the history
* add chart to correlation matrix

* lint

* handling for data.results without chart

---------

Co-authored-by: Igor Radovanovic <[email protected]>
  • Loading branch information
deeleeramone and IgorWounds authored Oct 11, 2024
1 parent 737e5d4 commit c58b567
Show file tree
Hide file tree
Showing 9 changed files with 301 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def get_data(menu: Literal["equity", "crypto"]):
@pytest.mark.parametrize(
"params, data_type",
[
({"data": ""}, "equity"),
({"data": ""}, "crypto"),
({"data": "", "method": "pearson"}, "equity"),
({"data": "", "method": "pearson"}, "crypto"),
],
)
@pytest.mark.integration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def get_data(menu: Literal["equity", "crypto"]):
@parametrize(
"params, data_type",
[
({"data": ""}, "equity"),
({"data": ""}, "crypto"),
({"data": "", "method": "pearson"}, "equity"),
({"data": "", "method": "pearson"}, "crypto"),
],
)
@pytest.mark.integration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
APIEx(parameters={"data": APIEx.mock_data("timeseries")}),
],
)
def correlation_matrix(data: List[Data]) -> OBBject[List[Data]]:
def correlation_matrix(
data: List[Data], method: Literal["pearson", "kendall", "spearman"] = "pearson"
) -> OBBject[List[Data]]:
"""Get the correlation matrix of an input dataset.
The correlation matrix provides a view of how different variables in your dataset relate to one another.
Expand All @@ -37,6 +39,11 @@ def correlation_matrix(data: List[Data]) -> OBBject[List[Data]]:
----------
data : List[Data]
Input dataset.
method : Literal["pearson", "kendall", "spearman"]
Method to use for correlation calculation. Default is "pearson".
pearson : standard correlation coefficient
kendall : Kendall Tau correlation coefficient
spearman : Spearman rank correlation
Returns
-------
Expand All @@ -49,9 +56,14 @@ def correlation_matrix(data: List[Data]) -> OBBject[List[Data]]:

df = basemodel_to_df(data)
# remove non float columns from the dataframe to perform the correlation
df = df.select_dtypes(include=["float64"])

corr = df.corr()
if "symbol" in df.columns and len(df.symbol.unique()) > 1 and "close" in df.columns:
df = df.pivot(
columns="symbol",
values="close",
)

corr = df.corr(method=method, numeric_only=True)

# replace nan values with None to allow for json serialization
corr = corr.replace(np.NaN, None)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Views for the Econometrics Extension."""

from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from openbb_charting.core.openbb_figure import (
OpenBBFigure,
)


class EconometricsViews:
"""Econometrics Views."""

@staticmethod
def econometrics_correlation_matrix( # noqa: PLR0912
**kwargs,
) -> tuple["OpenBBFigure", dict[str, Any]]:
"""Correlation Matrix Chart.
Parameters
----------
data : Union[list[Data], DataFrame]
Input dataset.
method : Literal["pearson", "kendall", "spearman"]
Method to use for correlation calculation. Default is "pearson".
pearson : standard correlation coefficient
kendall : Kendall Tau correlation coefficient
spearman : Spearman rank correlation
colorscale : str
Plotly colorscale to use for the heatmap. Default is "RdBu".
title : str
Title of the chart. Default is "Asset Correlation Matrix".
layout_kwargs : Dict[str, Any]
Additional keyword arguments to apply with figure.update_layout(), by default None.
"""
# pylint: disable=import-outside-toplevel
from openbb_charting.charts.correlation_matrix import correlation_matrix

return correlation_matrix(**kwargs)
3 changes: 3 additions & 0 deletions openbb_platform/extensions/econometrics/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,6 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry.plugins."openbb_core_extension"]
econometrics = "openbb_econometrics.econometrics_router:router"

[tool.poetry.plugins."openbb_charting_extension"]
econometrics = "openbb_econometrics.econometrics_views:EconometricsViews"
Original file line number Diff line number Diff line change
Expand Up @@ -897,3 +897,44 @@ def test_charting_economy_survey_bls_series(params, headers):
assert chart
assert not fig
assert list(chart.keys()) == ["content", "format"]


@parametrize(
"params",
[
(
{
"data": "",
"method": "pearson",
"chart": True,
}
)
],
)
@pytest.mark.integration
def test_charting_econometrics_correlation_matrix(params, headers):
"""Test chart econometrics correlation matrix."""
# pylint:disable=import-outside-toplevel
from pandas import DataFrame

url = "http://0.0.0.0:8000/api/v1/equity/price/historical?symbol=AAPL,MSFT,GOOG&provider=yfinance"
result = requests.get(url, headers=headers, timeout=10)
df = DataFrame(result.json()["results"])
df = df.pivot(index="date", columns="symbol", values="close").reset_index()
body = df.to_dict(orient="records")

params = {p: v for p, v in params.items() if v}

query_str = get_querystring(params, [])
url = f"http://0.0.0.0:8000/api/v1/econometrics/correlation_matrix?{query_str}"
result = requests.post(url, headers=headers, timeout=10, data=json.dumps(body))

assert isinstance(result, requests.Response)
assert result.status_code == 200

chart = result.json()["chart"]
fig = chart.pop("fig", {})

assert chart
assert not fig
assert list(chart.keys()) == ["content", "format"]
Original file line number Diff line number Diff line change
Expand Up @@ -728,3 +728,34 @@ def test_charting_economy_survey_bls_series(params, obb):
assert len(result.results) > 0
assert result.chart.content
assert isinstance(result.chart.fig, OpenBBFigure)


@parametrize(
"params",
[
(
{
"data": "",
"method": "pearson",
"chart": True,
}
)
],
)
@pytest.mark.integration
def test_charting_econometrics_correlation_matrix(params, obb):
"""Test chart econometrics correlation matrix."""

symbols = "XRT,XLB,XLI,XLH,XLC,XLY,XLU,XLK".split(",")
params["data"] = (
obb.equity.price.historical(symbol=symbols, provider="yfinance")
.to_df()
.pivot(columns="symbol", values="close")
.filter(items=symbols, axis=1)
)
result = obb.econometrics.correlation_matrix(**params)
assert result
assert isinstance(result, OBBject)
assert len(result.results) > 0
assert result.chart.content
assert isinstance(result.chart.fig, OpenBBFigure)
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Charting Class implementation."""

# pylint: disable=too-many-arguments,unused-argument
# pylint: disable=too-many-arguments,unused-argument,too-many-positional-arguments

from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -57,6 +57,8 @@ class Charting:
Create a line chart from external data.
create_bar_chart
Create a bar chart, on a single x-axis with one or more values for the y-axis, from external data.
create_correlation_matrix
Create a correlation matrix from external data.
toggle_chart_style
Toggle the chart style, of an existing chart, between light and dark mode.
"""
Expand Down Expand Up @@ -367,6 +369,54 @@ def create_bar_chart(

return fig

def create_correlation_matrix(
self,
data: Union[
list[Data],
"DataFrame",
],
method: Literal["pearson", "kendall", "spearman"] = "pearson",
colorscale: str = "RdBu",
title: str = "Asset Correlation Matrix",
layout_kwargs: Optional[Dict[str, Any]] = None,
):
"""Create a correlation matrix from external data.
Parameters
----------
data : Union[list[Data], DataFrame]
Input dataset.
method : Literal["pearson", "kendall", "spearman"]
Method to use for correlation calculation. Default is "pearson".
pearson : standard correlation coefficient
kendall : Kendall Tau correlation coefficient
spearman : Spearman rank correlation
colorscale : str
Plotly colorscale to use for the heatmap. Default is "RdBu".
title : str
Title of the chart. Default is "Asset Correlation Matrix".
layout_kwargs : Dict[str, Any]
Additional keyword arguments to apply with figure.update_layout(), by default None.
Returns
-------
OpenBBFigure
The OpenBBFigure object.
"""
# pylint: disable=import-outside-toplevel
from openbb_charting.charts.correlation_matrix import correlation_matrix

kwargs = {
"data": data,
"method": method,
"colorscale": colorscale,
"title": title,
"layout_kwargs": layout_kwargs,
}
fig, _ = correlation_matrix(**kwargs)
fig = self._set_chart_style(fig)
return fig

# pylint: disable=inconsistent-return-statements
def show(self, render: bool = True, **kwargs):
"""Display chart and save it to the OBBject."""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""Correlation Matrix Chart."""

from typing import TYPE_CHECKING, Any, Union

if TYPE_CHECKING:
from plotly.graph_objs import Figure # noqa
from openbb_charting.core.openbb_figure import OpenBBFigure # noqa


def correlation_matrix( # noqa: PLR0912
**kwargs,
) -> tuple[Union["OpenBBFigure", "Figure"], dict[str, Any]]:
"""Correlation Matrix Chart."""
# pylint: disable=import-outside-toplevel
from numpy import ones_like, triu # noqa
from openbb_core.app.utils import basemodel_to_df # noqa
from openbb_charting.core.openbb_figure import OpenBBFigure
from openbb_charting.core.chart_style import ChartStyle
from plotly.graph_objs import Figure, Heatmap, Layout
from pandas import DataFrame

if "data" in kwargs and isinstance(kwargs["data"], DataFrame):
corr = kwargs["data"]
elif "data" in kwargs and isinstance(kwargs["data"], list):
corr = basemodel_to_df(kwargs["data"], index=kwargs.get("index", "date")) # type: ignore
else:
corr = basemodel_to_df(
kwargs["obbject_item"], index=kwargs.get("index", "date") # type: ignore
)
if (
"symbol" in corr.columns
and len(corr.symbol.unique()) > 1
and "close" in corr.columns
):
corr = corr.pivot(
columns="symbol",
values="close",
)

method = kwargs.get("method") or "pearson"
corr = corr.corr(method=method, numeric_only=True)

X = corr.columns.to_list()
x_replace = X[-1]
Y = X.copy()
y_replace = Y[0]
X = [x if x != x_replace else "" for x in X]
Y = [y if y != y_replace else "" for y in Y]
mask = triu(ones_like(corr, dtype=bool))
df = corr.mask(mask)
title = kwargs.get("title") or "Asset Correlation Matrix"
text_color = "white" if ChartStyle().plt_style == "dark" else "black"
colorscale = kwargs.get("colorscale") or "RdBu"
heatmap = Heatmap(
z=df,
x=X,
y=Y,
xgap=1,
ygap=1,
colorscale=colorscale,
colorbar=dict(
orientation="v",
x=0.9,
y=0.45,
xanchor="left",
yanchor="middle",
len=0.75,
bgcolor="rgba(0,0,0,0)" if text_color == "white" else "rgba(255,255,255,0)",
),
text=df.fillna(""),
texttemplate="%{text:.4f}",
hoverinfo="skip",
)
layout = Layout(
title=title,
title_x=0.5,
title_y=0.95,
xaxis=dict(
showgrid=False,
showline=False,
ticklen=0,
tickfont=dict(size=16),
ticklabelstandoff=10,
domain=[0.05, 1],
),
yaxis=dict(
showgrid=False,
side="left",
autorange="reversed",
showline=False,
ticklen=0,
tickfont=dict(size=16),
ticklabelstandoff=15,
domain=[0.05, 1],
),
margin=dict(r=20, t=0, b=50),
dragmode=False,
)
fig = Figure(data=[heatmap], layout=layout)
figure = OpenBBFigure(fig=fig)
figure.update_layout(
font=dict(color=text_color),
paper_bgcolor=(
"rgba(0,0,0,0)" if text_color == "white" else "rgba(255,255,255,0)"
),
plot_bgcolor=(
"rgba(0,0,0,0)" if text_color == "white" else "rgba(255,255,255,0)"
),
)
layout_kwargs = kwargs.get("layout_kwargs", {})

if layout_kwargs:
figure.update_layout(**layout_kwargs)

content = figure.show(external=True).to_plotly_json() # type: ignore

return figure, content

0 comments on commit c58b567

Please sign in to comment.