Skip to content

Commit

Permalink
Add option to set default catalog and schema
Browse files Browse the repository at this point in the history
  • Loading branch information
martint committed Sep 15, 2019
1 parent b006af8 commit 3399d63
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import java.util.Map;
import java.util.Optional;

import static com.google.common.base.Preconditions.checkArgument;
import static io.prestosql.Session.SessionBuilder;
import static io.prestosql.spi.type.TimeZoneKey.getTimeZoneKey;
import static java.util.Map.Entry;
Expand All @@ -44,6 +45,8 @@ public class QuerySessionSupplier
private final SessionPropertyManager sessionPropertyManager;
private final Optional<String> path;
private final Optional<TimeZoneKey> forcedSessionTimeZone;
private final Optional<String> defaultCatalog;
private final Optional<String> defaultSchema;

@Inject
public QuerySessionSupplier(
Expand All @@ -58,6 +61,10 @@ public QuerySessionSupplier(
requireNonNull(config, "config is null");
this.path = requireNonNull(config.getPath(), "path is null");
this.forcedSessionTimeZone = requireNonNull(config.getForcedSessionTimeZone(), "forcedSessionTimeZone is null");
this.defaultCatalog = requireNonNull(config.getDefaultCatalog(), "defaultCatalog is null");
this.defaultSchema = requireNonNull(config.getDefaultSchema(), "defaultSchema is null");

checkArgument(defaultCatalog.isPresent() || !defaultSchema.isPresent(), "Default schema cannot be set if catalog is not set");
}

@Override
Expand All @@ -70,8 +77,6 @@ public Session createSession(QueryId queryId, SessionContext context)
.setQueryId(queryId)
.setIdentity(identity)
.setSource(context.getSource())
.setCatalog(context.getCatalog())
.setSchema(context.getSchema())
.setPath(new SqlPath(path))
.setRemoteUserAddress(context.getRemoteUserAddress())
.setUserAgent(context.getUserAgent())
Expand All @@ -81,6 +86,17 @@ public Session createSession(QueryId queryId, SessionContext context)
.setTraceToken(context.getTraceToken())
.setResourceEstimates(context.getResourceEstimates());

defaultCatalog.ifPresent(sessionBuilder::setCatalog);
defaultSchema.ifPresent(sessionBuilder::setSchema);

if (context.getCatalog() != null) {
sessionBuilder.setCatalog(context.getCatalog());
}

if (context.getSchema() != null) {
sessionBuilder.setSchema(context.getSchema());
}

if (context.getPath() != null) {
sessionBuilder.setPath(new SqlPath(Optional.of(context.getPath())));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
public class SqlEnvironmentConfig
{
private Optional<String> path = Optional.empty();
private Optional<String> defaultCatalog = Optional.empty();
private Optional<String> defaultSchema = Optional.empty();
private Optional<TimeZoneKey> forcedSessionTimeZone = Optional.empty();

@NotNull
Expand All @@ -40,6 +42,32 @@ public SqlEnvironmentConfig setPath(String path)
return this;
}

@NotNull
public Optional<String> getDefaultCatalog()
{
return defaultCatalog;
}

@Config("sql.default-catalog")
public SqlEnvironmentConfig setDefaultCatalog(String catalog)
{
this.defaultCatalog = Optional.ofNullable(catalog);
return this;
}

@NotNull
public Optional<String> getDefaultSchema()
{
return defaultSchema;
}

@Config("sql.default-schema")
public SqlEnvironmentConfig setDefaultSchema(String schema)
{
this.defaultSchema = Optional.ofNullable(schema);
return this;
}

@NotNull
public Optional<TimeZoneKey> getForcedSessionTimeZone()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ListMultimap;
import io.prestosql.Session;
import io.prestosql.metadata.SessionPropertyManager;
import io.prestosql.security.AllowAllAccessControl;
Expand Down Expand Up @@ -53,7 +54,9 @@
import static io.prestosql.dispatcher.DispatcherConfig.HeaderSupport.WARN;
import static io.prestosql.spi.type.TimeZoneKey.getTimeZoneKey;
import static io.prestosql.transaction.InMemoryTransactionManager.createTestTransactionManager;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;

public class TestQuerySessionSupplier
{
Expand All @@ -78,11 +81,7 @@ public class TestQuerySessionSupplier
public void testCreateSession()
{
HttpRequestSessionContext context = new HttpRequestSessionContext(WARN, TEST_REQUEST);
QuerySessionSupplier sessionSupplier = new QuerySessionSupplier(
createTestTransactionManager(),
new AllowAllAccessControl(),
new SessionPropertyManager(),
new SqlEnvironmentConfig());
QuerySessionSupplier sessionSupplier = createSessionSupplier(new SqlEnvironmentConfig());
Session session = sessionSupplier.createSession(new QueryId("test_query_id"), context);

assertEquals(session.getQueryId(), new QueryId("test_query_id"));
Expand Down Expand Up @@ -159,11 +158,7 @@ public void testInvalidTimeZone()
.build(),
"testRemote");
HttpRequestSessionContext context = new HttpRequestSessionContext(WARN, request);
QuerySessionSupplier sessionSupplier = new QuerySessionSupplier(
createTestTransactionManager(),
new AllowAllAccessControl(),
new SessionPropertyManager(),
new SqlEnvironmentConfig());
QuerySessionSupplier sessionSupplier = createSessionSupplier(new SqlEnvironmentConfig());
sessionSupplier.createSession(new QueryId("test_query_id"), context);
}

Expand Down Expand Up @@ -193,4 +188,79 @@ public void testSqlPathCreation()
assertEquals(path.getParsedPath(), expected);
assertEquals(path.toString(), Joiner.on(", ").join(expected));
}

@Test
public void testDefaultCatalogAndSchema()
{
// none specified
Session session = createSession(
ImmutableListMultimap.<String, String>builder()
.put(PRESTO_USER, "testUser")
.build(),
new SqlEnvironmentConfig());
assertFalse(session.getCatalog().isPresent());
assertFalse(session.getSchema().isPresent());

// catalog
session = createSession(
ImmutableListMultimap.<String, String>builder()
.put(PRESTO_USER, "testUser")
.build(),
new SqlEnvironmentConfig()
.setDefaultCatalog("catalog"));
assertEquals(session.getCatalog(), Optional.of("catalog"));
assertFalse(session.getSchema().isPresent());

// catalog and schema
session = createSession(
ImmutableListMultimap.<String, String>builder()
.put(PRESTO_USER, "testUser")
.build(),
new SqlEnvironmentConfig()
.setDefaultCatalog("catalog")
.setDefaultSchema("schema"));
assertEquals(session.getCatalog(), Optional.of("catalog"));
assertEquals(session.getSchema(), Optional.of("schema"));

// only schema
assertThatThrownBy(() -> createSessionSupplier(new SqlEnvironmentConfig().setDefaultSchema("schema")))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Default schema cannot be set if catalog is not set");
}

@Test
public void testCatalogAndSchemaOverrides()
{
// none specified
Session session = createSession(
ImmutableListMultimap.<String, String>builder()
.put(PRESTO_USER, "testUser")
.put(PRESTO_CATALOG, "catalog")
.put(PRESTO_SCHEMA, "schema")
.build(),
new SqlEnvironmentConfig()
.setDefaultCatalog("default-catalog")
.setDefaultSchema("default-schema"));
assertEquals(session.getCatalog(), Optional.of("catalog"));
assertEquals(session.getSchema(), Optional.of("schema"));
}

private Session createSession(ListMultimap<String, String> headers, SqlEnvironmentConfig config)
{
HttpRequestSessionContext context = new HttpRequestSessionContext(
WARN,
new MockHttpServletRequest(headers, "testRemote"));

QuerySessionSupplier sessionSupplier = createSessionSupplier(config);
return sessionSupplier.createSession(new QueryId("test_query_id"), context);
}

private static QuerySessionSupplier createSessionSupplier(SqlEnvironmentConfig config)
{
return new QuerySessionSupplier(
createTestTransactionManager(),
new AllowAllAccessControl(),
new SessionPropertyManager(),
config);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ public void testDefaults()
{
assertRecordedDefaults(recordDefaults(SqlEnvironmentConfig.class)
.setPath(null)
.setDefaultCatalog(null)
.setDefaultSchema(null)
.setForcedSessionTimeZone(null));
}

Expand All @@ -38,11 +40,15 @@ public void testExplicitPropertyMappings()
{
Map<String, String> properties = new ImmutableMap.Builder<String, String>()
.put("sql.path", "a.b, c.d")
.put("sql.default-catalog", "some-catalog")
.put("sql.default-schema", "some-schema")
.put("sql.forced-session-time-zone", "UTC")
.build();

SqlEnvironmentConfig expected = new SqlEnvironmentConfig()
.setPath("a.b, c.d")
.setDefaultCatalog("some-catalog")
.setDefaultSchema("some-schema")
.setForcedSessionTimeZone("UTC");

assertFullMapping(properties, expected);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.prestosql.tests.tpch;

import com.google.common.collect.ImmutableMap;
import io.airlift.log.Logger;
import io.airlift.log.Logging;
import io.prestosql.tests.DistributedQueryRunner;
Expand All @@ -26,7 +27,11 @@ public static void main(String[] args)
{
Logging.initialize();
DistributedQueryRunner queryRunner = TpchQueryRunnerBuilder.builder()
.setSingleExtraProperty("http-server.http.port", "8080")
.setExtraProperties(ImmutableMap.<String, String>builder()
.put("http-server.http.port", "8080")
.put("sql.default-catalog", "tpch")
.put("sql.default-schema", "tiny")
.build())
.build();
Thread.sleep(10);
Logger log = Logger.get(TpchQueryRunner.class);
Expand Down

0 comments on commit 3399d63

Please sign in to comment.