From 3399d63ac6d6f89b5efb929ba5d5903614f758ed Mon Sep 17 00:00:00 2001 From: Martin Traverso Date: Sat, 14 Sep 2019 18:42:25 -0700 Subject: [PATCH] Add option to set default catalog and schema --- .../server/QuerySessionSupplier.java | 20 ++++- .../prestosql/sql/SqlEnvironmentConfig.java | 28 ++++++ .../server/TestQuerySessionSupplier.java | 90 ++++++++++++++++--- .../sql/TestSqlEnvironmentConfig.java | 6 ++ .../prestosql/tests/tpch/TpchQueryRunner.java | 7 +- 5 files changed, 138 insertions(+), 13 deletions(-) diff --git a/presto-main/src/main/java/io/prestosql/server/QuerySessionSupplier.java b/presto-main/src/main/java/io/prestosql/server/QuerySessionSupplier.java index 5c5ec4aa919ae..37ed0e7a2cd8b 100644 --- a/presto-main/src/main/java/io/prestosql/server/QuerySessionSupplier.java +++ b/presto-main/src/main/java/io/prestosql/server/QuerySessionSupplier.java @@ -30,6 +30,7 @@ import java.util.Map; import java.util.Optional; +import static com.google.common.base.Preconditions.checkArgument; import static io.prestosql.Session.SessionBuilder; import static io.prestosql.spi.type.TimeZoneKey.getTimeZoneKey; import static java.util.Map.Entry; @@ -44,6 +45,8 @@ public class QuerySessionSupplier private final SessionPropertyManager sessionPropertyManager; private final Optional path; private final Optional forcedSessionTimeZone; + private final Optional defaultCatalog; + private final Optional defaultSchema; @Inject public QuerySessionSupplier( @@ -58,6 +61,10 @@ public QuerySessionSupplier( requireNonNull(config, "config is null"); this.path = requireNonNull(config.getPath(), "path is null"); this.forcedSessionTimeZone = requireNonNull(config.getForcedSessionTimeZone(), "forcedSessionTimeZone is null"); + this.defaultCatalog = requireNonNull(config.getDefaultCatalog(), "defaultCatalog is null"); + this.defaultSchema = requireNonNull(config.getDefaultSchema(), "defaultSchema is null"); + + checkArgument(defaultCatalog.isPresent() || !defaultSchema.isPresent(), "Default schema cannot be set if catalog is not set"); } @Override @@ -70,8 +77,6 @@ public Session createSession(QueryId queryId, SessionContext context) .setQueryId(queryId) .setIdentity(identity) .setSource(context.getSource()) - .setCatalog(context.getCatalog()) - .setSchema(context.getSchema()) .setPath(new SqlPath(path)) .setRemoteUserAddress(context.getRemoteUserAddress()) .setUserAgent(context.getUserAgent()) @@ -81,6 +86,17 @@ public Session createSession(QueryId queryId, SessionContext context) .setTraceToken(context.getTraceToken()) .setResourceEstimates(context.getResourceEstimates()); + defaultCatalog.ifPresent(sessionBuilder::setCatalog); + defaultSchema.ifPresent(sessionBuilder::setSchema); + + if (context.getCatalog() != null) { + sessionBuilder.setCatalog(context.getCatalog()); + } + + if (context.getSchema() != null) { + sessionBuilder.setSchema(context.getSchema()); + } + if (context.getPath() != null) { sessionBuilder.setPath(new SqlPath(Optional.of(context.getPath()))); } diff --git a/presto-main/src/main/java/io/prestosql/sql/SqlEnvironmentConfig.java b/presto-main/src/main/java/io/prestosql/sql/SqlEnvironmentConfig.java index 1789f547fe020..8f02459d6b0bd 100644 --- a/presto-main/src/main/java/io/prestosql/sql/SqlEnvironmentConfig.java +++ b/presto-main/src/main/java/io/prestosql/sql/SqlEnvironmentConfig.java @@ -25,6 +25,8 @@ public class SqlEnvironmentConfig { private Optional path = Optional.empty(); + private Optional defaultCatalog = Optional.empty(); + private Optional defaultSchema = Optional.empty(); private Optional forcedSessionTimeZone = Optional.empty(); @NotNull @@ -40,6 +42,32 @@ public SqlEnvironmentConfig setPath(String path) return this; } + @NotNull + public Optional getDefaultCatalog() + { + return defaultCatalog; + } + + @Config("sql.default-catalog") + public SqlEnvironmentConfig setDefaultCatalog(String catalog) + { + this.defaultCatalog = Optional.ofNullable(catalog); + return this; + } + + @NotNull + public Optional getDefaultSchema() + { + return defaultSchema; + } + + @Config("sql.default-schema") + public SqlEnvironmentConfig setDefaultSchema(String schema) + { + this.defaultSchema = Optional.ofNullable(schema); + return this; + } + @NotNull public Optional getForcedSessionTimeZone() { diff --git a/presto-main/src/test/java/io/prestosql/server/TestQuerySessionSupplier.java b/presto-main/src/test/java/io/prestosql/server/TestQuerySessionSupplier.java index 87ebbde4ff496..741a9bd4e1e05 100644 --- a/presto-main/src/test/java/io/prestosql/server/TestQuerySessionSupplier.java +++ b/presto-main/src/test/java/io/prestosql/server/TestQuerySessionSupplier.java @@ -18,6 +18,7 @@ import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.ListMultimap; import io.prestosql.Session; import io.prestosql.metadata.SessionPropertyManager; import io.prestosql.security.AllowAllAccessControl; @@ -53,7 +54,9 @@ import static io.prestosql.dispatcher.DispatcherConfig.HeaderSupport.WARN; import static io.prestosql.spi.type.TimeZoneKey.getTimeZoneKey; import static io.prestosql.transaction.InMemoryTransactionManager.createTestTransactionManager; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; public class TestQuerySessionSupplier { @@ -78,11 +81,7 @@ public class TestQuerySessionSupplier public void testCreateSession() { HttpRequestSessionContext context = new HttpRequestSessionContext(WARN, TEST_REQUEST); - QuerySessionSupplier sessionSupplier = new QuerySessionSupplier( - createTestTransactionManager(), - new AllowAllAccessControl(), - new SessionPropertyManager(), - new SqlEnvironmentConfig()); + QuerySessionSupplier sessionSupplier = createSessionSupplier(new SqlEnvironmentConfig()); Session session = sessionSupplier.createSession(new QueryId("test_query_id"), context); assertEquals(session.getQueryId(), new QueryId("test_query_id")); @@ -159,11 +158,7 @@ public void testInvalidTimeZone() .build(), "testRemote"); HttpRequestSessionContext context = new HttpRequestSessionContext(WARN, request); - QuerySessionSupplier sessionSupplier = new QuerySessionSupplier( - createTestTransactionManager(), - new AllowAllAccessControl(), - new SessionPropertyManager(), - new SqlEnvironmentConfig()); + QuerySessionSupplier sessionSupplier = createSessionSupplier(new SqlEnvironmentConfig()); sessionSupplier.createSession(new QueryId("test_query_id"), context); } @@ -193,4 +188,79 @@ public void testSqlPathCreation() assertEquals(path.getParsedPath(), expected); assertEquals(path.toString(), Joiner.on(", ").join(expected)); } + + @Test + public void testDefaultCatalogAndSchema() + { + // none specified + Session session = createSession( + ImmutableListMultimap.builder() + .put(PRESTO_USER, "testUser") + .build(), + new SqlEnvironmentConfig()); + assertFalse(session.getCatalog().isPresent()); + assertFalse(session.getSchema().isPresent()); + + // catalog + session = createSession( + ImmutableListMultimap.builder() + .put(PRESTO_USER, "testUser") + .build(), + new SqlEnvironmentConfig() + .setDefaultCatalog("catalog")); + assertEquals(session.getCatalog(), Optional.of("catalog")); + assertFalse(session.getSchema().isPresent()); + + // catalog and schema + session = createSession( + ImmutableListMultimap.builder() + .put(PRESTO_USER, "testUser") + .build(), + new SqlEnvironmentConfig() + .setDefaultCatalog("catalog") + .setDefaultSchema("schema")); + assertEquals(session.getCatalog(), Optional.of("catalog")); + assertEquals(session.getSchema(), Optional.of("schema")); + + // only schema + assertThatThrownBy(() -> createSessionSupplier(new SqlEnvironmentConfig().setDefaultSchema("schema"))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Default schema cannot be set if catalog is not set"); + } + + @Test + public void testCatalogAndSchemaOverrides() + { + // none specified + Session session = createSession( + ImmutableListMultimap.builder() + .put(PRESTO_USER, "testUser") + .put(PRESTO_CATALOG, "catalog") + .put(PRESTO_SCHEMA, "schema") + .build(), + new SqlEnvironmentConfig() + .setDefaultCatalog("default-catalog") + .setDefaultSchema("default-schema")); + assertEquals(session.getCatalog(), Optional.of("catalog")); + assertEquals(session.getSchema(), Optional.of("schema")); + } + + private Session createSession(ListMultimap headers, SqlEnvironmentConfig config) + { + HttpRequestSessionContext context = new HttpRequestSessionContext( + WARN, + new MockHttpServletRequest(headers, "testRemote")); + + QuerySessionSupplier sessionSupplier = createSessionSupplier(config); + return sessionSupplier.createSession(new QueryId("test_query_id"), context); + } + + private static QuerySessionSupplier createSessionSupplier(SqlEnvironmentConfig config) + { + return new QuerySessionSupplier( + createTestTransactionManager(), + new AllowAllAccessControl(), + new SessionPropertyManager(), + config); + } } diff --git a/presto-main/src/test/java/io/prestosql/sql/TestSqlEnvironmentConfig.java b/presto-main/src/test/java/io/prestosql/sql/TestSqlEnvironmentConfig.java index bcfe1969e6a1c..003834a22c85d 100644 --- a/presto-main/src/test/java/io/prestosql/sql/TestSqlEnvironmentConfig.java +++ b/presto-main/src/test/java/io/prestosql/sql/TestSqlEnvironmentConfig.java @@ -30,6 +30,8 @@ public void testDefaults() { assertRecordedDefaults(recordDefaults(SqlEnvironmentConfig.class) .setPath(null) + .setDefaultCatalog(null) + .setDefaultSchema(null) .setForcedSessionTimeZone(null)); } @@ -38,11 +40,15 @@ public void testExplicitPropertyMappings() { Map properties = new ImmutableMap.Builder() .put("sql.path", "a.b, c.d") + .put("sql.default-catalog", "some-catalog") + .put("sql.default-schema", "some-schema") .put("sql.forced-session-time-zone", "UTC") .build(); SqlEnvironmentConfig expected = new SqlEnvironmentConfig() .setPath("a.b, c.d") + .setDefaultCatalog("some-catalog") + .setDefaultSchema("some-schema") .setForcedSessionTimeZone("UTC"); assertFullMapping(properties, expected); diff --git a/presto-tests/src/test/java/io/prestosql/tests/tpch/TpchQueryRunner.java b/presto-tests/src/test/java/io/prestosql/tests/tpch/TpchQueryRunner.java index c885483f25aa1..4b99dcef77048 100644 --- a/presto-tests/src/test/java/io/prestosql/tests/tpch/TpchQueryRunner.java +++ b/presto-tests/src/test/java/io/prestosql/tests/tpch/TpchQueryRunner.java @@ -13,6 +13,7 @@ */ package io.prestosql.tests.tpch; +import com.google.common.collect.ImmutableMap; import io.airlift.log.Logger; import io.airlift.log.Logging; import io.prestosql.tests.DistributedQueryRunner; @@ -26,7 +27,11 @@ public static void main(String[] args) { Logging.initialize(); DistributedQueryRunner queryRunner = TpchQueryRunnerBuilder.builder() - .setSingleExtraProperty("http-server.http.port", "8080") + .setExtraProperties(ImmutableMap.builder() + .put("http-server.http.port", "8080") + .put("sql.default-catalog", "tpch") + .put("sql.default-schema", "tiny") + .build()) .build(); Thread.sleep(10); Logger log = Logger.get(TpchQueryRunner.class);