Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed issue where config with * could not be filled #53

Merged
merged 6 commits into from
Dec 31, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion confection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,7 @@ def alias_generator(name: str) -> str:
return name



def copy_model_field(field: ModelField, type_: Any) -> ModelField:
"""Copy a model field and assign a new type, e.g. to accept an Any type
even though the original value is typed differently.
Expand All @@ -704,6 +705,7 @@ def copy_model_field(field: ModelField, type_: Any) -> ModelField:
default=field.default,
default_factory=field.default_factory,
required=field.required,
alias=field.alias,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This avoids dropping the previous alias for *

)


Expand Down Expand Up @@ -912,6 +914,15 @@ def _fill(
# created via config blocks), only use its values
validation[v_key] = list(validation[v_key].values())
final[key] = list(final[key].values())

if ARGS_FIELD_ALIAS in schema.__fields__ and not resolve:
# If we're not resolving the config, make sure that the field
# expecting the promise is typed Any so it doesn't fail
# validation if it doesn't receive the function return value
field = schema.__fields__[ARGS_FIELD_ALIAS]
schema.__fields__[ARGS_FIELD_ALIAS] = copy_model_field(
field, Any
)
Comment on lines +921 to +924
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a copy of a logic appearing earlier (if it is a promise). It seems like it might be possible to move it outside the 'if promise'-scope.

else:
filled[key] = value
# Prevent pydantic from consuming generator if part of a union
Expand All @@ -936,7 +947,12 @@ def _fill(
# manually because .construct doesn't parse anything
if schema.Config.extra in (Extra.forbid, Extra.ignore):
fields = schema.__fields__.keys()
exclude = [k for k in result.__fields_set__ if k not in fields]
# If we have a reserved field, we need to use its alias
field_set = [
k if k != ARGS_FIELD else ARGS_FIELD_ALIAS for k in result.__fields_set__
]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If validate=False this avoid removing the "*" key.

# field_set = result.__fields_set__
exclude = [k for k in field_set if k not in fields]
exclude_validation = set([ARGS_FIELD_ALIAS, *RESERVED_FIELDS.keys()])
validation.update(result.dict(exclude=exclude_validation))
filled, final = cls._update_from_parsed(validation, filled, final)
Expand Down
25 changes: 25 additions & 0 deletions confection/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,31 @@ def catsie_567(*args: Optional[str], foo: str = "bar"):
assert my_registry.resolve(cfg)["config"] == "^_^"


def test_fill_config_positional_args_w_promise():
@my_registry.cats("catsie.v568")
def catsie_568(*args: str, foo: str = "bar"):
assert args[0] == "^(*.*)^"
assert foo == "baz"
return args[0]

@my_registry.cats("cat_promise.v568")
def cat_promise() -> str:
return "^(*.*)^"

cfg = {
"test_fn": {"@cats": "catsie.v568", "*": {"test_arg": {"@registry": "factory"}}}
}
adrianeboyd marked this conversation as resolved.
Show resolved Hide resolved
cfg = {
"config": {
"@cats": "catsie.v568",
"*": {"promise": {"@cats": "cat_promise.v568"}},
}
}
filled = my_registry.fill(cfg, validate=True)
assert filled["config"]["foo"] == "bar"
assert filled["config"]["*"] == {"promise": {"@cats": "cat_promise.v568"}}


def test_make_config_positional_args_complex():
@my_registry.cats("catsie.v890")
def catsie_890(*args: Optional[Union[StrictBool, PositiveInt]]):
Expand Down
Loading