From a6543b1b9c2ec190004d2a4bcc91579b317c778b Mon Sep 17 00:00:00 2001 From: zaz968m Date: Fri, 18 Sep 2020 15:39:38 +0200 Subject: [PATCH] Refactor Access Denied Exception testing --- .../testing/AbstractTestQueryFramework.java | 38 +++++++------------ 1 file changed, 14 insertions(+), 24 deletions(-) diff --git a/presto-testing/src/main/java/io/prestosql/testing/AbstractTestQueryFramework.java b/presto-testing/src/main/java/io/prestosql/testing/AbstractTestQueryFramework.java index 7d0ba1068780a..427b32a26597e 100644 --- a/presto-testing/src/main/java/io/prestosql/testing/AbstractTestQueryFramework.java +++ b/presto-testing/src/main/java/io/prestosql/testing/AbstractTestQueryFramework.java @@ -30,7 +30,6 @@ import io.prestosql.metadata.Metadata; import io.prestosql.operator.OperatorStats; import io.prestosql.spi.QueryId; -import io.prestosql.spi.security.AccessDeniedException; import io.prestosql.spi.type.Type; import io.prestosql.sql.analyzer.FeaturesConfig; import io.prestosql.sql.analyzer.FeaturesConfig.JoinDistributionType; @@ -66,18 +65,16 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; -import static com.google.common.base.Strings.nullToEmpty; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.prestosql.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE; import static io.prestosql.SystemSessionProperties.JOIN_REORDERING_STRATEGY; import static io.prestosql.sql.ParsingUtil.createParsingOptions; import static io.prestosql.sql.SqlFormatter.formatSql; import static io.prestosql.transaction.TransactionBuilder.transaction; -import static java.lang.String.format; import static java.util.Collections.emptyList; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.testng.Assert.assertEquals; -import static org.testng.Assert.fail; public abstract class AbstractTestQueryFramework { @@ -275,6 +272,11 @@ protected void assertAccessAllowed(@Language("SQL") String sql, TestingPrivilege } protected void assertAccessAllowed(Session session, @Language("SQL") String sql, TestingPrivilege... deniedPrivileges) + { + executeExclusively(session, sql, deniedPrivileges); + } + + private void executeExclusively(Session session, @Language("SQL") String sql, TestingPrivilege[] deniedPrivileges) { executeExclusively(() -> { try { @@ -298,19 +300,14 @@ protected void assertAccessDenied( @Language("RegExp") String exceptionsMessageRegExp, TestingPrivilege... deniedPrivileges) { - executeExclusively(() -> { - try { - queryRunner.getAccessControl().deny(deniedPrivileges); - queryRunner.execute(session, sql); - fail("Expected " + AccessDeniedException.class.getSimpleName()); - } - catch (RuntimeException e) { - assertExceptionMessage(sql, e, ".*Access Denied: " + exceptionsMessageRegExp); - } - finally { - queryRunner.getAccessControl().reset(); - } - }); + assertException(session, sql, ".*Access Denied: " + exceptionsMessageRegExp, deniedPrivileges); + } + + private void assertException(Session session, @Language("SQL") String sql, @Language("RegExp") String exceptionsMessageRegExp, TestingPrivilege[] deniedPrivileges) + { + assertThatThrownBy(() -> executeExclusively(session, sql, deniedPrivileges)) + .as("Query: " + sql) + .hasMessageMatching(exceptionsMessageRegExp); } protected void assertTableColumnNames(String tableName, String... columnNames) @@ -323,13 +320,6 @@ protected void assertTableColumnNames(String tableName, String... columnNames) assertEquals(actual, expected); } - private static void assertExceptionMessage(String sql, Exception exception, @Language("RegExp") String regex) - { - if (!nullToEmpty(exception.getMessage()).matches(regex)) { - fail(format("Expected exception message '%s' to match '%s' for query: %s", exception.getMessage(), regex, sql), exception); - } - } - protected MaterializedResult computeExpected(@Language("SQL") String sql, List resultTypes) { return h2QueryRunner.execute(getSession(), sql, resultTypes);