Skip to content

Commit

Permalink
[trinodb#135] Expose complete query for execute statement
Browse files Browse the repository at this point in the history
  • Loading branch information
yangjinde committed Sep 1, 2022
1 parent 6750bbf commit 2202bb1
Show file tree
Hide file tree
Showing 8 changed files with 184 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ private static QueryInfo immediateFailureQueryInfo(
ImmutableList.of(),
query,
preparedQuery,
Optional.empty(),
immediateFailureQueryStats(),
Optional.empty(),
Optional.empty(),
Expand Down
10 changes: 10 additions & 0 deletions core/trino-main/src/main/java/io/trino/execution/QueryInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ public class QueryInfo
private final List<String> fieldNames;
private final String query;
private final Optional<String> preparedQuery;
private final Optional<String> completeExecuteQuery;
private final QueryStats queryStats;
private final Optional<String> setCatalog;
private final Optional<String> setSchema;
Expand Down Expand Up @@ -94,6 +95,7 @@ public QueryInfo(
@JsonProperty("fieldNames") List<String> fieldNames,
@JsonProperty("query") String query,
@JsonProperty("preparedQuery") Optional<String> preparedQuery,
@JsonProperty("completeExecuteQuery") Optional<String> completeExecuteQuery,
@JsonProperty("queryStats") QueryStats queryStats,
@JsonProperty("setCatalog") Optional<String> setCatalog,
@JsonProperty("setSchema") Optional<String> setSchema,
Expand Down Expand Up @@ -135,6 +137,7 @@ public QueryInfo(
requireNonNull(startedTransactionId, "startedTransactionId is null");
requireNonNull(query, "query is null");
requireNonNull(preparedQuery, "preparedQuery is null");
requireNonNull(completeExecuteQuery, "completeExecuteQuery is null");
requireNonNull(outputStage, "outputStage is null");
requireNonNull(inputs, "inputs is null");
requireNonNull(output, "output is null");
Expand All @@ -154,6 +157,7 @@ public QueryInfo(
this.fieldNames = ImmutableList.copyOf(fieldNames);
this.query = query;
this.preparedQuery = preparedQuery;
this.completeExecuteQuery = completeExecuteQuery;
this.queryStats = queryStats;
this.setCatalog = setCatalog;
this.setSchema = setSchema;
Expand Down Expand Up @@ -235,6 +239,12 @@ public Optional<String> getPreparedQuery()
return preparedQuery;
}

@JsonProperty
public Optional<String> getCompleteExecuteQuery()
{
return completeExecuteQuery;
}

@JsonProperty
public QueryStats getQueryStats()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ public class QueryStateMachine
private final QueryId queryId;
private final String query;
private final Optional<String> preparedQuery;
private final AtomicReference<String> completeExecuteQuery = new AtomicReference<>();
private final Session session;
private final URI self;
private final ResourceGroupId resourceGroup;
Expand Down Expand Up @@ -440,6 +441,7 @@ QueryInfo getQueryInfo(Optional<StageInfo> rootStage)
outputManager.getQueryOutputInfo().map(QueryOutputInfo::getColumnNames).orElse(ImmutableList.of()),
query,
preparedQuery,
Optional.ofNullable(completeExecuteQuery.get()),
getQueryStats(rootStage),
Optional.ofNullable(setCatalog.get()),
Optional.ofNullable(setSchema.get()),
Expand Down Expand Up @@ -828,6 +830,11 @@ public void clearTransactionId()
clearTransactionId.set(true);
}

public void setCompleteExecuteQuery(String completeQuery)
{
completeExecuteQuery.set(requireNonNull(completeQuery, "complete query is null"));
}

public void setUpdateType(String updateType)
{
this.updateType.set(updateType);
Expand Down Expand Up @@ -1142,6 +1149,7 @@ public void pruneQueryInfo()
queryInfo.getFieldNames(),
queryInfo.getQuery(),
queryInfo.getPreparedQuery(),
queryInfo.getCompleteExecuteQuery(),
pruneQueryStats(queryInfo.getQueryStats()),
queryInfo.getSetCatalog(),
queryInfo.getSetSchema(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import io.trino.server.protocol.Slug;
import io.trino.spi.QueryId;
import io.trino.spi.TrinoException;
import io.trino.sql.ExpressionFormatter;
import io.trino.sql.PlannerContext;
import io.trino.sql.analyzer.Analysis;
import io.trino.sql.analyzer.Analyzer;
Expand All @@ -58,14 +59,20 @@
import io.trino.sql.planner.optimizations.PlanOptimizer;
import io.trino.sql.planner.plan.OutputNode;
import io.trino.sql.tree.ExplainAnalyze;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.NodeLocation;
import io.trino.sql.tree.NodeRef;
import io.trino.sql.tree.Parameter;
import io.trino.sql.tree.Query;
import io.trino.sql.tree.Statement;
import org.joda.time.DateTime;

import javax.annotation.concurrent.ThreadSafe;
import javax.inject.Inject;

import java.util.AbstractMap.SimpleEntry;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand All @@ -78,6 +85,7 @@
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Throwables.throwIfInstanceOf;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
import static io.airlift.units.DataSize.succinctBytes;
import static io.trino.SystemSessionProperties.isEnableDynamicFiltering;
Expand Down Expand Up @@ -258,11 +266,68 @@ private static Analysis analyze(
stateMachine.setReferencedTables(analysis.getReferencedTables());
stateMachine.setRoutines(analysis.getRoutines());

// try to generate complete query for execute statement
if (preparedQuery.getPrepareSql().isPresent()) {
try {
stateMachine.setCompleteExecuteQuery(getCompleteExecuteQuery(
preparedQuery.getPrepareSql().get(), analysis.getParameters()));
}
catch (Exception e) {
LOG.warn(e, "Fail to generate complete query for execute statement");
}
}

stateMachine.endAnalysis();

return analysis;
}

private static String getCompleteExecuteQuery(String prepareSql, Map<NodeRef<Parameter>, Expression> parameters)
{
if (parameters.isEmpty()) {
return prepareSql;
}

List<SimpleEntry<NodeLocation, Expression>> sortedParams = parameters.entrySet().stream()
.map(entry -> new SimpleEntry<>(entry.getKey().getNode().getLocation().get(), entry.getValue()))
.sorted(Map.Entry.comparingByKey(Comparator.comparing(NodeLocation::getLineNumber)
.thenComparing(NodeLocation::getColumnNumber)))
.collect(toImmutableList());

StringBuilder sb = new StringBuilder();
List<String> lines = prepareSql.lines().collect(toImmutableList());
NodeLocation last = null;
for (SimpleEntry<NodeLocation, Expression> param : sortedParams) {
appendBetween(sb, lines, last, param.getKey());
sb.append(ExpressionFormatter.formatExpression(param.getValue()));
last = param.getKey();
}
appendBetween(sb, lines, last, null);
return sb.append("\n").toString();
}

private static void appendBetween(StringBuilder sb, List<String> lines, NodeLocation start, NodeLocation end)
{
if (start == null && end == null) {
lines.forEach(line -> sb.append(line).append("\n"));
}
else {
int startLine = start == null ? 0 : start.getLineNumber() - 1;
int startCol = start == null ? 0 : start.getColumnNumber();
int endLine = end == null ? lines.size() - 1 : end.getLineNumber() - 1;
int endCol = end == null ? lines.get(lines.size() - 1).length() : end.getColumnNumber() - 1;

if (startLine == endLine) {
sb.append(lines.get(startLine), startCol, endCol);
}
else {
sb.append(lines.get(startLine).substring(startCol)).append("\n");
lines.subList(startLine + 1, endLine).forEach(line -> sb.append(line).append("\n"));
sb.append(lines.get(endLine), 0, endCol);
}
}
}

@Override
public Slug getSlug()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ public QueryInfo getFullQueryInfo()
ImmutableList.of(),
"SELECT 1",
Optional.empty(),
Optional.empty(),
new QueryStats(
new DateTime(1),
new DateTime(2),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ public void testConstructor()
ImmutableList.of("2", "3"),
"SELECT 4",
Optional.empty(),
Optional.empty(),
new QueryStats(
DateTime.parse("1991-09-06T05:00-05:30"),
DateTime.parse("1991-09-06T05:01-05:30"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ private QueryInfo createQueryInfo(String queryId, QueryState state, String query
ImmutableList.of("2", "3"),
query,
Optional.empty(),
Optional.empty(),
new QueryStats(
DateTime.parse("1991-09-06T05:00-05:30"),
DateTime.parse("1991-09-06T05:01-05:30"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.tests;

import io.trino.Session;
import io.trino.execution.QueryInfo;
import io.trino.testing.AbstractTestQueryFramework;
import io.trino.testing.DistributedQueryRunner;
import io.trino.testing.MaterializedResult;
import io.trino.testing.QueryRunner;
import io.trino.testing.ResultWithQueryId;
import io.trino.tests.tpch.TpchQueryRunnerBuilder;
import org.intellij.lang.annotations.Language;
import org.testng.annotations.Test;

import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertTrue;

public class TestCompleteExecuteQuery
extends AbstractTestQueryFramework
{
private static final String PREPARE_STATEMENT_NAME = "my_select";
private static final String PREPARE_STATEMENT_PREFIX = "PREPARE " + PREPARE_STATEMENT_NAME + " FROM ";
private static final String EXECUTE_STATEMENT_PREFIX = "EXECUTE " + PREPARE_STATEMENT_NAME + " USING ";

@Override
protected QueryRunner createQueryRunner()
throws Exception
{
return TpchQueryRunnerBuilder.builder().build();
}

@Test
public void testCompleteExecuteQuery()
{
assertCompleteExecuteQuery(
"SELECT COUNT(*) FROM customer WHERE mktsegment = ?",
EXECUTE_STATEMENT_PREFIX + "'HOUSEHOLD'",
"SELECT COUNT(*) FROM customer WHERE mktsegment = 'HOUSEHOLD'");
assertCompleteExecuteQuery(
"SELECT COUNT(*) FROM orders WHERE custkey > ?",
EXECUTE_STATEMENT_PREFIX + "2*1000/300",
"SELECT COUNT(*) FROM orders WHERE custkey > 2*1000/300");
assertCompleteExecuteQuery(
"SELECT c.name, sum(o.totalprice) " +
"FROM (orders o INNER JOIN customer c ON o.custkey = c.custkey) " +
"WHERE nationkey = ? AND mktsegment = ? " +
"GROUP BY c.name " +
"LIMIT ?",
EXECUTE_STATEMENT_PREFIX + "2, 'BUILDING', 5",
"SELECT c.name, sum(o.totalprice) " +
"FROM (orders o INNER JOIN customer c ON o.custkey = c.custkey) " +
"WHERE nationkey = 2 AND mktsegment = 'BUILDING' " +
"GROUP BY c.name " +
"LIMIT 5");
}

private void assertCompleteExecuteQuery(
@Language("SQL") String prepareQuery,
@Language("SQL") String executeQuery,
@Language("SQL") String expectedQuery)
{
Session session = Session.builder(getSession())
.addPreparedStatement(PREPARE_STATEMENT_NAME, prepareQuery(prepareQuery))
.build();
QueryInfo queryInfo = runAndGetFullQueryInfo(session, executeQuery);
assertTrue(queryInfo.getCompleteExecuteQuery().isPresent(), "missing complete execute query");
assertEquals(queryInfo.getCompleteExecuteQuery().get(), prepareQuery(expectedQuery));
}

private String prepareQuery(@Language("SQL") String query)
{
QueryInfo queryInfo = runAndGetFullQueryInfo(getSession(), PREPARE_STATEMENT_PREFIX + query);
assertTrue(queryInfo.getAddedPreparedStatements().containsKey(PREPARE_STATEMENT_NAME), "fail to prepare query");
return queryInfo.getAddedPreparedStatements().get(PREPARE_STATEMENT_NAME);
}

private QueryInfo runAndGetFullQueryInfo(Session session, @Language("SQL") String query)
{
DistributedQueryRunner queryRunner = getDistributedQueryRunner();
ResultWithQueryId<MaterializedResult> resultWithQueryId = queryRunner.executeWithQueryId(session, query);
return queryRunner.getCoordinator()
.getQueryManager()
.getFullQueryInfo(resultWithQueryId.getQueryId());
}
}

0 comments on commit 2202bb1

Please sign in to comment.