Skip to content

Commit

Permalink
Add table & column name suggestions to presto validator (pinterest#1330)
Browse files Browse the repository at this point in the history
* Add table & column name suggestions to presto validator
  • Loading branch information
kgopal492 authored and aidenprice committed Jan 3, 2024
1 parent 150906a commit 2edc691
Show file tree
Hide file tree
Showing 7 changed files with 579 additions and 93 deletions.
67 changes: 67 additions & 0 deletions querybook/server/lib/elasticsearch/search_table.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from typing import Dict, List, Tuple
from lib.elasticsearch.query_utils import (
match_filters,
highlight_fields,
order_by_fields,
combine_keyword_and_filter_query,
)
from lib.elasticsearch.search_utils import (
ES_CONFIG,
get_matching_objects,
)

FILTERS_TO_AND = ["tags", "data_elements"]

Expand Down Expand Up @@ -173,3 +178,65 @@ def construct_tables_query_by_table_names(
}

return query


def get_column_name_suggestion(
fuzzy_column_name: str, full_table_names: List[str]
) -> Tuple[Dict, int]:
"""Given an invalid column name and a list of tables to search from, uses fuzzy search to search
the correctly-spelled column name"""
should_clause = []
for full_table_name in full_table_names:
schema_name, table_name = full_table_name.split(".")
should_clause.append(
{
"bool": {
"must": [
{"match": {"name": table_name}},
{"match": {"schema": schema_name}},
]
}
}
)

search_query = {
"query": {
"bool": {
"must": {
"match": {
"columns": {"query": fuzzy_column_name, "fuzziness": "AUTO"}
}
},
"should": should_clause,
"minimum_should_match": 1,
},
},
"highlight": {"pre_tags": [""], "post_tags": [""], "fields": {"columns": {}}},
}

return get_matching_objects(search_query, ES_CONFIG["tables"]["index_name"], True)


def get_table_name_suggestion(fuzzy_table_name: str) -> Tuple[Dict, int]:
"""Given an invalid table name use fuzzy search to search the correctly-spelled table name"""

schema_name, fuzzy_name = None, fuzzy_table_name
fuzzy_table_name_parts = fuzzy_table_name.split(".")
if len(fuzzy_table_name_parts) == 2:
schema_name, fuzzy_name = fuzzy_table_name_parts

must_clause = [
{
"match": {
"name": {"query": fuzzy_name, "fuzziness": "AUTO"},
}
},
]
if schema_name:
must_clause.append({"match": {"schema": schema_name}})

search_query = {
"query": {"bool": {"must": must_clause}},
}

return get_matching_objects(search_query, ES_CONFIG["tables"]["index_name"], True)
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def validate(
query: str,
uid: int, # who is doing the syntax check
engine_id: int, # which engine they are checking against
**kwargs,
) -> List[QueryValidationResult]:
raise NotImplementedError()

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABCMeta, abstractmethod
from typing import List, Tuple
from abc import abstractmethod
from typing import Any, Dict, List, Tuple
from sqlglot import Tokenizer
from sqlglot.tokens import Token

Expand All @@ -8,9 +8,13 @@
QueryValidationResultObjectType,
QueryValidationSeverity,
)
from lib.query_analysis.validation.base_query_validator import BaseQueryValidator


class BaseSQLGlotValidator(metaclass=ABCMeta):
class BaseSQLGlotValidator(BaseQueryValidator):
def __init__(self, name: str = "", config: Dict[str, Any] = {}):
super(BaseSQLGlotValidator, self).__init__(name, config)

@property
@abstractmethod
def message(self) -> str:
Expand All @@ -33,6 +37,12 @@ def _get_query_coordinate_by_index(self, query: str, index: int) -> Tuple[int, i
rows = query[: index + 1].splitlines(keepends=False)
return len(rows) - 1, len(rows[-1]) - 1

def _get_query_index_by_coordinate(
self, query: str, start_line: int, start_ch: int
) -> int:
rows = query.splitlines(keepends=True)[:start_line]
return sum([len(row) for row in rows]) + start_ch

def _get_query_validation_result(
self,
query: str,
Expand All @@ -56,7 +66,28 @@ def _get_query_validation_result(
)

@abstractmethod
def get_query_validation_results(
self, query: str, raw_tokens: List[Token] = None
def validate(
self,
query: str,
uid: int,
engine_id: int,
raw_tokens: List[Token] = None,
**kwargs,
) -> List[QueryValidationResult]:
raise NotImplementedError()


class BaseSQLGlotDecorator(BaseSQLGlotValidator):
def __init__(self, validator: BaseQueryValidator):
self._validator = validator

def validate(
self,
query: str,
uid: int,
engine_id: int,
raw_tokens: List[Token] = None,
**kwargs,
):
"""Override this method to add suggestions to validation results"""
return self._validator.validate(query, uid, engine_id, **kwargs)
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from abc import abstractmethod
from itertools import chain
from typing import List, Optional

from lib.elasticsearch import search_table
from lib.query_analysis.lineage import process_query
from lib.query_analysis.validation.base_query_validator import (
QueryValidationResult,
QueryValidationSeverity,
)
from lib.query_analysis.validation.validators.base_sqlglot_validator import (
BaseSQLGlotDecorator,
)
from logic.admin import get_query_engine_by_id


class BaseColumnNameSuggester(BaseSQLGlotDecorator):
@property
def severity(self):
return QueryValidationSeverity.WARNING # Unused, severity is not changed

@property
def message(self):
return "" # Unused, message is not changed

@abstractmethod
def get_column_name_from_error(
self, validation_result: QueryValidationResult
) -> Optional[str]:
"""Returns invalid column name if the validation result is a column name error, otherwise
returns None"""
raise NotImplementedError()

def _get_tables_in_query(self, query: str, engine_id: int) -> List[str]:
engine = get_query_engine_by_id(engine_id)
tables_per_statement, _ = process_query(query, language=engine.language)
return list(chain.from_iterable(tables_per_statement))

def _search_columns_for_suggestion(self, columns: List[str], suggestion: str):
"""Return the case-sensitive column name by searching the table's columns for the suggestion text"""
for col in columns:
if col.lower() == suggestion.lower():
return col
return suggestion

def _suggest_column_name_if_needed(
self,
validation_result: QueryValidationResult,
tables_in_query: List[str],
):
"""Takes validation result and tables in query to update validation result to provide column
name suggestion"""
fuzzy_column_name = self.get_column_name_from_error(validation_result)
if not fuzzy_column_name:
return
results, count = search_table.get_column_name_suggestion(
fuzzy_column_name, tables_in_query
)
if count == 1: # Only suggest column if there's a single match
table_result = results[0]
highlights = table_result.get("highlight", {}).get("columns", [])
if len(highlights) == 1:
column_suggestion = self._search_columns_for_suggestion(
table_result.get("columns"), highlights[0]
)
validation_result.suggestion = column_suggestion
validation_result.end_line = validation_result.start_line
validation_result.end_ch = (
validation_result.start_ch + len(fuzzy_column_name) - 1
)

def validate(
self,
query: str,
uid: int,
engine_id: int,
raw_tokens: List[QueryValidationResult] = None,
**kwargs,
) -> List[QueryValidationResult]:
if raw_tokens is None:
raw_tokens = self._tokenize_query(query)
validation_results = self._validator.validate(
query, uid, engine_id, raw_tokens=raw_tokens
)
tables_in_query = self._get_tables_in_query(query, engine_id)
for result in validation_results:
self._suggest_column_name_if_needed(result, tables_in_query)
return validation_results


class BaseTableNameSuggester(BaseSQLGlotDecorator):
@property
def severity(self):
return QueryValidationSeverity.WARNING # Unused, severity is not changed

@property
def message(self):
return "" # Unused, message is not changed

@abstractmethod
def get_full_table_name_from_error(self, validation_result: QueryValidationResult):
"""Returns invalid table name if the validation result is a table name error, otherwise
returns None"""
raise NotImplementedError()

def _suggest_table_name_if_needed(
self, validation_result: QueryValidationResult
) -> Optional[str]:
"""Takes validation result and tables in query to update validation result to provide table
name suggestion"""
fuzzy_table_name = self.get_full_table_name_from_error(validation_result)
if not fuzzy_table_name:
return
results, count = search_table.get_table_name_suggestion(fuzzy_table_name)
if count > 0:
table_result = results[0] # Get top match
table_suggestion = f"{table_result['schema']}.{table_result['name']}"
validation_result.suggestion = table_suggestion
validation_result.end_line = validation_result.start_line
validation_result.end_ch = (
validation_result.start_ch + len(fuzzy_table_name) - 1
)

def validate(
self,
query: str,
uid: int,
engine_id: int,
raw_tokens: List[QueryValidationResult] = None,
**kwargs,
) -> List[QueryValidationResult]:
if raw_tokens is None:
raw_tokens = self._tokenize_query(query)
validation_results = self._validator.validate(
query, uid, engine_id, raw_tokens=raw_tokens
)
for result in validation_results:
self._suggest_table_name_if_needed(result)
return validation_results
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def validate(
query: str,
uid: int, # who is doing the syntax check
engine_id: int, # which engine they are checking against
**kwargs,
) -> List[QueryValidationResult]:
validation_errors = []
(
Expand Down
Loading

0 comments on commit 2edc691

Please sign in to comment.