Skip to content

Commit

Permalink
Refactor Access Denied Exception testing
Browse files Browse the repository at this point in the history
  • Loading branch information
ssheikin authored and kokosing committed Sep 30, 2020
1 parent 738406d commit a6543b1
Showing 1 changed file with 14 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import io.prestosql.metadata.Metadata;
import io.prestosql.operator.OperatorStats;
import io.prestosql.spi.QueryId;
import io.prestosql.spi.security.AccessDeniedException;
import io.prestosql.spi.type.Type;
import io.prestosql.sql.analyzer.FeaturesConfig;
import io.prestosql.sql.analyzer.FeaturesConfig.JoinDistributionType;
Expand Down Expand Up @@ -66,18 +65,16 @@

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Strings.nullToEmpty;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.prestosql.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE;
import static io.prestosql.SystemSessionProperties.JOIN_REORDERING_STRATEGY;
import static io.prestosql.sql.ParsingUtil.createParsingOptions;
import static io.prestosql.sql.SqlFormatter.formatSql;
import static io.prestosql.transaction.TransactionBuilder.transaction;
import static java.lang.String.format;
import static java.util.Collections.emptyList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.fail;

public abstract class AbstractTestQueryFramework
{
Expand Down Expand Up @@ -275,6 +272,11 @@ protected void assertAccessAllowed(@Language("SQL") String sql, TestingPrivilege
}

protected void assertAccessAllowed(Session session, @Language("SQL") String sql, TestingPrivilege... deniedPrivileges)
{
executeExclusively(session, sql, deniedPrivileges);
}

private void executeExclusively(Session session, @Language("SQL") String sql, TestingPrivilege[] deniedPrivileges)
{
executeExclusively(() -> {
try {
Expand All @@ -298,19 +300,14 @@ protected void assertAccessDenied(
@Language("RegExp") String exceptionsMessageRegExp,
TestingPrivilege... deniedPrivileges)
{
executeExclusively(() -> {
try {
queryRunner.getAccessControl().deny(deniedPrivileges);
queryRunner.execute(session, sql);
fail("Expected " + AccessDeniedException.class.getSimpleName());
}
catch (RuntimeException e) {
assertExceptionMessage(sql, e, ".*Access Denied: " + exceptionsMessageRegExp);
}
finally {
queryRunner.getAccessControl().reset();
}
});
assertException(session, sql, ".*Access Denied: " + exceptionsMessageRegExp, deniedPrivileges);
}

private void assertException(Session session, @Language("SQL") String sql, @Language("RegExp") String exceptionsMessageRegExp, TestingPrivilege[] deniedPrivileges)
{
assertThatThrownBy(() -> executeExclusively(session, sql, deniedPrivileges))
.as("Query: " + sql)
.hasMessageMatching(exceptionsMessageRegExp);
}

protected void assertTableColumnNames(String tableName, String... columnNames)
Expand All @@ -323,13 +320,6 @@ protected void assertTableColumnNames(String tableName, String... columnNames)
assertEquals(actual, expected);
}

private static void assertExceptionMessage(String sql, Exception exception, @Language("RegExp") String regex)
{
if (!nullToEmpty(exception.getMessage()).matches(regex)) {
fail(format("Expected exception message '%s' to match '%s' for query: %s", exception.getMessage(), regex, sql), exception);
}
}

protected MaterializedResult computeExpected(@Language("SQL") String sql, List<? extends Type> resultTypes)
{
return h2QueryRunner.execute(getSession(), sql, resultTypes);
Expand Down

0 comments on commit a6543b1

Please sign in to comment.