Skip to content

Commit

Permalink
[BugFix] Make User Preferences -> Defaults Work With Any Parameter (#…
Browse files Browse the repository at this point in the history
…6687)

* make defaults work with any parameter

* clear temp print

* linters

* unused local

* fix test

* add partial API support

* pass defaults only if provider matches and param value is not None

---------

Co-authored-by: Theodore Aptekarev <[email protected]>
  • Loading branch information
deeleeramone and piiq authored Oct 3, 2024
1 parent fe066cb commit fe16bc3
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 5 deletions.
44 changes: 44 additions & 0 deletions openbb_platform/core/openbb_core/api/router/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,50 @@ async def wrapper(*args: Tuple[Any], **kwargs: Dict[str, Any]) -> OBBject:
UserService.read_from_file(),
)
)
p = path.strip("/").replace("/", ".")
defaults = (
getattr(user_settings.defaults, "__dict__", {})
.get("commands", {})
.get(p, {})
)

if defaults:
provider_choices = getattr(
kwargs.get("provider_choices", None), "__dict__", {}
)
_provider = defaults.pop("provider", None)

if (
_provider
and isinstance(_provider, list)
and _provider[0] == provider_choices.get("provider")
):
standard_params = getattr(
kwargs.pop("standard_params", None), "__dict__", {}
)
extra_params = getattr(kwargs.pop("extra_params", None), "__dict__", {})

if "chart" in defaults:
kwargs["chart"] = defaults.pop("chart", False)

if "chart_params" in defaults:
extra_params["chart_params"] = defaults.pop("chart_params", {})

for k, v in defaults.items():
if k in standard_params and standard_params[k] is None:
standard_params[k] = v
elif (k in standard_params and standard_params[k] is not None) or (
k in extra_params and extra_params[k] is not None
):
continue
elif k not in extra_params or (
k in extra_params and extra_params[k] is None
):
extra_params[k] = v

kwargs["standard_params"] = standard_params
kwargs["extra_params"] = extra_params

execute = partial(command_runner.run, path, user_settings)
output: OBBject = await execute(*args, **kwargs)

Expand Down
7 changes: 4 additions & 3 deletions openbb_platform/core/openbb_core/app/model/defaults.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Defaults model."""

from typing import Dict, List, Optional
from typing import Any
from warnings import warn

from pydantic import BaseModel, ConfigDict, Field, model_validator
Expand All @@ -13,7 +13,7 @@ class Defaults(BaseModel):

model_config = ConfigDict(validate_assignment=True, populate_by_name=True)

commands: Dict[str, Dict[str, Optional[List[str]]]] = Field(
commands: dict[str, dict[str, Any]] = Field(
default_factory=dict,
alias="routes",
)
Expand Down Expand Up @@ -41,13 +41,14 @@ def validate_before(cls, values: dict) -> dict:
)
key = "routes"

new_values: Dict[str, Dict[str, Optional[List[str]]]] = {"commands": {}}
new_values: dict = {"commands": {}}
for k, v in values.get(key, {}).items():
clean_k = k.strip("/").replace("/", ".")
provider = v.get("provider") if v else None
if isinstance(provider, str):
v["provider"] = [provider]
new_values["commands"][clean_k] = v

return new_values

def update(self, incoming: "Defaults"):
Expand Down
26 changes: 24 additions & 2 deletions openbb_platform/core/openbb_core/app/static/container.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Container class."""

from typing import TYPE_CHECKING, Any, Optional, Tuple
from typing import TYPE_CHECKING, Any, Optional

from openbb_core.app.model.abstract.error import OpenBBError

Expand All @@ -22,6 +22,26 @@ def __init__(self, command_runner: "CommandRunner") -> None:

def _run(self, *args, **kwargs) -> Any:
"""Run a command in the container."""
endpoint = args[0][1:].replace("/", ".") if args else ""
defaults = self._command_runner.user_settings.defaults.commands

if endpoint and defaults and defaults.get(endpoint):
default_params = {
k: v for k, v in defaults[endpoint].items() if k != "provider"
}
for k, v in default_params.items():
if k == "chart" and v is True:
kwargs["chart"] = True
elif (
k in kwargs["standard_params"]
and kwargs["standard_params"][k] is None
):
kwargs["standard_params"][k] = v
elif (
k in kwargs["extra_params"] and kwargs["extra_params"][k] is None
) or k not in kwargs["extra_params"]:
kwargs["extra_params"][k] = v

obbject = self._command_runner.sync_run(*args, **kwargs)
output_type = self._command_runner.user_settings.preferences.output_type
if output_type == "OBBject":
Expand All @@ -37,7 +57,7 @@ def _check_credentials(self, provider: str) -> Optional[bool]:
return all(getattr(credentials, r, None) for r in required)

def _get_provider(
self, choice: Optional[str], command: str, default_priority: Tuple[str, ...]
self, choice: Optional[str], command: str, default_priority: tuple[str, ...]
) -> str:
"""Get the provider to use in execution.
Expand Down Expand Up @@ -69,6 +89,8 @@ def _get_provider(
commands.get(command, {}).get("provider", []) or default_priority
)
tries = []
if len(providers) == 1:
return providers[0]
for p in providers:
result = self._check_credentials(p)
if result:
Expand Down

0 comments on commit fe16bc3

Please sign in to comment.