Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: [DCS-230] add usa state code and zip code validations #231

Merged
merged 1 commit into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions datachecks/core/datasource/sql_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ def query_string_pattern_validity(
"uuid": r"^[0-9a-f]{8}-[0-9a-f]{4}-[1-5][0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$",
"usa_phone": r"^(\+1[-.\s]?)?(\(?\d{3}\)?[-.\s]?)?\d{3}[-.\s]?\d{4}$",
"email": r"^(?!.*\.\.)(?!.*@.*@)[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$",
"usa_zip_code": r"^[0-9]{5}(?:-[0-9]{4})?$",
}

if not regex_pattern and not predefined_regex_pattern:
Expand Down Expand Up @@ -488,3 +489,83 @@ def query_get_string_length_metric(

result = self.fetchone(query)[0]
return round(result, 2) if metric.lower() == "avg" else result

def query_get_usa_state_code_validity(
self, table: str, field: str, filters: str = None
) -> Tuple[int, int]:
"""
Get the count of valid USA state codes
:param table: table name
:param field: column name
:param filters: filter condition
:return: count of valid state codes, count of total row count
"""
# List of valid state codes
valid_state_codes = [
"AL",
"AK",
"AZ",
"AR",
"CA",
"CO",
"CT",
"DE",
"FL",
"GA",
"HI",
"ID",
"IL",
"IN",
"IA",
"KS",
"KY",
"LA",
"ME",
"MD",
"MA",
"MI",
"MN",
"MS",
"MO",
"MT",
"NE",
"NV",
"NH",
"NJ",
"NM",
"NY",
"NC",
"ND",
"OH",
"OK",
"OR",
"PA",
"RI",
"SC",
"SD",
"TN",
"TX",
"UT",
"VT",
"VA",
"WA",
"WV",
"WI",
"WY",
]

valid_state_codes_str = ", ".join(f"'{code}'" for code in valid_state_codes)

filters = f"WHERE {filters}" if filters else ""

qualified_table_name = self.qualified_table_name(table)

regex_query = f"CASE WHEN {field} ~ '^[A-Z]{{2}}$' AND {field} IN ({valid_state_codes_str}) THEN 1 ELSE 0 END"

query = f"""
SELECT SUM({regex_query}) AS valid_count, COUNT(*) AS total_count
FROM {qualified_table_name} {filters}
"""

result = self.fetchone(query)
return result[0], result[1]
8 changes: 8 additions & 0 deletions datachecks/core/validation/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,17 @@
CountInvalidRegex,
CountInvalidValues,
CountUSAPhoneValidation,
CountUSAStateCodeValidation,
CountUSAZipCodeValidation,
CountUUIDValidation,
CountValidRegex,
CountValidValues,
PercentEmailValidation,
PercentInvalidRegex,
PercentInvalidValues,
PercentUSAPhoneValidation,
PercentUSAStateCodeValidation,
PercentUSAZipCodeValidation,
PercentUUIDValidation,
PercentValidRegex,
PercentValidValues,
Expand Down Expand Up @@ -103,6 +107,10 @@ class ValidationManager:
ValidationFunction.STRING_LENGTH_MAX.value: "StringLengthMaxValidation",
ValidationFunction.STRING_LENGTH_MIN.value: "StringLengthMinValidation",
ValidationFunction.STRING_LENGTH_AVERAGE.value: "StringLengthAverageValidation",
ValidationFunction.COUNT_USA_STATE_CODE.value: "CountUSAStateCodeValidation",
ValidationFunction.PERCENT_USA_STATE_CODE.value: "PercentUSAStateCodeValidation",
ValidationFunction.COUNT_USA_ZIP_CODE.value: "CountUSAZipCodeValidation",
ValidationFunction.PERCENT_USA_ZIP_CODE.value: "PercentUSAZipCodeValidation",
}

def __init__(
Expand Down
68 changes: 68 additions & 0 deletions datachecks/core/validation/validity_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,3 +350,71 @@ def _generate_metric_value(self, **kwargs) -> Union[float, int]:
raise ValueError(
"Unsupported data source type for StringLengthAverageValidation"
)


class CountUSAZipCodeValidation(Validation):
def _generate_metric_value(self, **kwargs) -> Union[float, int]:
if isinstance(self.data_source, SQLDataSource):
valid_count, total_count = self.data_source.query_string_pattern_validity(
table=self.dataset_name,
field=self.field_name,
predefined_regex_pattern="usa_zip_code",
filters=self.where_filter if self.where_filter is not None else None,
)
return valid_count
else:
raise NotImplementedError(
"USA Zip Code validation is only supported for SQL data sources"
)


class PercentUSAZipCodeValidation(Validation):
def _generate_metric_value(self, **kwargs) -> Union[float, int]:
if isinstance(self.data_source, SQLDataSource):
valid_count, total_count = self.data_source.query_string_pattern_validity(
table=self.dataset_name,
field=self.field_name,
predefined_regex_pattern="usa_zip_code",
filters=self.where_filter if self.where_filter is not None else None,
)
return round(valid_count / total_count * 100, 2) if total_count > 0 else 0
else:
raise NotImplementedError(
"USA Zip Code validation is only supported for SQL data sources"
)


class CountUSAStateCodeValidation(Validation):
def _generate_metric_value(self, **kwargs) -> Union[float, int]:
if isinstance(self.data_source, SQLDataSource):
(
valid_count,
total_count,
) = self.data_source.query_get_usa_state_code_validity(
table=self.dataset_name,
field=self.field_name,
filters=self.where_filter if self.where_filter is not None else None,
)
return valid_count
else:
raise NotImplementedError(
"USA State Code validation is only supported for SQL data sources"
)


class PercentUSAStateCodeValidation(Validation):
def _generate_metric_value(self, **kwargs) -> Union[float, int]:
if isinstance(self.data_source, SQLDataSource):
(
valid_count,
total_count,
) = self.data_source.query_get_usa_state_code_validity(
table=self.dataset_name,
field=self.field_name,
filters=self.where_filter if self.where_filter is not None else None,
)
return round(valid_count / total_count * 100, 2) if total_count > 0 else 0
else:
raise NotImplementedError(
"USA State Code validation is only supported for SQL data sources"
)
64 changes: 64 additions & 0 deletions tests/core/configuration/test_configuration_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,3 +529,67 @@ def test_should_parse_string_length_avg_validation():
.get_validation_function
== ValidationFunction.STRING_LENGTH_AVERAGE
)


def test_should_parse_count_usa_zip_code():
yaml_string = """
validations for source.table:
- test:
on: count_usa_zip_code(usa_zip_code)
threshold: "<10"
"""
configuration = load_configuration_from_yaml_str(yaml_string)
assert (
configuration.validations["source.table"]
.validations["test"]
.get_validation_function
== ValidationFunction.COUNT_USA_ZIP_CODE
)


def test_should_parse_percent_usa_zip_code():
yaml_string = """
validations for source.table:
- test:
on: percent_usa_zip_code(usa_zip_code)
threshold: "<10"
"""
configuration = load_configuration_from_yaml_str(yaml_string)
assert (
configuration.validations["source.table"]
.validations["test"]
.get_validation_function
== ValidationFunction.PERCENT_USA_ZIP_CODE
)


def test_should_parse_count_usa_state_code():
yaml_string = """
validations for source.table:
- test:
on: count_usa_state_code(usa_state_code)
threshold: "<10"
"""
configuration = load_configuration_from_yaml_str(yaml_string)
assert (
configuration.validations["source.table"]
.validations["test"]
.get_validation_function
== ValidationFunction.COUNT_USA_STATE_CODE
)


def test_should_parse_percent_usa_state_code():
yaml_string = """
validations for source.table:
- test:
on: percent_usa_state_code(usa_state_code)
threshold: "<10"
"""
configuration = load_configuration_from_yaml_str(yaml_string)
assert (
configuration.validations["source.table"]
.validations["test"]
.get_validation_function
== ValidationFunction.PERCENT_USA_STATE_CODE
)
42 changes: 35 additions & 7 deletions tests/integration/datasource/test_sql_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ def setup_tables(
name VARCHAR(50), last_fight timestamp, age INTEGER,
weight FLOAT, description VARCHAR(100), weapon_id VARCHAR(50),
usa_phone VARCHAR(50),
email VARCHAR(50)
email VARCHAR(50),
usa_state_code VARCHAR(5), usa_zip_code VARCHAR(50)
)
"""
)
Expand All @@ -133,22 +134,22 @@ def setup_tables(
INSERT INTO {self.TABLE_NAME} VALUES
('thor', '{(utc_now - datetime.timedelta(days=10)).strftime("%Y-%m-%d")}',
1500, NULL, 'thor hammer', 'e7194aaa-5516-4362-a5ff-6ff971976bec',
'123-456-7890', 'jane.doe@domain'), -- invalid email
'123-456-7890', 'jane.doe@domain', 'C2', 'ABCDE'), -- invalid email -- invalid usa_state_code -- invalid usa_zip_code
('captain america', '{(utc_now - datetime.timedelta(days=3)).strftime("%Y-%m-%d")}',
90, 80, 'shield', 'e7194aaa-5516-4362-a5ff-6ff971976b', '(123) 456-7890',
'[email protected] '), -- invalid weapon_id --invalid email
'[email protected] ', 'NY', '12-345'), -- invalid weapon_id --invalid email -- invalid usa_zip_code
('iron man', '{(utc_now - datetime.timedelta(days=4)).strftime("%Y-%m-%d")}',
50, 70, 'suit', '1739c676-6108-4dd2-8984-2459df744936', '123 456 7890',
'[email protected]'), -- invalid email
'[email protected]', 'XY', '85001'), -- invalid email -- invalid usa_state_code
('hawk eye', '{(utc_now - datetime.timedelta(days=5)).strftime("%Y-%m-%d")}',
40, 60, 'bow', '1739c676-6108-4dd2-8984-2459df746', '+1 123-456-7890',
'user@@example.com'), -- invalid weapon_id --invalid email
'user@@example.com', 'TX', '30301'), -- invalid weapon_id --invalid email
('clark kent', '{(utc_now - datetime.timedelta(days=6)).strftime("%Y-%m-%d")}',
35, 50, '', '7be61b2c-45dc-4889-97e3-9202e8', '09123.456.7890',
'[email protected]'), -- invalid weapon_id -- invalid phone
'[email protected]', 'ZZ', '123456'), -- invalid weapon_id -- invalid phone -- invalid usa_state_code -- invalid usa_zip_code
('black widow', '{(utc_now - datetime.timedelta(days=6)).strftime("%Y-%m-%d")}',
35, 50, '', '7be61b2c-45dc-4889-97e3-9202e8032c73', '+1 (123) 456-7890',
'[email protected]')
'[email protected]', 'FL', '90210')
"""

postgresql_connection.execute(text(insert_query))
Expand Down Expand Up @@ -412,3 +413,30 @@ def test_should_return_string_length_avg(
metric="avg",
)
assert result == 7.5

def test_should_return_row_count_for_valid_usa_zip_code(
self, postgres_datasource: PostgresDataSource
):
(
valid_count,
total_row_count,
) = postgres_datasource.query_string_pattern_validity(
table=self.TABLE_NAME,
field="usa_zip_code",
predefined_regex_pattern="usa_zip_code",
)
assert valid_count == 3
assert total_row_count == 6

def test_should_return_row_count_for_valid_usa_state_code(
self, postgres_datasource: PostgresDataSource
):
(
valid_count,
total_row_count,
) = postgres_datasource.query_get_usa_state_code_validity(
table=self.TABLE_NAME,
field="usa_state_code",
)
assert valid_count == 3
assert total_row_count == 6
Loading