From fe16bc32a4af21055bbd45d32bf8c59c205c1145 Mon Sep 17 00:00:00 2001 From: Danglewood <85772166+deeleeramone@users.noreply.github.com> Date: Thu, 3 Oct 2024 10:58:00 -0700 Subject: [PATCH] [BugFix] Make User Preferences -> Defaults Work With Any Parameter (#6687) * make defaults work with any parameter * clear temp print * linters * unused local * fix test * add partial API support * pass defaults only if provider matches and param value is not None --------- Co-authored-by: Theodore Aptekarev --- .../core/openbb_core/api/router/commands.py | 44 +++++++++++++++++++ .../core/openbb_core/app/model/defaults.py | 7 +-- .../core/openbb_core/app/static/container.py | 26 ++++++++++- 3 files changed, 72 insertions(+), 5 deletions(-) diff --git a/openbb_platform/core/openbb_core/api/router/commands.py b/openbb_platform/core/openbb_core/api/router/commands.py index 1cbefbe1f1a6..a4f68a14bead 100644 --- a/openbb_platform/core/openbb_core/api/router/commands.py +++ b/openbb_platform/core/openbb_core/api/router/commands.py @@ -195,6 +195,50 @@ async def wrapper(*args: Tuple[Any], **kwargs: Dict[str, Any]) -> OBBject: UserService.read_from_file(), ) ) + p = path.strip("/").replace("/", ".") + defaults = ( + getattr(user_settings.defaults, "__dict__", {}) + .get("commands", {}) + .get(p, {}) + ) + + if defaults: + provider_choices = getattr( + kwargs.get("provider_choices", None), "__dict__", {} + ) + _provider = defaults.pop("provider", None) + + if ( + _provider + and isinstance(_provider, list) + and _provider[0] == provider_choices.get("provider") + ): + standard_params = getattr( + kwargs.pop("standard_params", None), "__dict__", {} + ) + extra_params = getattr(kwargs.pop("extra_params", None), "__dict__", {}) + + if "chart" in defaults: + kwargs["chart"] = defaults.pop("chart", False) + + if "chart_params" in defaults: + extra_params["chart_params"] = defaults.pop("chart_params", {}) + + for k, v in defaults.items(): + if k in standard_params and standard_params[k] is None: + standard_params[k] = v + elif (k in standard_params and standard_params[k] is not None) or ( + k in extra_params and extra_params[k] is not None + ): + continue + elif k not in extra_params or ( + k in extra_params and extra_params[k] is None + ): + extra_params[k] = v + + kwargs["standard_params"] = standard_params + kwargs["extra_params"] = extra_params + execute = partial(command_runner.run, path, user_settings) output: OBBject = await execute(*args, **kwargs) diff --git a/openbb_platform/core/openbb_core/app/model/defaults.py b/openbb_platform/core/openbb_core/app/model/defaults.py index 8dda33d222d4..3b283753eea5 100644 --- a/openbb_platform/core/openbb_core/app/model/defaults.py +++ b/openbb_platform/core/openbb_core/app/model/defaults.py @@ -1,6 +1,6 @@ """Defaults model.""" -from typing import Dict, List, Optional +from typing import Any from warnings import warn from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -13,7 +13,7 @@ class Defaults(BaseModel): model_config = ConfigDict(validate_assignment=True, populate_by_name=True) - commands: Dict[str, Dict[str, Optional[List[str]]]] = Field( + commands: dict[str, dict[str, Any]] = Field( default_factory=dict, alias="routes", ) @@ -41,13 +41,14 @@ def validate_before(cls, values: dict) -> dict: ) key = "routes" - new_values: Dict[str, Dict[str, Optional[List[str]]]] = {"commands": {}} + new_values: dict = {"commands": {}} for k, v in values.get(key, {}).items(): clean_k = k.strip("/").replace("/", ".") provider = v.get("provider") if v else None if isinstance(provider, str): v["provider"] = [provider] new_values["commands"][clean_k] = v + return new_values def update(self, incoming: "Defaults"): diff --git a/openbb_platform/core/openbb_core/app/static/container.py b/openbb_platform/core/openbb_core/app/static/container.py index c958f8fe602d..eac555e4e0f2 100644 --- a/openbb_platform/core/openbb_core/app/static/container.py +++ b/openbb_platform/core/openbb_core/app/static/container.py @@ -1,6 +1,6 @@ """Container class.""" -from typing import TYPE_CHECKING, Any, Optional, Tuple +from typing import TYPE_CHECKING, Any, Optional from openbb_core.app.model.abstract.error import OpenBBError @@ -22,6 +22,26 @@ def __init__(self, command_runner: "CommandRunner") -> None: def _run(self, *args, **kwargs) -> Any: """Run a command in the container.""" + endpoint = args[0][1:].replace("/", ".") if args else "" + defaults = self._command_runner.user_settings.defaults.commands + + if endpoint and defaults and defaults.get(endpoint): + default_params = { + k: v for k, v in defaults[endpoint].items() if k != "provider" + } + for k, v in default_params.items(): + if k == "chart" and v is True: + kwargs["chart"] = True + elif ( + k in kwargs["standard_params"] + and kwargs["standard_params"][k] is None + ): + kwargs["standard_params"][k] = v + elif ( + k in kwargs["extra_params"] and kwargs["extra_params"][k] is None + ) or k not in kwargs["extra_params"]: + kwargs["extra_params"][k] = v + obbject = self._command_runner.sync_run(*args, **kwargs) output_type = self._command_runner.user_settings.preferences.output_type if output_type == "OBBject": @@ -37,7 +57,7 @@ def _check_credentials(self, provider: str) -> Optional[bool]: return all(getattr(credentials, r, None) for r in required) def _get_provider( - self, choice: Optional[str], command: str, default_priority: Tuple[str, ...] + self, choice: Optional[str], command: str, default_priority: tuple[str, ...] ) -> str: """Get the provider to use in execution. @@ -69,6 +89,8 @@ def _get_provider( commands.get(command, {}).get("provider", []) or default_priority ) tries = [] + if len(providers) == 1: + return providers[0] for p in providers: result = self._check_credentials(p) if result: