Skip to content

Commit

Permalink
Merge branch 'develop' into bugfix/pywry-plot-paperbg
Browse files Browse the repository at this point in the history
  • Loading branch information
deeleeramone authored May 9, 2024
2 parents ea36be7 + 0b9b12d commit 9300f08
Show file tree
Hide file tree
Showing 104 changed files with 760 additions and 332 deletions.
26 changes: 18 additions & 8 deletions cli/openbb_cli/argparse_translator/argparse_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,9 @@ def __init__(
def _handle_argument_in_groups(self, argument, group):
"""Handle the argument and add it to the parser."""

def _in_optional_arguments(arg):
def _in_group(arg, group_title):
for action_group in self._parser._action_groups:
if action_group.title == "optional arguments":
if action_group.title == group_title:
for action in action_group._group_actions:
opts = action.option_strings
if (opts and opts[0] == arg) or action.dest == arg:
Expand Down Expand Up @@ -286,16 +286,26 @@ def _update_providers(
# extend choices
choices = tuple(set(_get_arg_choices(argument.name) + model_choices))

# check if the argument is in the required arguments
if _in_group(argument.name, group_title="required arguments"):
for action in self._required._group_actions:
if action.dest == argument.name and choices:
# update choices
action.choices = choices
return

# check if the argument is in the optional arguments
if _in_optional_arguments(argument.name):
if _in_group(argument.name, group_title="optional arguments"):
for action in self._parser._actions:
if action.dest == argument.name:
# update choices
action.choices = choices
# update help
action.help = _update_providers(
action.help or "", [group.title]
)
if choices:
action.choices = choices
if argument.name not in self.signature.parameters:
# update help
action.help = _update_providers(
action.help or "", [group.title]
)
return

# if the argument is in use, remove it from all groups
Expand Down
29 changes: 18 additions & 11 deletions openbb_platform/core/openbb_core/app/provider_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,17 +253,24 @@ def _create_field(
annotation = field.annotation

additional_description = ""
if (extra := field.json_schema_extra) and (
multiple := extra.get("multiple_items_allowed") # type: ignore
):
if provider_name:
additional_description += " Multiple comma separated items allowed."
else:
additional_description += (
" Multiple comma separated items allowed for provider(s): "
+ ", ".join(multiple) # type: ignore[arg-type]
+ "."
)
if extra := field.json_schema_extra:
providers = []
for p, v in extra.items(): # type: ignore[union-attr]
if isinstance(v, dict) and v.get("multiple_items_allowed"):
providers.append(p)
elif isinstance(v, list) and "multiple_items_allowed" in v:
# For backwards compatibility, before this was a list
providers.append(p)

if providers:
if provider_name:
additional_description += " Multiple comma separated items allowed."
else:
additional_description += (
" Multiple comma separated items allowed for provider(s): "
+ ", ".join(providers) # type: ignore[arg-type]
+ "."
)

provider_field = (
f"(provider: {provider_name})" if provider_name != "openbb" else ""
Expand Down
32 changes: 25 additions & 7 deletions openbb_platform/core/openbb_core/app/static/package_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,15 @@ def get_expanded_type(
original_type: Optional[type] = None,
) -> object:
"""Expand the original field type."""
if extra and "multiple_items_allowed" in extra:
if extra and any(
(
v.get("multiple_items_allowed")
if isinstance(v, dict)
# For backwards compatibility, before this was a list
else "multiple_items_allowed" in v
)
for v in extra.values()
):
if original_type is None:
raise ValueError(
"multiple_items_allowed requires the original type to be specified."
Expand Down Expand Up @@ -1450,6 +1458,10 @@ def _get_provider_field_params(
expanded_types = MethodDefinition.TYPE_EXPANSION
model_map = cls.pi.map[model]

# TODO: Change this to read the package data instead of pi.map directly
# We change some items (types, descriptions), so the reference.json
# does not reflect entirely the package code.

for field, field_info in model_map[provider][params_type]["fields"].items():
# Determine the field type, expanding it if necessary and if params_type is "Parameters"
field_type = field_info.annotation
Expand All @@ -1470,12 +1482,18 @@ def _get_provider_field_params(
) # fmt: skip

# Add information for the providers supporting multiple symbols
if params_type == "QueryParams" and field_info.json_schema_extra:
multiple_items_list = field_info.json_schema_extra.get(
"multiple_items_allowed", None
)
if multiple_items_list:
multiple_items = ", ".join(multiple_items_list)
if params_type == "QueryParams" and (extra := field_info.json_schema_extra):

providers = []
for p, v in extra.items(): # type: ignore[union-attr]
if isinstance(v, dict) and v.get("multiple_items_allowed"):
providers.append(p)
elif isinstance(v, list) and "multiple_items_allowed" in v:
# For backwards compatibility, before this was a list
providers.append(p)

if providers:
multiple_items = ", ".join(providers)
cleaned_description += (
f" Multiple items allowed for provider(s): {multiple_items}."
)
Expand Down
54 changes: 32 additions & 22 deletions openbb_platform/core/openbb_core/app/static/utils/filters.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""OpenBB filters."""

from typing import Dict, List, Optional
from typing import Any, Dict, Optional

from openbb_core.app.utils import check_single_item, convert_to_basemodel


def filter_inputs(
data_processing: bool = False,
info: Optional[Dict[str, Dict[str, List[str]]]] = None,
info: Optional[Dict[str, Dict[str, Any]]] = None,
**kwargs,
) -> dict:
"""Filter command inputs."""
Expand All @@ -16,32 +16,42 @@ def filter_inputs(
kwargs[key] = convert_to_basemodel(value)

if info:
PROPERTY = "multiple_items_allowed"

# Here we check if list items are passed and multiple items allowed for
# the given provider/input combination. In that case we transform the list
# into a comma-separated string
for field, props in info.items():
if PROPERTY in props and (
provider := kwargs.get("provider_choices", {}).get("provider")
):
for p in ("standard_params", "extra_params"):
if field in kwargs.get(p, {}):
current = kwargs[p][field]
new = (
",".join(map(str, current))
if isinstance(current, list)
else current
provider = kwargs.get("provider_choices", {}).get("provider")
for field, properties in info.items():

for p in ("standard_params", "extra_params"):
if field in kwargs.get(p, {}):
current = kwargs[p][field]
new = (
",".join(map(str, current))
if isinstance(current, list)
else current
)

provider_properties = properties.get(provider, {})
if isinstance(provider_properties, dict):
multiple_items_allowed = provider_properties.get(
"multiple_items_allowed"
)
elif isinstance(provider_properties, list):
# For backwards compatibility, before this was a list
multiple_items_allowed = (
"multiple_items_allowed" in provider_properties
)
else:
multiple_items_allowed = True

if provider and provider not in props[PROPERTY]:
check_single_item(
new,
f"{field} -> multiple items not allowed for '{provider}'",
)
if not multiple_items_allowed:
check_single_item(
new,
f"{field} -> multiple items not allowed for '{provider}'",
)

kwargs[p][field] = new
break
kwargs[p][field] = new
break
else:
provider = kwargs.get("provider_choices", {}).get("provider")
for param_category in ("standard_params", "extra_params"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,24 @@ class QueryParams(BaseModel):
Merge different json schema extra, identified by provider.
Example:
FMP fetcher:
__json_schema_extra__ = {"symbol": ["multiple_items_allowed"]}
__json_schema_extra__ = {"symbol": {"multiple_items_allowed": True}}
Intrinio fetcher
__json_schema_extra__ = {"symbol": ["multiple_items_allowed"]}
__json_schema_extra__ = {"symbol": {"multiple_items_allowed": False}}
Creates a new field in the `symbol` schema with:
Creates new fields in the `symbol` schema:
{
...,
"multiple_items_allowed": ["fmp", "intrinio"],
"type": "string",
"description": "Symbol to get data for.",
"fmp": {"multiple_items_allowed": True},
"intrinio": {"multiple_items_allowed": False}
...,
}
Multiple fields can be tagged with the same or multiple properties.
Example:
__json_schema_extra__ = {
"<field_name_A>": ["some_prop", "another_prop"],
"<field_name_B>": ["yet_another_prop"]
"<field_name_A>": {"foo": 123, "bar": 456},
"<field_name_B>": {"foo": 789}
}
Attributes:
Expand Down
81 changes: 41 additions & 40 deletions openbb_platform/core/openbb_core/provider/registry_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,38 +104,34 @@ def _get_maps(self, registry: Registry) -> Tuple[MapType, Dict[str, Dict]]:
}
)

self._merge_json_schema_extra(p, fetcher, standard_extra[model_name])
self._update_json_schema_extra(p, fetcher, standard_extra[model_name])

return standard_extra, original_models

def _merge_json_schema_extra(
def _update_json_schema_extra(
self,
provider: str,
fetcher: Fetcher,
model_map: dict,
):
"""Merge json schema extra for different providers."""
model: BaseModel = RegistryMap._get_model(fetcher, "query_params")
std_fields = model_map["openbb"]["QueryParams"]["fields"]
standard_fields = model_map["openbb"]["QueryParams"]["fields"]
extra_fields = model_map[provider]["QueryParams"]["fields"]
for f, props in getattr(model, "__json_schema_extra__", {}).items():
for p in props:
if f in std_fields:
model_field = std_fields[f]
elif f in extra_fields:
model_field = extra_fields[f]

for field, properties in getattr(model, "__json_schema_extra__", {}).items():
if properties:
if field in standard_fields:
model_field = standard_fields[field]
elif field in extra_fields:
model_field = extra_fields[field]
else:
continue

if model_field.json_schema_extra is None:
model_field.json_schema_extra = {}

if p not in model_field.json_schema_extra:
model_field.json_schema_extra[p] = []

providers = model_field.json_schema_extra[p]
if provider not in providers:
providers.append(provider)
model_field.json_schema_extra[provider] = properties

def _get_models(self, map_: MapType) -> List[str]:
"""Get available models."""
Expand All @@ -152,33 +148,38 @@ def _extract_info(
) -> tuple:
"""Extract info (fields and docstring) from fetcher query params or data."""
model: BaseModel = RegistryMap._get_model(fetcher, type_)
all_fields = {}
standard_info: Dict[str, Any] = {"fields": {}, "docstring": None}
found_top_level = False
extra_info: Dict[str, Any] = {"fields": {}, "docstring": model.__doc__}
found_first_standard = False

for c in RegistryMap._class_hierarchy(model):
if c.__name__ in SKIP:
family = RegistryMap._get_class_family(model)
for i, child in enumerate(family):
if child.__name__ in SKIP:
continue
if (Path(getfile(c)).parent == STANDARD_MODELS_FOLDER) or found_top_level:
if not found_top_level:
# We might update the standard_info more than once to account for
# nested standard models, but we only want to update the docstring
# once with the __doc__ of the top-level standard model.
standard_info["docstring"] = c.__doc__
found_top_level = True
standard_info["fields"].update(c.model_fields)
else:
all_fields.update(c.model_fields)

extra_info: Dict[str, Any] = {
"fields": {},
"docstring": model.__doc__,
}

# We ignore fields that are already in the standard model
for name, field in all_fields.items():
if name not in standard_info["fields"]:
extra_info["fields"][name] = field
parent = family[i + 1] if family[i + 1] not in SKIP else BaseModel

fields = {
name: field
for name, field in child.model_fields.items()
# This ensures fields inherited by c are discarded.
# We need to compare child and parent __annotations__
# because this attribute is redirected to the parent class
# when the child simply inherits the parent and does not
# define any attributes.
# TLDR: Only fields defined in c are included
if name in child.__annotations__
and child.__annotations__ is not parent.__annotations__
}

if Path(getfile(child)).parent == STANDARD_MODELS_FOLDER:
if not found_first_standard:
# If standard uses inheritance we just use the first docstring
standard_info["docstring"] = child.__doc__
found_first_standard = True
standard_info["fields"].update(fields)
else:
extra_info["fields"].update(fields)

return standard_info, extra_info

Expand All @@ -204,6 +205,6 @@ def _validate(model: Any, type_: Literal["query_params", "data"]) -> None:
)

@staticmethod
def _class_hierarchy(class_) -> tuple:
"""Return the class hierarchy starting with the class itself until `object`."""
def _get_class_family(class_) -> tuple:
"""Return the class family starting with the class itself until `object`."""
return getattr(class_, "__mro__", ())
Loading

0 comments on commit 9300f08

Please sign in to comment.