Skip to content

Commit

Permalink
Add validation for unsupported type/identifier/commands
Browse files Browse the repository at this point in the history
Signed-off-by: Tomoyuki Morita <[email protected]>
  • Loading branch information
ykmr1224 committed Dec 5, 2024
1 parent b6846ce commit b462d2c
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropTableContext;
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropTablePartitionsContext;
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropViewContext;
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ErrorCapturingIdentifierContext;
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ErrorCapturingIdentifierExtraContext;
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ExplainContext;
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.FunctionNameContext;
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.HintContext;
Expand All @@ -43,6 +45,7 @@
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.JoinRelationContext;
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.JoinTypeContext;
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.LateralViewContext;
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.LiteralTypeContext;
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.LoadDataContext;
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ManageResourceContext;
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.QueryOrganizationContext;
Expand Down Expand Up @@ -77,7 +80,9 @@
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.TableValuedFunctionContext;
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.TransformClauseContext;
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.TruncateTableContext;
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.TypeContext;
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.UncacheTableContext;
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.UnsupportedHiveNativeCommandsContext;
import org.opensearch.sql.spark.antlr.parser.SqlBaseParserBaseVisitor;

/** This visitor validate grammar using GrammarElementValidator */
Expand Down Expand Up @@ -584,4 +589,34 @@ private void validateAllowed(SQLGrammarElement element) {
throw new IllegalArgumentException(element + " is not allowed.");
}
}

@Override
public Void visitErrorCapturingIdentifier(ErrorCapturingIdentifierContext ctx) {
ErrorCapturingIdentifierExtraContext extra = ctx.errorCapturingIdentifierExtra();
if (extra.children != null) {
throw new IllegalArgumentException("Invalid identifier: " + ctx.getText());
}
return super.visitErrorCapturingIdentifier(ctx);
}

@Override
public Void visitLiteralType(LiteralTypeContext ctx) {
if (ctx.unsupportedType != null) {
throw new IllegalArgumentException("Unsupported typed literal: " + ctx.getText());
}
return super.visitLiteralType(ctx);
}

@Override
public Void visitType(TypeContext ctx) {
if (ctx.unsupportedType != null) {
throw new IllegalArgumentException("Unsupported data type: " + ctx.getText());
}
return super.visitType(ctx);
}

@Override
public Void visitUnsupportedHiveNativeCommands(UnsupportedHiveNativeCommandsContext ctx) {
throw new IllegalArgumentException("Unsupported command.");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,56 @@ void testValidateFlintExtensionQuery() {
UUID.randomUUID().toString(), DataSourceType.SECURITY_LAKE));
}

@Test
void testInvalidIdentifier() {
when(mockedProvider.getValidatorForDatasource(any())).thenReturn(element -> true);
VerifyValidator v = new VerifyValidator(sqlQueryValidator, DataSourceType.SPARK);
v.ng("SELECT a.b.c as a-b-c FROM abc");
v.ok("SELECT a.b.c as `a-b-c` FROM abc");
v.ok("SELECT a.b.c as a_b_c FROM abc");

v.ng("SELECT a.b.c FROM a-b-c");
v.ng("SELECT a.b.c FROM a.b-c");
v.ok("SELECT a.b.c FROM b.c.`a-b-c`");
v.ok("SELECT a.b.c FROM `a-b-c`");
}

@Test
void testUnsupportedType() {
when(mockedProvider.getValidatorForDatasource(any())).thenReturn(element -> true);
VerifyValidator v = new VerifyValidator(sqlQueryValidator, DataSourceType.SPARK);

v.ng("SELECT cast ( a as DateTime ) FROM tbl");
v.ok("SELECT cast ( a as DATE ) FROM tbl");
v.ok("SELECT cast ( a as Date ) FROM tbl");
v.ok("SELECT cast ( a as Timestamp ) FROM tbl");
}

@Test
void testUnsupportedTypedLiteral() {
when(mockedProvider.getValidatorForDatasource(any())).thenReturn(element -> true);
VerifyValidator v = new VerifyValidator(sqlQueryValidator, DataSourceType.SPARK);

v.ng("SELECT DATETIME '2024-10-11'");
v.ok("SELECT DATE '2024-10-11'");
v.ok("SELECT TIMESTAMP '2024-10-11'");
}

@Test
void testUnsupportedHiveNativeCommand() {
when(mockedProvider.getValidatorForDatasource(any())).thenReturn(element -> true);
VerifyValidator v = new VerifyValidator(sqlQueryValidator, DataSourceType.SPARK);

v.ng("CREATE ROLE aaa");
v.ng("SHOW GRANT");
v.ng("EXPORT TABLE");
v.ng("ALTER TABLE aaa NOT CLUSTERED");
v.ng("START TRANSACTION");
v.ng("COMMIT");
v.ng("ROLLBACK");
v.ng("DFS");
}

@AllArgsConstructor
private static class VerifyValidator {
private final SQLQueryValidator validator;
Expand All @@ -580,10 +630,18 @@ public void ok(TestElement query) {
runValidate(query.getQueries());
}

public void ok(String query) {
runValidate(query);
}

public void ng(TestElement query) {
Arrays.stream(query.getQueries()).forEach(this::ng);
}

public void ng(String query) {
assertThrows(
IllegalArgumentException.class,
() -> runValidate(query.getQueries()),
() -> runValidate(query),
"The query should throw: query=`" + query.toString() + "`");
}

Expand Down

0 comments on commit b462d2c

Please sign in to comment.