Skip to content

Commit

Permalink
[Feature] Set provider fallback based on credentials (#6446)
Browse files Browse the repository at this point in the history
* feat: allow provider fallback based on credentials

* fix: rename prop and raise error

* fix: raise if all providers fail

* fix: remove dead line

* fix: rename defaults field to commands

* rename key and update docs

* fix: rename some stuff

* fix bug

* fix: mypy

* fix: unittests

* docstring

* fix: provider field description

* fix: error messages

* fix: msg

* feat: update defaults class

* feat: update defaults class

* unit tests

* fix test

* msg

* fix: add website documentation

* fix: detailed error message

* update core ruff version

* fix test

* rebuild

---------

Co-authored-by: Igor Radovanovic <[email protected]>
  • Loading branch information
montezdesousa and IgorWounds authored May 29, 2024
1 parent d3802c5 commit 01a71a8
Show file tree
Hide file tree
Showing 40 changed files with 950 additions and 1,086 deletions.
74 changes: 1 addition & 73 deletions openbb_platform/core/openbb_core/app/command_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from openbb_core.app.model.obbject import OBBject
from openbb_core.app.model.system_settings import SystemSettings
from openbb_core.app.model.user_settings import UserSettings
from openbb_core.app.provider_interface import ExtraParams, ProviderInterface
from openbb_core.app.provider_interface import ExtraParams
from openbb_core.app.router import CommandMap
from openbb_core.app.service.system_service import SystemService
from openbb_core.app.service.user_service import UserService
Expand Down Expand Up @@ -117,68 +117,6 @@ def update_command_context(

return kwargs

@staticmethod
def update_provider_choices(
func: Callable,
command_coverage: Dict[str, List[str]],
route: str,
kwargs: Dict[str, Any],
route_default: Optional[Dict[str, Optional[str]]],
) -> Dict[str, Any]:
"""Update the provider choices with the available providers and set default provider."""

def _needs_provider(func: Callable) -> bool:
"""Check if the function needs a provider."""
parameters = signature(func).parameters.keys()
return "provider_choices" in parameters

def _has_provider(kwargs: Dict[str, Any]) -> bool:
"""Check if the kwargs already have a provider."""
provider_choices = kwargs.get("provider_choices")

if isinstance(provider_choices, dict): # when in python
return provider_choices.get("provider", None) is not None
if isinstance(provider_choices, object): # when running as fastapi
return getattr(provider_choices, "provider", None) is not None
return False

def _get_first_provider() -> Optional[str]:
"""Get the first available provider."""
available_providers = ProviderInterface().available_providers
return available_providers[0] if available_providers else None

def _get_default_provider(
command_coverage: Dict[str, List[str]],
route_default: Optional[Dict[str, Optional[str]]],
) -> Optional[str]:
"""
Get the default provider for the given route.
Either pick it from the user defaults or from the command coverage.
"""
cmd_cov_given_route = command_coverage.get(route)
command_cov_provider = (
cmd_cov_given_route[0] if cmd_cov_given_route else None
)

if route_default:
return route_default.get("provider", None) or command_cov_provider # type: ignore

return command_cov_provider

if not _has_provider(kwargs) and _needs_provider(func):
provider = (
_get_default_provider(
command_coverage,
route_default,
)
if route in command_coverage
else _get_first_provider()
)
kwargs["provider_choices"] = {"provider": provider}

return kwargs

@staticmethod
def _warn_kwargs(
extra_params: Dict[str, Any],
Expand Down Expand Up @@ -246,14 +184,12 @@ def build(
args: Tuple[Any, ...],
execution_context: ExecutionContext,
func: Callable,
route: str,
kwargs: Dict[str, Any],
) -> Dict[str, Any]:
"""Build the parameters for a function."""
func = cls.get_polished_func(func=func)
system_settings = execution_context.system_settings
user_settings = execution_context.user_settings
command_map = execution_context.command_map

kwargs = cls.merge_args_and_kwargs(
func=func,
Expand All @@ -266,13 +202,6 @@ def build(
system_settings=system_settings,
user_settings=user_settings,
)
kwargs = cls.update_provider_choices(
func=func,
command_coverage=command_map.command_coverage,
route=route,
kwargs=kwargs,
route_default=user_settings.defaults.routes.get(route, None),
)
kwargs = cls.validate_kwargs(
func=func,
kwargs=kwargs,
Expand Down Expand Up @@ -364,7 +293,6 @@ async def _execute_func(
args=args,
execution_context=execution_context,
func=func,
route=route,
kwargs=kwargs,
)

Expand Down
54 changes: 28 additions & 26 deletions openbb_platform/core/openbb_core/app/model/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import traceback
import warnings
from typing import Dict, Optional, Set, Tuple
from typing import Dict, List, Optional, Tuple

from pydantic import (
BaseModel,
Expand Down Expand Up @@ -36,37 +36,39 @@ class LoadingError(Exception):
class CredentialsLoader:
"""Here we create the Credentials model."""

credentials: Dict[str, Set[str]] = {}
credentials: Dict[str, List[str]] = {}

@staticmethod
def prepare(
credentials: Dict[str, Set[str]],
) -> Dict[str, Tuple[object, None]]:
def format_credentials(self) -> Dict[str, Tuple[object, None]]:
"""Prepare credentials map to be used in the Credentials model."""
formatted: Dict[str, Tuple[object, None]] = {}
for origin, creds in credentials.items():
for c in creds:
# Not sure we should do this, if you require the same credential it breaks
# if c in formatted:
# raise ValueError(f"Credential '{c}' already in use.")
formatted[c] = (
for c_origin, c_list in self.credentials.items():
for c_name in c_list:
if c_name in formatted:
warnings.warn(
message=f"Skipping '{c_name}', credential already in use.",
category=OpenBBWarning,
)
continue
formatted[c_name] = (
Optional[OBBSecretStr],
Field(
default=None, description=origin, alias=c.upper()
), # register the credential origin (obbject, providers)
Field(default=None, description=c_origin, alias=c_name.upper()),
)

return formatted
return dict(sorted(formatted.items()))

def from_obbject(self) -> None:
"""Load credentials from OBBject extensions."""
self.credentials["obbject"] = set()
for name, entry in ExtensionLoader().obbject_objects.items(): # type: ignore[attr-defined]
for ext_name, ext in ExtensionLoader().obbject_objects.items(): # type: ignore[attr-defined]
try:
for c in entry.credentials:
self.credentials["obbject"].add(c)
if ext_name in self.credentials:
warnings.warn(
message=f"Skipping '{ext_name}', name already in user.",
category=OpenBBWarning,
)
continue
self.credentials[ext_name] = ext.credentials
except Exception as e:
msg = f"Error loading extension: {name}\n"
msg = f"Error loading extension: {ext_name}\n"
if Env().DEBUG_MODE:
traceback.print_exception(type(e), e, e.__traceback__)
raise LoadingError(msg + f"\033[91m{e}\033[0m") from e
Expand All @@ -77,20 +79,20 @@ def from_obbject(self) -> None:

def from_providers(self) -> None:
"""Load credentials from providers."""
self.credentials["providers"] = set()
for c in ProviderInterface().credentials:
self.credentials["providers"].add(c)
self.credentials = ProviderInterface().credentials

def load(self) -> BaseModel:
"""Load credentials from providers."""
# We load providers first to give them priority choosing credential names
self.from_providers()
self.from_obbject()
return create_model( # type: ignore
model = create_model( # type: ignore
"Credentials",
__config__=ConfigDict(validate_assignment=True, populate_by_name=True),
**self.prepare(self.credentials),
**self.format_credentials(),
)
model.origins = self.credentials
return model


_Credentials = CredentialsLoader().load()
Expand Down
35 changes: 31 additions & 4 deletions openbb_platform/core/openbb_core/app/model/defaults.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,46 @@
"""Defaults model."""

from typing import Dict, Optional
from typing import Dict, List, Optional
from warnings import warn

from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, model_validator

from openbb_core.app.model.abstract.warning import OpenBBWarning


class Defaults(BaseModel):
"""Defaults."""

model_config = ConfigDict(validate_assignment=True)
model_config = ConfigDict(validate_assignment=True, populate_by_name=True)

routes: Dict[str, Dict[str, Optional[str]]] = Field(default_factory=dict)
commands: Dict[str, Dict[str, Optional[List[str]]]] = Field(
default_factory=dict,
alias="routes",
)

def __repr__(self) -> str:
"""Return string representation."""
return f"{self.__class__.__name__}\n\n" + "\n".join(
f"{k}: {v}" for k, v in self.model_dump().items()
)

@model_validator(mode="before")
@classmethod
def validate_before(cls, values: dict) -> dict:
"""Validate model (before)."""
key = "commands"
if "routes" in values:
warn(
message="'routes' is deprecated. Use 'commands' instead.",
category=OpenBBWarning,
)
key = "routes"

new_values: Dict[str, Dict[str, Optional[List[str]]]] = {"commands": {}}
for k, v in values.get(key, {}).items():
clean_k = k.strip("/").replace("/", ".")
provider = v.get("provider") if v else None
if isinstance(provider, str):
v["provider"] = [provider]
new_values["commands"][clean_k] = v
return new_values
4 changes: 2 additions & 2 deletions openbb_platform/core/openbb_core/app/provider_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ def map(self) -> MapType:
return self._map

@property
def credentials(self) -> List[str]:
"""Dictionary of required credentials by provider."""
def credentials(self) -> Dict[str, List[str]]:
"""Map providers to credentials."""
return self._registry_map.credentials

@property
Expand Down
63 changes: 52 additions & 11 deletions openbb_platform/core/openbb_core/app/static/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,59 @@ def _run(self, *args, **kwargs) -> Any:
return obbject
return getattr(obbject, "to_" + output_type)()

def _check_credentials(self, provider: str) -> Optional[bool]:
"""Check required credentials are populated."""
credentials = self._command_runner.user_settings.credentials
if provider not in credentials.origins:
return None
required = credentials.origins.get(provider)
return all(getattr(credentials, r, None) for r in required)

def _get_provider(
self, choice: Optional[str], cmd: str, available: Tuple[str, ...]
self, choice: Optional[str], command: str, default_priority: Tuple[str, ...]
) -> str:
"""Get the provider to use in execution."""
"""Get the provider to use in execution.
If no choice is specified, the configured priority list is used. A provider is used
when all of its required credentials are populated.
Parameters
----------
choice: Optional[str]
The provider choice, for example 'fmp'.
command: str
The command to get the provider for, for example 'equity.price.historical'
default_priority: Tuple[str, ...]
A tuple of available providers for the given command to use as default priority list.
Returns
-------
str
The provider to use in the command.
Raises
------
OpenBBError
Raises error when all the providers in the priority list failed.
"""
if choice is None:
if config_default := self._command_runner.user_settings.defaults.routes.get(
cmd, {}
).get("provider"):
if config_default in available:
return config_default
raise OpenBBError(
f"provider '{config_default}' is not available. Choose from: {', '.join(available)}."
)
return available[0]
commands = self._command_runner.user_settings.defaults.commands
providers = (
commands.get(command, {}).get("provider", []) or default_priority
)
tries = []
for p in providers:
result = self._check_credentials(p)
if result:
return p
elif result is False:
tries.append((p, "missing credentials"))
else:
tries.append((p, "not found"))

msg = "\n ".join([f"* '{pair[0]}' -> {pair[1]}" for pair in tries])
raise OpenBBError(
f"Provider fallback failed, please specify the provider or update credentials.\n"
f"[Providers]\n {msg}"
)
return choice
Loading

0 comments on commit 01a71a8

Please sign in to comment.