Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Add Validators For date Fields For Multiple Items Allowed. #6671

Merged
merged 3 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions openbb_platform/core/openbb_core/api/exception_handlers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Exception handlers module."""

import logging
from collections.abc import Iterable
from typing import Any

from fastapi import Request
Expand Down Expand Up @@ -31,16 +32,23 @@ async def _handle(exception: Exception, status_code: int, detail: Any):
@staticmethod
async def exception(_: Request, error: Exception) -> JSONResponse:
"""Exception handler for Base Exception."""
# Required parameters are missing and is not handled by ValidationError.
errors = error.errors(include_url=False) if hasattr(error, "errors") else error
if errors:
for err in errors:
if err.get("type") == "missing":
return await ExceptionHandlers._handle(
exception=error,
status_code=422,
detail={**err},
)
if isinstance(errors, ValueError):
return await ExceptionHandlers._handle(
exception=errors,
status_code=422,
detail=errors.args,
)
# Required parameters are missing and is not handled by ValidationError.
if isinstance(errors, Iterable):
for err in errors:
if err.get("type") == "missing":
return await ExceptionHandlers._handle(
exception=error,
status_code=422,
detail={**err},
)
return await ExceptionHandlers._handle(
exception=error,
status_code=500,
Expand All @@ -61,7 +69,13 @@ async def validation(request: Request, error: ValidationError):
loc in query_params for err in errors for loc in err.get("loc", ())
)
if "QueryParams" in error.title and all_in_query:
detail = [{**err, "loc": ("query",) + err.get("loc", ())} for err in errors]
detail = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i really admire your ability to write lines like this one

{
**{k: v for k, v in err.items() if k != "ctx"},
"loc": ("query",) + err.get("loc", ()),
}
for err in errors
]
return await ExceptionHandlers._handle(
exception=error,
status_code=422,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from datetime import date as dateType
from typing import Optional, Union

from pydantic import Field
from pydantic import Field, field_validator

from openbb_core.provider.abstract.data import Data
from openbb_core.provider.abstract.query_params import QueryParams
Expand All @@ -23,11 +23,32 @@ class ReleaseTableQueryParams(QueryParams):
default=None,
description="The element ID of a specific table in the release.",
)
date: Optional[Union[dateType, str]] = Field(
date: Union[None, dateType, str] = Field(
default=None,
description=QUERY_DESCRIPTIONS.get("date", ""),
)

@field_validator("date", mode="before", check_fields=False)
@classmethod
def _validate_date(cls, v):
"""Validate the date."""
# pylint: disable=import-outside-toplevel
from pandas import to_datetime

if v is None:
return None
if isinstance(v, dateType):
return v.strftime("%Y-%m-%d")
new_dates: list = []
if isinstance(v, str):
dates = v.split(",")
if isinstance(v, list):
dates = v
for date in dates:
new_dates.append(to_datetime(date).date().strftime("%Y-%m-%d"))

return ",".join(new_dates) if new_dates else None


class ReleaseTableData(Data):
"""FRED Release Table Data."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,27 @@ def to_upper(cls, v):
"""Convert field to uppercase."""
return v.upper()

@field_validator("date", mode="before", check_fields=False)
@classmethod
def _validate_date(cls, v):
"""Validate the date."""
# pylint: disable=import-outside-toplevel
from pandas import to_datetime

if v is None:
return None
if isinstance(v, dateType):
return v.strftime("%Y-%m-%d")
new_dates: list = []
if isinstance(v, str):
dates = v.split(",")
if isinstance(v, list):
dates = v
for date in dates:
new_dates.append(to_datetime(date).date().strftime("%Y-%m-%d"))

return ",".join(new_dates) if new_dates else None


class FuturesCurveData(Data):
"""Futures Curve Data."""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Yield Curve Standard Model."""

from datetime import date as dateType
from typing import Optional
from typing import Optional, Union

from pydantic import Field
from pydantic import Field, field_validator

from openbb_core.provider.abstract.data import Data
from openbb_core.provider.abstract.query_params import QueryParams
Expand All @@ -16,12 +16,33 @@
class YieldCurveQueryParams(QueryParams):
"""Yield Curve Query."""

date: Optional[str] = Field(
date: Union[None, dateType, str] = Field(
default=None,
description=QUERY_DESCRIPTIONS.get("date", "")
+ " By default is the current data.",
)

@field_validator("date", mode="before", check_fields=False)
@classmethod
def _validate_date(cls, v):
"""Validate the date."""
# pylint: disable=import-outside-toplevel
from pandas import to_datetime

if v is None:
return None
if isinstance(v, dateType):
return v.strftime("%Y-%m-%d")
new_dates: list = []
if isinstance(v, str):
dates = v.split(",")
if isinstance(v, list):
dates = v
for date in dates:
new_dates.append(to_datetime(date).date().strftime("%Y-%m-%d"))

return ",".join(new_dates) if new_dates else None


class YieldCurveData(Data):
"""Yield Curve Data."""
Expand Down
Loading