Skip to content

Commit

Permalink
fixup! Add support for WITH SESSION clause
Browse files Browse the repository at this point in the history
  • Loading branch information
ebyhr committed Sep 23, 2024
1 parent bad8492 commit f61fb9f
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 52 deletions.
2 changes: 1 addition & 1 deletion core/trino-main/src/main/java/io/trino/Session.java
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
import static io.trino.client.ProtocolHeaders.TRINO_HEADERS;
import static io.trino.spi.StandardErrorCode.CATALOG_NOT_FOUND;
import static io.trino.spi.StandardErrorCode.NOT_FOUND;
import static io.trino.spi.type.TimeZoneKey.getTimeZoneKey;
import static io.trino.sql.SqlPath.EMPTY_PATH;
import static io.trino.util.Failures.checkCondition;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -435,6 +434,7 @@ public Session withProperties(Map<String, String> systemProperties, Map<String,
schema,
path,
traceToken,
// This is required to override a timezone using a WITH SESSION timezone
Optional.ofNullable(systemProperties.get(TIME_ZONE_ID))
.map(TimeZoneKey::getTimeZoneKey)
.orElse(timeZoneKey),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import io.trino.server.protocol.Slug;
import io.trino.spi.TrinoException;
import io.trino.spi.resourcegroups.ResourceGroupId;
import io.trino.sql.SessionSpecificationEvaluator;
import io.trino.sql.SessionPropertyInterpreter;
import io.trino.sql.tree.Statement;
import io.trino.transaction.TransactionId;
import io.trino.transaction.TransactionManager;
Expand All @@ -60,7 +60,7 @@ public class LocalDispatchQueryFactory
private final TransactionManager transactionManager;
private final AccessControl accessControl;
private final Metadata metadata;
private final SessionSpecificationEvaluator sessionSpecificationEvaluator;
private final SessionPropertyInterpreter sessionPropertyInterpreter;
private final QueryMonitor queryMonitor;
private final LocationFactory locationFactory;

Expand All @@ -79,7 +79,7 @@ public LocalDispatchQueryFactory(
QueryManager queryManager,
QueryManagerConfig queryManagerConfig,
TransactionManager transactionManager,
SessionSpecificationEvaluator sessionSpecificationEvaluator,
SessionPropertyInterpreter sessionPropertyInterpreter,
AccessControl accessControl,
Metadata metadata,
QueryMonitor queryMonitor,
Expand All @@ -95,7 +95,7 @@ public LocalDispatchQueryFactory(
this.transactionManager = requireNonNull(transactionManager, "transactionManager is null");
this.accessControl = requireNonNull(accessControl, "accessControl is null");
this.metadata = requireNonNull(metadata, "metadata is null");
this.sessionSpecificationEvaluator = requireNonNull(sessionSpecificationEvaluator, "sessionSpecificationEvaluator is null");
this.sessionPropertyInterpreter = requireNonNull(sessionPropertyInterpreter, "sessionPropertyInterpreter is null");
this.queryMonitor = requireNonNull(queryMonitor, "queryMonitor is null");
this.locationFactory = requireNonNull(locationFactory, "locationFactory is null");
this.executionFactories = requireNonNull(executionFactories, "executionFactories is null");
Expand Down Expand Up @@ -136,7 +136,7 @@ public DispatchQuery createDispatchQuery(
planOptimizersStatsCollector,
getQueryType(preparedQuery.getStatement()),
faultTolerantExecutionExchangeEncryptionEnabled,
Optional.of(sessionSpecificationEvaluator.getSessionSpecificationApplier(preparedQuery)),
Optional.of(sessionPropertyInterpreter.getSessionPropertiesApplier(preparedQuery)),
version);

// It is important that `queryCreatedEvent` is called here. Moving it past the `executor.submit` below
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
import io.trino.spi.resourcegroups.ResourceGroupId;
import io.trino.spi.security.SelectedRole;
import io.trino.spi.type.Type;
import io.trino.sql.SessionSpecificationEvaluator.SessionSpecificationsApplier;
import io.trino.sql.SessionPropertyInterpreter.SessionPropertiesApplier;
import io.trino.sql.analyzer.Output;
import io.trino.sql.planner.PlanFragment;
import io.trino.tracing.TrinoAttributes;
Expand Down Expand Up @@ -244,7 +244,7 @@ public static QueryStateMachine begin(
PlanOptimizersStatsCollector queryStatsCollector,
Optional<QueryType> queryType,
boolean faultTolerantExecutionExchangeEncryptionEnabled,
Optional<SessionSpecificationsApplier> sessionSpecificationsApplier,
Optional<SessionPropertiesApplier> sessionPropertiesApplier,
NodeVersion version)
{
return beginWithTicker(
Expand All @@ -264,7 +264,7 @@ public static QueryStateMachine begin(
queryStatsCollector,
queryType,
faultTolerantExecutionExchangeEncryptionEnabled,
sessionSpecificationsApplier,
sessionPropertiesApplier,
version);
}

Expand All @@ -285,7 +285,7 @@ static QueryStateMachine beginWithTicker(
PlanOptimizersStatsCollector queryStatsCollector,
Optional<QueryType> queryType,
boolean faultTolerantExecutionExchangeEncryptionEnabled,
Optional<SessionSpecificationsApplier> sessionSpecificationsApplier,
Optional<SessionPropertiesApplier> sessionPropertiesApplier,
NodeVersion version)
{
// if there is an existing transaction, activate it
Expand All @@ -312,9 +312,9 @@ static QueryStateMachine beginWithTicker(
session = session.withExchangeEncryption(serializeAesEncryptionKey(createRandomAesEncryptionKey()));
}

// Apply WITH SESSION specifications which require transaction to be started to resolve catalog handles
if (sessionSpecificationsApplier.isPresent()) {
session = sessionSpecificationsApplier.orElseThrow().apply(session);
// Apply WITH SESSION properties which require transaction to be started to resolve catalog handles
if (sessionPropertiesApplier.isPresent()) {
session = sessionPropertiesApplier.orElseThrow().apply(session);
}

Span querySpan = session.getQuerySpan();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@
import io.trino.server.ui.WorkerResource;
import io.trino.spi.VersionEmbedder;
import io.trino.sql.PlannerContext;
import io.trino.sql.SessionSpecificationEvaluator;
import io.trino.sql.SessionPropertyInterpreter;
import io.trino.sql.analyzer.AnalyzerFactory;
import io.trino.sql.analyzer.QueryExplainerFactory;
import io.trino.sql.planner.OptimizerStatsMBeanExporter;
Expand Down Expand Up @@ -212,7 +212,7 @@ protected void setup(Binder binder)
// dispatcher
binder.bind(DispatchManager.class).in(Scopes.SINGLETON);
// WITH SESSION interpreter
binder.bind(SessionSpecificationEvaluator.class).in(Scopes.SINGLETON);
binder.bind(SessionPropertyInterpreter.class).in(Scopes.SINGLETON);
// export under the old name, for backwards compatibility
newExporter(binder).export(DispatchManager.class).as(generator -> generator.generatedNameOf(QueryManager.class));
binder.bind(FailedDispatchQueryFactory.class).in(Scopes.SINGLETON);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import io.trino.sql.tree.Parameter;
import io.trino.sql.tree.QualifiedName;
import io.trino.sql.tree.Query;
import io.trino.sql.tree.SessionSpecification;
import io.trino.sql.tree.SessionProperty;

import java.util.HashMap;
import java.util.HashSet;
Expand All @@ -52,92 +52,92 @@
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;

public class SessionSpecificationEvaluator
public class SessionPropertyInterpreter
{
private final PlannerContext plannerContext;
private final AccessControl accessControl;
private final SessionPropertyManager sessionPropertyManager;

@Inject
public SessionSpecificationEvaluator(PlannerContext plannerContext, AccessControl accessControl, SessionPropertyManager sessionPropertyManager)
public SessionPropertyInterpreter(PlannerContext plannerContext, AccessControl accessControl, SessionPropertyManager sessionPropertyManager)
{
this.plannerContext = requireNonNull(plannerContext, "plannerContext is null");
this.accessControl = requireNonNull(accessControl, "accessControl is null");
this.sessionPropertyManager = requireNonNull(sessionPropertyManager, "sessionPropertyManager is null");
}

public SessionSpecificationsApplier getSessionSpecificationApplier(PreparedQuery preparedQuery)
public SessionPropertiesApplier getSessionPropertiesApplier(PreparedQuery preparedQuery)
{
if (!(preparedQuery.getStatement() instanceof Query queryStatement)) {
return session -> session;
}
return session -> prepareSession(session, queryStatement.getSessionProperties(), bindParameters(preparedQuery.getStatement(), preparedQuery.getParameters()));
}

private Session prepareSession(Session session, List<SessionSpecification> specifications, Map<NodeRef<Parameter>, Expression> parameters)
private Session prepareSession(Session session, List<SessionProperty> sessionProperties, Map<NodeRef<Parameter>, Expression> parameters)
{
ResolvedSessionSpecifications resolvedSessionSpecifications = resolve(session, parameters, specifications);
return overrideProperties(session, resolvedSessionSpecifications);
ResolvedSessionProperties resolvedSessionProperties = resolve(session, parameters, sessionProperties);
return overrideProperties(session, resolvedSessionProperties);
}

private ResolvedSessionSpecifications resolve(Session session, Map<NodeRef<Parameter>, Expression> parameters, List<SessionSpecification> specifications)
private ResolvedSessionProperties resolve(Session session, Map<NodeRef<Parameter>, Expression> parameters, List<SessionProperty> sessionProperties)
{
ImmutableMap.Builder<String, String> sessionProperties = ImmutableMap.builder();
ImmutableMap.Builder<String, String> systemProperties = ImmutableMap.builder();
Table<String, String, String> catalogProperties = HashBasedTable.create();
Set<QualifiedName> seenPropertyNames = new HashSet<>();

for (SessionSpecification specification : specifications) {
List<String> nameParts = specification.getName().getParts();
for (SessionProperty sessionProperty : sessionProperties) {
List<String> nameParts = sessionProperty.getName().getParts();

if (!seenPropertyNames.add(specification.getName())) {
throw semanticException(INVALID_SESSION_PROPERTY, specification, "Session property %s already set", specification.getName());
if (!seenPropertyNames.add(sessionProperty.getName())) {
throw semanticException(INVALID_SESSION_PROPERTY, sessionProperty, "Session property %s already set", sessionProperty.getName());
}

if (nameParts.size() == 1) {
Optional<PropertyMetadata<?>> systemSessionPropertyMetadata = sessionPropertyManager.getSystemSessionPropertyMetadata(nameParts.getFirst());
if (systemSessionPropertyMetadata.isEmpty()) {
throw semanticException(INVALID_SESSION_PROPERTY, specification, "Session property %s does not exist", specification.getName());
throw semanticException(INVALID_SESSION_PROPERTY, sessionProperty, "Session property %s does not exist", sessionProperty.getName());
}
sessionProperties.put(nameParts.getFirst(), toSessionValue(session, parameters, specification, systemSessionPropertyMetadata.get()));
systemProperties.put(nameParts.getFirst(), toSessionValue(session, parameters, sessionProperty, systemSessionPropertyMetadata.get()));
}
else if (nameParts.size() == 2) {
String catalogName = nameParts.getFirst();
String propertyName = nameParts.getLast();

CatalogHandle catalogHandle = getRequiredCatalogHandle(plannerContext.getMetadata(), session, specification, catalogName);
CatalogHandle catalogHandle = getRequiredCatalogHandle(plannerContext.getMetadata(), session, sessionProperty, catalogName);
Optional<PropertyMetadata<?>> connectorSessionPropertyMetadata = sessionPropertyManager.getConnectorSessionPropertyMetadata(catalogHandle, propertyName);
if (connectorSessionPropertyMetadata.isEmpty()) {
throw semanticException(INVALID_SESSION_PROPERTY, specification, "Session property %s does not exist", specification.getName());
throw semanticException(INVALID_SESSION_PROPERTY, sessionProperty, "Session property %s does not exist", sessionProperty.getName());
}
catalogProperties.put(catalogName, propertyName, toSessionValue(session, parameters, specification, connectorSessionPropertyMetadata.get()));
catalogProperties.put(catalogName, propertyName, toSessionValue(session, parameters, sessionProperty, connectorSessionPropertyMetadata.get()));
}
else {
throw semanticException(INVALID_SESSION_PROPERTY, specification, "Invalid session property '%s'", specification.getName());
throw semanticException(INVALID_SESSION_PROPERTY, sessionProperty, "Invalid session property '%s'", sessionProperty.getName());
}
}

return new ResolvedSessionSpecifications(sessionProperties.buildOrThrow(), catalogProperties.rowMap());
return new ResolvedSessionProperties(systemProperties.buildOrThrow(), catalogProperties.rowMap());
}

private Session overrideProperties(Session session, ResolvedSessionSpecifications resolvedSessionSpecifications)
private Session overrideProperties(Session session, ResolvedSessionProperties resolvedSessionProperties)
{
requireNonNull(resolvedSessionSpecifications, "resolvedSessionSpecifications is null");
requireNonNull(resolvedSessionProperties, "resolvedSessionProperties is null");

// TODO Consider moving validation to Session.withProperties method
validateSystemProperties(session, resolvedSessionSpecifications.systemProperties());
validateSystemProperties(session, resolvedSessionProperties.systemProperties());

// Catalog session properties were already evaluated so we need to evaluate overrides
if (session.getTransactionId().isPresent()) {
validateCatalogProperties(session, resolvedSessionSpecifications.catalogProperties());
validateCatalogProperties(session, resolvedSessionProperties.catalogProperties());
}

// NOTE: properties are validated before calling overrideProperties
Map<String, String> systemProperties = new HashMap<>();
systemProperties.putAll(session.getSystemProperties());
systemProperties.putAll(resolvedSessionSpecifications.systemProperties());
systemProperties.putAll(resolvedSessionProperties.systemProperties());

Map<String, Map<String, String>> catalogProperties = new HashMap<>(session.getCatalogProperties());
for (Map.Entry<String, Map<String, String>> catalogEntry : resolvedSessionSpecifications.catalogProperties().entrySet()) {
for (Map.Entry<String, Map<String, String>> catalogEntry : resolvedSessionProperties.catalogProperties().entrySet()) {
catalogProperties.computeIfAbsent(catalogEntry.getKey(), id -> new HashMap<>())
.putAll(catalogEntry.getValue());
}
Expand All @@ -146,18 +146,18 @@ private Session overrideProperties(Session session, ResolvedSessionSpecification
}

// TODO Consider extracting a method from SetSessionTask and reusing it here
private String toSessionValue(Session session, Map<NodeRef<Parameter>, Expression> parameters, SessionSpecification specification, PropertyMetadata<?> propertyMetadata)
private String toSessionValue(Session session, Map<NodeRef<Parameter>, Expression> parameters, SessionProperty sessionProperty, PropertyMetadata<?> propertyMetadata)
{
Type type = propertyMetadata.getSqlType();
Object objectValue;

try {
objectValue = evaluatePropertyValue(specification.getValue(), type, session, plannerContext, accessControl, parameters);
objectValue = evaluatePropertyValue(sessionProperty.getValue(), type, session, plannerContext, accessControl, parameters);
}
catch (TrinoException e) {
throw new TrinoException(
INVALID_SESSION_PROPERTY,
format("Unable to set session property '%s' to '%s': %s", specification.getName(), specification.getValue(), e.getRawMessage()));
format("Unable to set session property '%s' to '%s': %s", sessionProperty.getName(), sessionProperty.getValue(), e.getRawMessage()));
}

String value = serializeSessionProperty(type, objectValue);
Expand All @@ -166,7 +166,7 @@ private String toSessionValue(Session session, Map<NodeRef<Parameter>, Expressio
propertyMetadata.decode(objectValue);
}
catch (RuntimeException e) {
throw semanticException(INVALID_SESSION_PROPERTY, specification, "%s", e.getMessage());
throw semanticException(INVALID_SESSION_PROPERTY, sessionProperty, "%s", e.getMessage());
}

return value;
Expand Down Expand Up @@ -198,17 +198,17 @@ private void validateCatalogProperties(Session session, Map<String, Map<String,
}
}

public record ResolvedSessionSpecifications(Map<String, String> systemProperties, Map<String, Map<String, String>> catalogProperties)
public record ResolvedSessionProperties(Map<String, String> systemProperties, Map<String, Map<String, String>> catalogProperties)
{
public ResolvedSessionSpecifications
public ResolvedSessionProperties
{
systemProperties = ImmutableMap.copyOf(requireNonNull(systemProperties, "systemProperties is null"));
catalogProperties = ImmutableMap.copyOf(requireNonNull(catalogProperties, "catalogProperties is null"));
}
}

@FunctionalInterface
public interface SessionSpecificationsApplier
public interface SessionPropertiesApplier
extends Function<Session, Session>
{
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import io.trino.spi.type.TypeManager;
import io.trino.spi.type.TypeOperators;
import io.trino.sql.PlannerContext;
import io.trino.sql.SessionSpecificationEvaluator;
import io.trino.sql.SessionPropertyInterpreter;
import io.trino.sql.parser.SqlParser;
import io.trino.transaction.TestingTransactionManager;
import io.trino.transaction.TransactionManager;
Expand All @@ -55,7 +55,7 @@

@TestInstance(PER_CLASS)
@Execution(CONCURRENT)
final class TestSessionSpecifications
final class TestSessionProperties
{
private static final SqlParser SQL_PARSER = new SqlParser();
private static final SessionPropertyManager SESSION_PROPERTY_MANAGER = new SessionPropertyManager(
Expand Down Expand Up @@ -117,9 +117,9 @@ private static Session analyze(@Language("SQL") String statement)

return transaction(transactionManager, plannerContext.getMetadata(), new AllowAllAccessControl())
.execute(testSession(), transactionSession -> {
SessionSpecificationEvaluator evaluator = new SessionSpecificationEvaluator(plannerContext, new AllowAllAccessControl(), SESSION_PROPERTY_MANAGER);
SessionPropertyInterpreter evaluator = new SessionPropertyInterpreter(plannerContext, new AllowAllAccessControl(), SESSION_PROPERTY_MANAGER);
QueryPreparer.PreparedQuery preparedQuery = new QueryPreparer.PreparedQuery(SQL_PARSER.createStatement(statement), ImmutableList.of(), Optional.empty());
return evaluator.getSessionSpecificationApplier(preparedQuery).apply(transactionSession);
return evaluator.getSessionPropertiesApplier(preparedQuery).apply(transactionSession);
});
}

Expand Down

0 comments on commit f61fb9f

Please sign in to comment.