From 3b0750a3f80a3fcc34b430770bb6875713975ece Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edgar=20Rami=CC=81rez=20Mondrago=CC=81n?= Date: Wed, 29 Jun 2022 21:24:11 -0500 Subject: [PATCH] feat: Validate unsupported config options --- singer_sdk/exceptions.py | 21 +++++++ singer_sdk/plugin_base.py | 5 +- singer_sdk/tap_base.py | 24 +++++--- singer_sdk/typing.py | 11 ++-- tests/core/test_jsonschema_helpers.py | 84 +++++++++++++++++++++------ tests/core/test_streams.py | 55 +++++++++++++++--- 6 files changed, 162 insertions(+), 38 deletions(-) diff --git a/singer_sdk/exceptions.py b/singer_sdk/exceptions.py index eddc7af52f..7834462487 100644 --- a/singer_sdk/exceptions.py +++ b/singer_sdk/exceptions.py @@ -1,10 +1,31 @@ """Defines a common set of exceptions which developers can raise and/or catch.""" + +from __future__ import annotations + import requests class ConfigValidationError(Exception): """Raised when a user's config settings fail validation.""" + def __init__( + self, + message: str, + *, + errors: list[str] | None = None, + warnings: list[str] | None = None, + ) -> None: + """Initialize a ConfigValidationError. + + Args: + message: A message describing the error. + errors: A list of errors which caused the validation error. + warnings: A list of warnings which caused the validation error. + """ + super().__init__(message) + self.errors = errors + self.warnings = warnings + class FatalAPIError(Exception): """Exception raised when a failed request should not be considered retriable.""" diff --git a/singer_sdk/plugin_base.py b/singer_sdk/plugin_base.py index 7ef4d48577..df93dcf7ae 100644 --- a/singer_sdk/plugin_base.py +++ b/singer_sdk/plugin_base.py @@ -250,7 +250,7 @@ def _validate_config( f"JSONSchema was: {config_jsonschema}" ) if raise_errors: - raise ConfigValidationError(summary) + raise ConfigValidationError(summary, errors=errors) log_fn = self.logger.warning else: @@ -259,7 +259,8 @@ def _validate_config( summary += f"\n{warning}" if warnings_as_errors and raise_errors and warnings: raise ConfigValidationError( - f"One or more warnings ocurred during validation: {warnings}" + f"One or more warnings ocurred during validation: {warnings}", + warnings=warnings, ) log_fn(summary) return warnings, errors diff --git a/singer_sdk/tap_base.py b/singer_sdk/tap_base.py index 75f1b481eb..fd31a23611 100644 --- a/singer_sdk/tap_base.py +++ b/singer_sdk/tap_base.py @@ -9,7 +9,7 @@ import click from singer_sdk.cli import common_options -from singer_sdk.exceptions import MaxRecordsLimitException +from singer_sdk.exceptions import ConfigValidationError, MaxRecordsLimitException from singer_sdk.helpers import _state from singer_sdk.helpers._classproperty import classproperty from singer_sdk.helpers._compat import final @@ -453,6 +453,7 @@ def cli( Raises: FileNotFoundError: If the config file does not exist. + Abort: If the configuration is not valid. """ if version: cls.print_version() @@ -486,13 +487,20 @@ def cli( config_files.append(Path(config_path)) - tap = cls( # type: ignore # Ignore 'type not callable' - config=config_files or None, - state=state, - catalog=catalog, - parse_env_config=parse_env_config, - validate_config=validate_config, - ) + try: + tap = cls( # type: ignore # Ignore 'type not callable' + config=config_files or None, + state=state, + catalog=catalog, + parse_env_config=parse_env_config, + validate_config=validate_config, + ) + except ConfigValidationError as exc: + for error in exc.errors: + click.secho(error, fg="red") + for warning in exc.warnings: + click.secho(warning, fg="warning") + raise click.Abort("Configuration is not valid.") if discover: tap.run_discovery() diff --git a/singer_sdk/typing.py b/singer_sdk/typing.py index 77417a6200..fc031ce962 100644 --- a/singer_sdk/typing.py +++ b/singer_sdk/typing.py @@ -403,14 +403,14 @@ class ObjectType(JSONTypeHelper): def __init__( self, *properties: Property, - additional_properties: W | type[W] | None = None, + additional_properties: W | type[W] | bool | None = None, ) -> None: """Initialize ObjectType from its list of properties. Args: properties: Zero or more attributes for this JSON object. additional_properties: A schema to match against unnamed properties in - this object. + this object or a boolean indicating if extra properties are allowed. """ self.wrapped: list[Property] = list(properties) self.additional_properties = additional_properties @@ -433,8 +433,11 @@ def type_dict(self) -> dict: # type: ignore # OK: @classproperty vs @property if required: result["required"] = required - if self.additional_properties: - result["additionalProperties"] = self.additional_properties.type_dict + if self.additional_properties is not None: + if isinstance(self.additional_properties, bool): + result["additionalProperties"] = self.additional_properties + else: + result["additionalProperties"] = self.additional_properties.type_dict return result diff --git a/tests/core/test_jsonschema_helpers.py b/tests/core/test_jsonschema_helpers.py index 175d0b577e..966c311883 100644 --- a/tests/core/test_jsonschema_helpers.py +++ b/tests/core/test_jsonschema_helpers.py @@ -1,7 +1,8 @@ """Test sample sync.""" +from __future__ import annotations + import re -from typing import List import pytest @@ -47,7 +48,7 @@ class ConfigTestTap(Tap): Property("batch_size", IntegerType, default=-1), ).to_dict() - def discover_streams(self) -> List[Stream]: + def discover_streams(self) -> list[Stream]: return [] @@ -291,7 +292,7 @@ def test_array_type(): @pytest.mark.parametrize( - "properties,addtional_properties", + "properties,additional_properties", [ ( [ @@ -311,6 +312,15 @@ def test_array_type(): ], StringType, ), + ( + [ + Property("id", StringType), + Property("email", StringType), + Property("username", StringType), + Property("phone_number", StringType), + ], + False, + ), ( [ Property("id", StringType), @@ -331,6 +341,16 @@ def test_array_type(): ], StringType, ), + ( + [ + Property("id", StringType), + Property("id", StringType), + Property("email", StringType), + Property("username", StringType), + Property("phone_number", StringType), + ], + False, + ), ( [ Property("id", StringType), @@ -349,6 +369,15 @@ def test_array_type(): ], StringType, ), + ( + [ + Property("id", StringType), + Property("email", StringType, True), + Property("username", StringType, True), + Property("phone_number", StringType), + ], + False, + ), ( [ Property("id", StringType), @@ -369,28 +398,49 @@ def test_array_type(): ], StringType, ), + ( + [ + Property("id", StringType), + Property("email", StringType, True), + Property("email", StringType, True), + Property("username", StringType, True), + Property("phone_number", StringType), + ], + False, + ), ], ids=[ - "no requried, no duplicates, no additional properties", - "no requried, no duplicates, additional properties", - "no requried, duplicates, no additional properties", - "no requried, duplicates, additional properties", - "requried, no duplicates, no additional properties", - "requried, no duplicates, additional properties", - "requried, duplicates, no additional properties", - "requried, duplicates, additional properties", + "no required, no duplicates, no additional properties", + "no required, no duplicates, additional properties", + "no required, no duplicates, no additional properties allowed", + "no required, duplicates, no additional properties", + "no required, duplicates, additional properties", + "no required, duplicates, no additional properties allowed", + "required, no duplicates, no additional properties", + "required, no duplicates, additional properties", + "required, no duplicates, no additional properties allowed", + "required, duplicates, no additional properties", + "required, duplicates, additional properties", + "required, duplicates, no additional properties allowed", ], ) -def test_object_type(properties: List[Property], addtional_properties: JSONTypeHelper): +def test_object_type( + properties: list[Property], + additional_properties: JSONTypeHelper | bool, +): merged_property_schemas = { name: schema for p in properties for name, schema in p.to_dict().items() } required = [p.name for p in properties if not p.optional] required_schema = {"required": required} if required else {} - addtional_properties_schema = ( - {"additionalProperties": addtional_properties.type_dict} - if addtional_properties + additional_properties_schema = ( + { + "additionalProperties": additional_properties + if isinstance(additional_properties, bool) + else additional_properties.type_dict + } + if additional_properties is not None else {} ) @@ -398,10 +448,10 @@ def test_object_type(properties: List[Property], addtional_properties: JSONTypeH "type": "object", "properties": merged_property_schemas, **required_schema, - **addtional_properties_schema, + **additional_properties_schema, } - object_type = ObjectType(*properties, additional_properties=addtional_properties) + object_type = ObjectType(*properties, additional_properties=additional_properties) assert object_type.type_dict == expected_json_schema diff --git a/tests/core/test_streams.py b/tests/core/test_streams.py index 2b87dd75bd..66e2850db2 100644 --- a/tests/core/test_streams.py +++ b/tests/core/test_streams.py @@ -1,12 +1,15 @@ """Stream tests.""" +from __future__ import annotations + import logging -from typing import Any, Dict, Iterable, List, Optional, cast +from typing import Any, Iterable, cast import pendulum import pytest import requests +from singer_sdk.exceptions import ConfigValidationError from singer_sdk.helpers._classproperty import classproperty from singer_sdk.helpers.jsonpath import _compile_jsonpath from singer_sdk.streams.core import ( @@ -41,7 +44,7 @@ def __init__(self, tap: Tap): """Create a new stream.""" super().__init__(tap, schema=self.schema, name=self.name) - def get_records(self, context: Optional[dict]) -> Iterable[Dict[str, Any]]: + def get_records(self, context: dict | None) -> Iterable[dict[str, Any]]: """Generate records.""" yield {"id": 1, "value": "Egypt"} yield {"id": 2, "value": "Germany"} @@ -78,9 +81,14 @@ class SimpleTestTap(Tap): """Test tap class.""" name = "test-tap" - settings_jsonschema = PropertiesList(Property("start_date", DateTimeType)).to_dict() + config_jsonschema = PropertiesList( + Property("username", StringType, required=True), + Property("password", StringType, required=True), + Property("start_date", DateTimeType), + additional_properties=False, + ).to_dict() - def discover_streams(self) -> List[Stream]: + def discover_streams(self) -> list[Stream]: """List all streams.""" return [SimpleTestStream(self)] @@ -101,7 +109,11 @@ def tap() -> SimpleTestTap: ] } return SimpleTestTap( - config={"start_date": "2021-01-01"}, + config={ + "username": "utest", + "password": "ptest", + "start_date": "2021-01-01", + }, parse_env_config=False, catalog=catalog_dict, ) @@ -214,7 +226,7 @@ def test_stream_starting_timestamp(tap: SimpleTestTap, stream: SimpleTestStream) ], ) def test_jsonpath_rest_stream( - tap: SimpleTestTap, path: str, content: str, result: List[dict] + tap: SimpleTestTap, path: str, content: str, result: list[dict] ): """Validate records are extracted correctly from the API response.""" fake_response = requests.Response() @@ -370,7 +382,7 @@ def test_sync_costs_calculation(tap: SimpleTestTap, caplog): def calculate_test_cost( request: requests.PreparedRequest, response: requests.Response, - context: Optional[Dict], + context: dict | None, ): return {"dim1": 1, "dim2": 2} @@ -387,3 +399,32 @@ def calculate_test_cost( for record in caplog.records: assert record.levelname == "INFO" assert f"Total Sync costs for stream {stream.name}" in record.message + + +@pytest.mark.parametrize( + "config_dict,errors", + [ + ( + {}, + ["'username' is a required property"], + ), + ( + {"username": "utest"}, + ["'password' is a required property"], + ), + ( + {"username": "utest", "password": "ptest", "extra": "not valid"}, + ["Additional properties are not allowed ('extra' was unexpected)"], + ), + ], + ids=[ + "missing_username", + "missing_password", + "extra_property", + ], +) +def test_config_errors(config_dict: dict, errors: list[str]): + with pytest.raises(ConfigValidationError, match="Config validation failed") as exc: + SimpleTestTap(config_dict, validate_config=True) + + assert exc.value.errors == errors