Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Set provider fallback based on credentials #6446

Merged
merged 31 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
52148a8
feat: allow provider fallback based on credentials
montezdesousa May 21, 2024
1a0748c
fix: rename prop and raise error
montezdesousa May 21, 2024
2f892d3
fix: raise if all providers fail
montezdesousa May 21, 2024
c15e4a1
Merge branch 'develop' into feature/provider-fallbacks
montezdesousa May 21, 2024
5aa90e2
fix: remove dead line
montezdesousa May 21, 2024
83482f4
fix: rename defaults field to commands
montezdesousa May 21, 2024
64a7adf
rename key and update docs
montezdesousa May 21, 2024
00003f5
fix: rename some stuff
montezdesousa May 21, 2024
d72a138
fix bug
montezdesousa May 21, 2024
8282406
fix: mypy
montezdesousa May 21, 2024
faf1180
fix: unittests
montezdesousa May 21, 2024
9943602
docstring
montezdesousa May 21, 2024
61d2043
fix: provider field description
montezdesousa May 21, 2024
2ce858e
fix: error messages
montezdesousa May 21, 2024
7d5abab
fix: msg
montezdesousa May 21, 2024
8d4002f
feat: update defaults class
montezdesousa May 22, 2024
a938fe5
feat: update defaults class
montezdesousa May 22, 2024
e0fb123
unit tests
montezdesousa May 22, 2024
56bb236
Merge branch 'develop' into feature/provider-fallbacks
montezdesousa May 22, 2024
49b5d80
fix test
montezdesousa May 22, 2024
e8a3d84
msg
montezdesousa May 22, 2024
1ca002b
Merge branch 'develop' into feature/provider-fallbacks
IgorWounds May 22, 2024
72a8c58
Merge branch 'develop' into feature/provider-fallbacks
montezdesousa May 23, 2024
4771f32
fix: add website documentation
montezdesousa May 23, 2024
e8e3bff
fix: detailed error message
montezdesousa May 23, 2024
4a39c11
update core ruff version
montezdesousa May 23, 2024
7680b81
fix test
montezdesousa May 23, 2024
22034a2
Merge branch 'develop' into feature/provider-fallbacks
montezdesousa May 23, 2024
dbbbd8c
rebuild
montezdesousa May 24, 2024
b9dc096
Merge branch 'develop' into feature/provider-fallbacks
IgorWounds May 27, 2024
8663662
Merge branch 'develop' into feature/provider-fallbacks
montezdesousa May 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 1 addition & 73 deletions openbb_platform/core/openbb_core/app/command_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from openbb_core.app.model.obbject import OBBject
from openbb_core.app.model.system_settings import SystemSettings
from openbb_core.app.model.user_settings import UserSettings
from openbb_core.app.provider_interface import ExtraParams, ProviderInterface
from openbb_core.app.provider_interface import ExtraParams
from openbb_core.app.router import CommandMap
from openbb_core.app.service.system_service import SystemService
from openbb_core.app.service.user_service import UserService
Expand Down Expand Up @@ -117,68 +117,6 @@ def update_command_context(

return kwargs

@staticmethod
def update_provider_choices(
func: Callable,
command_coverage: Dict[str, List[str]],
route: str,
kwargs: Dict[str, Any],
route_default: Optional[Dict[str, Optional[str]]],
) -> Dict[str, Any]:
"""Update the provider choices with the available providers and set default provider."""

def _needs_provider(func: Callable) -> bool:
"""Check if the function needs a provider."""
parameters = signature(func).parameters.keys()
return "provider_choices" in parameters

def _has_provider(kwargs: Dict[str, Any]) -> bool:
"""Check if the kwargs already have a provider."""
provider_choices = kwargs.get("provider_choices")

if isinstance(provider_choices, dict): # when in python
return provider_choices.get("provider", None) is not None
if isinstance(provider_choices, object): # when running as fastapi
return getattr(provider_choices, "provider", None) is not None
return False

def _get_first_provider() -> Optional[str]:
"""Get the first available provider."""
available_providers = ProviderInterface().available_providers
return available_providers[0] if available_providers else None

def _get_default_provider(
command_coverage: Dict[str, List[str]],
route_default: Optional[Dict[str, Optional[str]]],
) -> Optional[str]:
"""
Get the default provider for the given route.

Either pick it from the user defaults or from the command coverage.
"""
cmd_cov_given_route = command_coverage.get(route)
command_cov_provider = (
cmd_cov_given_route[0] if cmd_cov_given_route else None
)

if route_default:
return route_default.get("provider", None) or command_cov_provider # type: ignore

return command_cov_provider

if not _has_provider(kwargs) and _needs_provider(func):
provider = (
_get_default_provider(
command_coverage,
route_default,
)
if route in command_coverage
else _get_first_provider()
)
kwargs["provider_choices"] = {"provider": provider}

return kwargs

@staticmethod
def _warn_kwargs(
extra_params: Dict[str, Any],
Expand Down Expand Up @@ -246,14 +184,12 @@ def build(
args: Tuple[Any, ...],
execution_context: ExecutionContext,
func: Callable,
route: str,
kwargs: Dict[str, Any],
) -> Dict[str, Any]:
"""Build the parameters for a function."""
func = cls.get_polished_func(func=func)
system_settings = execution_context.system_settings
user_settings = execution_context.user_settings
command_map = execution_context.command_map

kwargs = cls.merge_args_and_kwargs(
func=func,
Expand All @@ -266,13 +202,6 @@ def build(
system_settings=system_settings,
user_settings=user_settings,
)
kwargs = cls.update_provider_choices(
func=func,
command_coverage=command_map.command_coverage,
route=route,
kwargs=kwargs,
route_default=user_settings.defaults.routes.get(route, None),
)
kwargs = cls.validate_kwargs(
func=func,
kwargs=kwargs,
Expand Down Expand Up @@ -364,7 +293,6 @@ async def _execute_func(
args=args,
execution_context=execution_context,
func=func,
route=route,
kwargs=kwargs,
)

Expand Down
54 changes: 28 additions & 26 deletions openbb_platform/core/openbb_core/app/model/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import traceback
import warnings
from typing import Dict, Optional, Set, Tuple
from typing import Dict, List, Optional, Tuple

from pydantic import (
BaseModel,
Expand Down Expand Up @@ -36,37 +36,39 @@ class LoadingError(Exception):
class CredentialsLoader:
"""Here we create the Credentials model."""

credentials: Dict[str, Set[str]] = {}
credentials: Dict[str, List[str]] = {}

@staticmethod
def prepare(
credentials: Dict[str, Set[str]],
) -> Dict[str, Tuple[object, None]]:
def format_credentials(self) -> Dict[str, Tuple[object, None]]:
"""Prepare credentials map to be used in the Credentials model."""
formatted: Dict[str, Tuple[object, None]] = {}
for origin, creds in credentials.items():
for c in creds:
# Not sure we should do this, if you require the same credential it breaks
# if c in formatted:
# raise ValueError(f"Credential '{c}' already in use.")
formatted[c] = (
for c_origin, c_list in self.credentials.items():
for c_name in c_list:
if c_name in formatted:
warnings.warn(
message=f"Skipping '{c_name}', credential already in use.",
category=OpenBBWarning,
)
continue
formatted[c_name] = (
Optional[OBBSecretStr],
Field(
default=None, description=origin, alias=c.upper()
), # register the credential origin (obbject, providers)
Field(default=None, description=c_origin, alias=c_name.upper()),
)

return formatted
return dict(sorted(formatted.items()))

def from_obbject(self) -> None:
"""Load credentials from OBBject extensions."""
self.credentials["obbject"] = set()
for name, entry in ExtensionLoader().obbject_objects.items(): # type: ignore[attr-defined]
for ext_name, ext in ExtensionLoader().obbject_objects.items(): # type: ignore[attr-defined]
try:
for c in entry.credentials:
self.credentials["obbject"].add(c)
if ext_name in self.credentials:
warnings.warn(
message=f"Skipping '{ext_name}', name already in user.",
category=OpenBBWarning,
)
continue
self.credentials[ext_name] = ext.credentials
except Exception as e:
msg = f"Error loading extension: {name}\n"
msg = f"Error loading extension: {ext_name}\n"
if Env().DEBUG_MODE:
traceback.print_exception(type(e), e, e.__traceback__)
raise LoadingError(msg + f"\033[91m{e}\033[0m") from e
Expand All @@ -77,20 +79,20 @@ def from_obbject(self) -> None:

def from_providers(self) -> None:
"""Load credentials from providers."""
self.credentials["providers"] = set()
for c in ProviderInterface().credentials:
self.credentials["providers"].add(c)
self.credentials = ProviderInterface().credentials

def load(self) -> BaseModel:
"""Load credentials from providers."""
# We load providers first to give them priority choosing credential names
self.from_providers()
self.from_obbject()
return create_model( # type: ignore
model = create_model( # type: ignore
"Credentials",
__config__=ConfigDict(validate_assignment=True, populate_by_name=True),
**self.prepare(self.credentials),
**self.format_credentials(),
)
model.origins = self.credentials
return model


_Credentials = CredentialsLoader().load()
Expand Down
35 changes: 31 additions & 4 deletions openbb_platform/core/openbb_core/app/model/defaults.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,46 @@
"""Defaults model."""

from typing import Dict, Optional
from typing import Dict, List, Optional
from warnings import warn

from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, model_validator

from openbb_core.app.model.abstract.warning import OpenBBWarning


class Defaults(BaseModel):
"""Defaults."""

model_config = ConfigDict(validate_assignment=True)
model_config = ConfigDict(validate_assignment=True, populate_by_name=True)

routes: Dict[str, Dict[str, Optional[str]]] = Field(default_factory=dict)
commands: Dict[str, Dict[str, Optional[List[str]]]] = Field(
default_factory=dict,
alias="routes",
)

def __repr__(self) -> str:
"""Return string representation."""
return f"{self.__class__.__name__}\n\n" + "\n".join(
f"{k}: {v}" for k, v in self.model_dump().items()
)

@model_validator(mode="before")
@classmethod
def validate_before(cls, values: dict) -> dict:
"""Validate model (before)."""
key = "commands"
if "routes" in values:
warn(
message="'routes' is deprecated. Use 'commands' instead.",
category=OpenBBWarning,
)
key = "routes"

new_values: Dict[str, Dict[str, Optional[List[str]]]] = {"commands": {}}
for k, v in values.get(key, {}).items():
clean_k = k.strip("/").replace("/", ".")
provider = v.get("provider") if v else None
if isinstance(provider, str):
v["provider"] = [provider]
new_values["commands"][clean_k] = v
return new_values
4 changes: 2 additions & 2 deletions openbb_platform/core/openbb_core/app/provider_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ def map(self) -> MapType:
return self._map

@property
def credentials(self) -> List[str]:
"""Dictionary of required credentials by provider."""
def credentials(self) -> Dict[str, List[str]]:
"""Map providers to credentials."""
return self._registry_map.credentials

@property
Expand Down
63 changes: 52 additions & 11 deletions openbb_platform/core/openbb_core/app/static/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,59 @@ def _run(self, *args, **kwargs) -> Any:
return obbject
return getattr(obbject, "to_" + output_type)()

def _check_credentials(self, provider: str) -> Optional[bool]:
"""Check required credentials are populated."""
credentials = self._command_runner.user_settings.credentials
if provider not in credentials.origins:
return None
required = credentials.origins.get(provider)
return all(getattr(credentials, r, None) for r in required)

def _get_provider(
self, choice: Optional[str], cmd: str, available: Tuple[str, ...]
self, choice: Optional[str], command: str, default_priority: Tuple[str, ...]
) -> str:
"""Get the provider to use in execution."""
"""Get the provider to use in execution.

If no choice is specified, the configured priority list is used. A provider is used
when all of its required credentials are populated.

Parameters
----------
choice: Optional[str]
The provider choice, for example 'fmp'.
command: str
The command to get the provider for, for example 'equity.price.historical'
default_priority: Tuple[str, ...]
A tuple of available providers for the given command to use as default priority list.

Returns
-------
str
The provider to use in the command.

Raises
------
OpenBBError
Raises error when all the providers in the priority list failed.
"""
if choice is None:
if config_default := self._command_runner.user_settings.defaults.routes.get(
cmd, {}
).get("provider"):
if config_default in available:
return config_default
raise OpenBBError(
f"provider '{config_default}' is not available. Choose from: {', '.join(available)}."
)
return available[0]
commands = self._command_runner.user_settings.defaults.commands
providers = (
commands.get(command, {}).get("provider", []) or default_priority
)
tries = []
for p in providers:
result = self._check_credentials(p)
if result:
return p
elif result is False:
tries.append((p, "missing credentials"))
else:
tries.append((p, "not found"))

msg = "\n ".join([f"* '{pair[0]}' -> {pair[1]}" for pair in tries])
raise OpenBBError(
f"Provider fallback failed, please specify the provider or update credentials.\n"
f"[Providers]\n {msg}"
)
return choice
Loading
Loading