diff --git a/confection/__init__.py b/confection/__init__.py index 9c6922c..0395498 100644 --- a/confection/__init__.py +++ b/confection/__init__.py @@ -704,6 +704,7 @@ def copy_model_field(field: ModelField, type_: Any) -> ModelField: default=field.default, default_factory=field.default_factory, required=field.required, + alias=field.alias, ) @@ -912,6 +913,15 @@ def _fill( # created via config blocks), only use its values validation[v_key] = list(validation[v_key].values()) final[key] = list(final[key].values()) + + if ARGS_FIELD_ALIAS in schema.__fields__ and not resolve: + # If we're not resolving the config, make sure that the field + # expecting the promise is typed Any so it doesn't fail + # validation if it doesn't receive the function return value + field = schema.__fields__[ARGS_FIELD_ALIAS] + schema.__fields__[ARGS_FIELD_ALIAS] = copy_model_field( + field, Any + ) else: filled[key] = value # Prevent pydantic from consuming generator if part of a union @@ -936,7 +946,12 @@ def _fill( # manually because .construct doesn't parse anything if schema.Config.extra in (Extra.forbid, Extra.ignore): fields = schema.__fields__.keys() - exclude = [k for k in result.__fields_set__ if k not in fields] + # If we have a reserved field, we need to use its alias + field_set = [ + k if k != ARGS_FIELD else ARGS_FIELD_ALIAS + for k in result.__fields_set__ + ] + exclude = [k for k in field_set if k not in fields] exclude_validation = set([ARGS_FIELD_ALIAS, *RESERVED_FIELDS.keys()]) validation.update(result.dict(exclude=exclude_validation)) filled, final = cls._update_from_parsed(validation, filled, final) diff --git a/confection/tests/test_config.py b/confection/tests/test_config.py index 5873eb2..58600cf 100644 --- a/confection/tests/test_config.py +++ b/confection/tests/test_config.py @@ -424,6 +424,28 @@ def catsie_567(*args: Optional[str], foo: str = "bar"): assert my_registry.resolve(cfg)["config"] == "^_^" +def test_fill_config_positional_args_w_promise(): + @my_registry.cats("catsie.v568") + def catsie_568(*args: str, foo: str = "bar"): + assert args[0] == "^(*.*)^" + assert foo == "baz" + return args[0] + + @my_registry.cats("cat_promise.v568") + def cat_promise() -> str: + return "^(*.*)^" + + cfg = { + "config": { + "@cats": "catsie.v568", + "*": {"promise": {"@cats": "cat_promise.v568"}}, + } + } + filled = my_registry.fill(cfg, validate=True) + assert filled["config"]["foo"] == "bar" + assert filled["config"]["*"] == {"promise": {"@cats": "cat_promise.v568"}} + + def test_make_config_positional_args_complex(): @my_registry.cats("catsie.v890") def catsie_890(*args: Optional[Union[StrictBool, PositiveInt]]):