Skip to content

Commit

Permalink
Remove checking non nullability via TableWriterOperator
Browse files Browse the repository at this point in the history
This is now redundant as the checks are inserted by the planner
  • Loading branch information
homar authored and findepi committed Aug 10, 2022
1 parent 7667716 commit a631276
Show file tree
Hide file tree
Showing 9 changed files with 3 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import io.trino.spi.Mergeable;
import io.trino.spi.Page;
import io.trino.spi.PageBuilder;
import io.trino.spi.TrinoException;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.RunLengthEncodedBlock;
Expand All @@ -53,7 +52,6 @@
import static io.airlift.concurrent.MoreFutures.getFutureValue;
import static io.airlift.concurrent.MoreFutures.toListenableFuture;
import static io.trino.SystemSessionProperties.isStatisticsCpuTimerEnabled;
import static io.trino.spi.StandardErrorCode.CONSTRAINT_VIOLATION;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.VarbinaryType.VARBINARY;
import static io.trino.sql.planner.plan.TableWriterNode.CreateTarget;
Expand All @@ -76,7 +74,6 @@ public static class TableWriterOperatorFactory
private final PageSinkManager pageSinkManager;
private final WriterTarget target;
private final List<Integer> columnChannels;
private final List<String> notNullChannelColumnNames;
private final Session session;
private final OperatorFactory statisticsAggregationOperatorFactory;
private final List<Type> types;
Expand All @@ -88,15 +85,13 @@ public TableWriterOperatorFactory(
PageSinkManager pageSinkManager,
WriterTarget writerTarget,
List<Integer> columnChannels,
List<String> notNullChannelColumnNames,
Session session,
OperatorFactory statisticsAggregationOperatorFactory,
List<Type> types)
{
this.operatorId = operatorId;
this.planNodeId = requireNonNull(planNodeId, "planNodeId is null");
this.columnChannels = requireNonNull(columnChannels, "columnChannels is null");
this.notNullChannelColumnNames = requireNonNull(notNullChannelColumnNames, "notNullChannelColumnNames is null");
this.pageSinkManager = requireNonNull(pageSinkManager, "pageSinkManager is null");
checkArgument(
writerTarget instanceof CreateTarget
Expand All @@ -117,7 +112,7 @@ public Operator createOperator(DriverContext driverContext)
OperatorContext context = driverContext.addOperatorContext(operatorId, planNodeId, TableWriterOperator.class.getSimpleName());
Operator statisticsAggregationOperator = statisticsAggregationOperatorFactory.createOperator(driverContext);
boolean statisticsCpuTimerEnabled = !(statisticsAggregationOperator instanceof DevNullOperator) && isStatisticsCpuTimerEnabled(session);
return new TableWriterOperator(context, createPageSink(), columnChannels, notNullChannelColumnNames, statisticsAggregationOperator, types, statisticsCpuTimerEnabled);
return new TableWriterOperator(context, createPageSink(), columnChannels, statisticsAggregationOperator, types, statisticsCpuTimerEnabled);
}

private ConnectorPageSink createPageSink()
Expand Down Expand Up @@ -146,7 +141,7 @@ public void noMoreOperators()
@Override
public OperatorFactory duplicate()
{
return new TableWriterOperatorFactory(operatorId, planNodeId, pageSinkManager, target, columnChannels, notNullChannelColumnNames, session, statisticsAggregationOperatorFactory, types);
return new TableWriterOperatorFactory(operatorId, planNodeId, pageSinkManager, target, columnChannels, session, statisticsAggregationOperatorFactory, types);
}
}

Expand All @@ -159,7 +154,6 @@ private enum State
private final LocalMemoryContext pageSinkMemoryContext;
private final ConnectorPageSink pageSink;
private final List<Integer> columnChannels;
private final List<String> notNullChannelColumnNames;
private final AtomicLong pageSinkPeakMemoryUsage = new AtomicLong();
private final Operator statisticAggregationOperator;
private final List<Type> types;
Expand All @@ -181,7 +175,6 @@ public TableWriterOperator(
OperatorContext operatorContext,
ConnectorPageSink pageSink,
List<Integer> columnChannels,
List<String> notNullChannelColumnNames,
Operator statisticAggregationOperator,
List<Type> types,
boolean statisticsCpuTimerEnabled)
Expand All @@ -190,8 +183,6 @@ public TableWriterOperator(
this.pageSinkMemoryContext = operatorContext.newLocalUserMemoryContext(TableWriterOperator.class.getSimpleName());
this.pageSink = requireNonNull(pageSink, "pageSink is null");
this.columnChannels = requireNonNull(columnChannels, "columnChannels is null");
this.notNullChannelColumnNames = requireNonNull(notNullChannelColumnNames, "notNullChannelColumnNames is null");
checkArgument(columnChannels.size() == notNullChannelColumnNames.size(), "columnChannels and notNullColumnNames have different sizes");
this.statisticAggregationOperator = requireNonNull(statisticAggregationOperator, "statisticAggregationOperator is null");
this.types = ImmutableList.copyOf(requireNonNull(types, "types is null"));
this.statisticsCpuTimerEnabled = statisticsCpuTimerEnabled;
Expand Down Expand Up @@ -255,10 +246,6 @@ public void addInput(Page page)
Block[] blocks = new Block[columnChannels.size()];
for (int outputChannel = 0; outputChannel < columnChannels.size(); outputChannel++) {
Block block = page.getBlock(columnChannels.get(outputChannel));
String columnName = notNullChannelColumnNames.get(outputChannel);
if (columnName != null) {
verifyBlockHasNoNulls(block, columnName);
}
blocks[outputChannel] = block;
}

Expand All @@ -275,18 +262,6 @@ public void addInput(Page page)
updateWrittenBytes();
}

private void verifyBlockHasNoNulls(Block block, String columnName)
{
if (!block.mayHaveNull()) {
return;
}
for (int position = 0; position < block.getPositionCount(); position++) {
if (block.isNull(position)) {
throw new TrinoException(CONSTRAINT_VIOLATION, "NULL value not allowed for NOT NULL column: " + columnName);
}
}
}

@Override
public Page getOutput()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,6 @@
import static io.trino.util.SpatialJoinUtils.ST_WITHIN;
import static io.trino.util.SpatialJoinUtils.extractSupportedSpatialComparisons;
import static io.trino.util.SpatialJoinUtils.extractSupportedSpatialFunctions;
import static java.util.Collections.nCopies;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.TimeUnit.HOURS;
import static java.util.stream.Collectors.partitioningBy;
Expand Down Expand Up @@ -3239,17 +3238,12 @@ public PhysicalOperation visitTableWriter(TableWriterNode node, LocalExecutionPl
.map(source::symbolToChannel)
.collect(toImmutableList());

List<String> notNullChannelColumnNames = node.getColumns().stream()
.map(symbol -> node.getNotNullColumnSymbols().contains(symbol) ? node.getColumnNames().get(source.symbolToChannel(symbol)) : null)
.collect(Collectors.toList());

OperatorFactory operatorFactory = new TableWriterOperatorFactory(
context.getNextOperatorId(),
node.getId(),
pageSinkManager,
node.getTarget(),
inputChannels,
notNullChannelColumnNames,
session,
statisticsAggregation,
getSymbolTypes(node.getOutputSymbols(), context.getTypes()));
Expand Down Expand Up @@ -3412,7 +3406,6 @@ public PhysicalOperation visitTableExecute(TableExecuteNode node, LocalExecution
pageSinkManager,
node.getTarget(),
inputChannels,
nCopies(inputChannels.size(), null), // N x null means no not-null checking will be performed. This is ok as in TableExecute flow we are not changing any table data.
session,
new DevNullOperatorFactory(context.getNextOperatorId(), node.getId()), // statistics are not calculated
getSymbolTypes(node.getOutputSymbols(), context.getTypes()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,12 @@
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;

import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.Streams.zip;
import static io.trino.SystemSessionProperties.isCollectPlanStatisticsForAllQueries;
import static io.trino.metadata.MetadataUtil.createQualifiedObjectName;
Expand Down Expand Up @@ -414,7 +412,6 @@ private RelationPlan createTableCreationPlan(Analysis analysis, Query query)
visibleFields(plan),
new CreateReference(catalogName, tableMetadata, newTableLayout),
columnNames,
tableMetadata.getColumns(),
newTableLayout,
statisticsMetadata);
}
Expand Down Expand Up @@ -515,7 +512,6 @@ private RelationPlan getInsertPlan(
plan.getFieldMappings(),
materializedViewRefreshWriterTarget.get(),
insertedTableColumnNames,
insertedColumns,
newTableLayout,
statisticsMetadata);
}
Expand All @@ -530,7 +526,6 @@ private RelationPlan getInsertPlan(
plan.getFieldMappings(),
insertTarget,
insertedTableColumnNames,
insertedColumns,
newTableLayout,
statisticsMetadata);
}
Expand Down Expand Up @@ -595,7 +590,6 @@ private RelationPlan createTableWriterPlan(
List<Symbol> symbols,
WriterTarget target,
List<String> columnNames,
List<ColumnMetadata> columnMetadataList,
Optional<TableLayout> writeTableLayout,
TableStatisticsMetadata statisticsMetadata)
{
Expand Down Expand Up @@ -628,12 +622,6 @@ private RelationPlan createTableWriterPlan(
Map<String, Symbol> columnToSymbolMap = zip(columnNames.stream(), symbols.stream(), SimpleImmutableEntry::new)
.collect(toImmutableMap(Entry::getKey, Entry::getValue));

Set<Symbol> notNullColumnSymbols = columnMetadataList.stream()
.filter(column -> !column.isNullable())
.map(ColumnMetadata::getName)
.map(columnToSymbolMap::get)
.collect(toImmutableSet());

if (!statisticsMetadata.isEmpty()) {
TableStatisticAggregation result = statisticsAggregationPlanner.createStatisticsAggregation(statisticsMetadata, columnToSymbolMap);

Expand All @@ -655,7 +643,6 @@ private RelationPlan createTableWriterPlan(
symbolAllocator.newSymbol("fragment", VARBINARY),
symbols,
columnNames,
notNullColumnSymbols,
partitioningScheme,
preferredPartitioningScheme,
Optional.of(partialAggregation),
Expand All @@ -678,7 +665,6 @@ private RelationPlan createTableWriterPlan(
symbolAllocator.newSymbol("fragment", VARBINARY),
symbols,
columnNames,
notNullColumnSymbols,
partitioningScheme,
preferredPartitioningScheme,
Optional.empty(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ private Result enable(TableWriterNode node)
node.getFragmentSymbol(),
node.getColumns(),
node.getColumnNames(),
node.getNotNullColumnSymbols(),
node.getPreferredPartitioningScheme(),
Optional.empty(),
node.getStatisticsAggregation(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ public PlanNode visitTableWriter(TableWriterNode node, RewriteContext<Optional<W
node.getFragmentSymbol(),
node.getColumns(),
node.getColumnNames(),
node.getNotNullColumnSymbols(),
node.getPartitioningScheme(),
node.getPreferredPartitioningScheme(),
node.getStatisticsAggregation(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,6 @@ public TableWriterNode map(TableWriterNode node, PlanNode source, PlanNodeId new
map(node.getFragmentSymbol()),
map(node.getColumns()),
node.getColumnNames(),
node.getNotNullColumnSymbols(),
node.getPartitioningScheme().map(partitioningScheme -> map(partitioningScheme, source.getOutputSymbols())),
node.getPreferredPartitioningScheme().map(partitioningScheme -> map(partitioningScheme, source.getOutputSymbols())),
node.getStatisticsAggregation().map(this::map),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import com.fasterxml.jackson.annotation.JsonSubTypes;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import io.trino.Session;
import io.trino.metadata.InsertTableHandle;
Expand All @@ -43,7 +42,6 @@

import java.util.List;
import java.util.Optional;
import java.util.Set;

import static com.google.common.base.Preconditions.checkArgument;
import static java.util.Objects.requireNonNull;
Expand All @@ -58,7 +56,6 @@ public class TableWriterNode
private final Symbol fragmentSymbol;
private final List<Symbol> columns;
private final List<String> columnNames;
private final Set<Symbol> notNullColumnSymbols;
private final Optional<PartitioningScheme> partitioningScheme;
private final Optional<PartitioningScheme> preferredPartitioningScheme;
private final Optional<StatisticAggregations> statisticsAggregation;
Expand All @@ -74,7 +71,6 @@ public TableWriterNode(
@JsonProperty("fragmentSymbol") Symbol fragmentSymbol,
@JsonProperty("columns") List<Symbol> columns,
@JsonProperty("columnNames") List<String> columnNames,
@JsonProperty("notNullColumnSymbols") Set<Symbol> notNullColumnSymbols,
@JsonProperty("partitioningScheme") Optional<PartitioningScheme> partitioningScheme,
@JsonProperty("preferredPartitioningScheme") Optional<PartitioningScheme> preferredPartitioningScheme,
@JsonProperty("statisticsAggregation") Optional<StatisticAggregations> statisticsAggregation,
Expand All @@ -92,7 +88,6 @@ public TableWriterNode(
this.fragmentSymbol = requireNonNull(fragmentSymbol, "fragmentSymbol is null");
this.columns = ImmutableList.copyOf(columns);
this.columnNames = ImmutableList.copyOf(columnNames);
this.notNullColumnSymbols = ImmutableSet.copyOf(requireNonNull(notNullColumnSymbols, "notNullColumnSymbols is null"));
this.partitioningScheme = requireNonNull(partitioningScheme, "partitioningScheme is null");
this.preferredPartitioningScheme = requireNonNull(preferredPartitioningScheme, "preferredPartitioningScheme is null");
this.statisticsAggregation = requireNonNull(statisticsAggregation, "statisticsAggregation is null");
Expand Down Expand Up @@ -146,12 +141,6 @@ public List<String> getColumnNames()
return columnNames;
}

@JsonProperty
public Set<Symbol> getNotNullColumnSymbols()
{
return notNullColumnSymbols;
}

@JsonProperty
public Optional<PartitioningScheme> getPartitioningScheme()
{
Expand Down Expand Up @@ -197,7 +186,7 @@ public <R, C> R accept(PlanVisitor<R, C> visitor, C context)
@Override
public PlanNode replaceChildren(List<PlanNode> newChildren)
{
return new TableWriterNode(getId(), Iterables.getOnlyElement(newChildren), target, rowCountSymbol, fragmentSymbol, columns, columnNames, notNullColumnSymbols, partitioningScheme, preferredPartitioningScheme, statisticsAggregation, statisticsAggregationDescriptor);
return new TableWriterNode(getId(), Iterables.getOnlyElement(newChildren), target, rowCountSymbol, fragmentSymbol, columns, columnNames, partitioningScheme, preferredPartitioningScheme, statisticsAggregation, statisticsAggregationDescriptor);
}

@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, property = "@type")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,6 @@ private Operator createTableWriterOperator(
Session session,
DriverContext driverContext)
{
List<String> notNullColumnNames = new ArrayList<>(1);
notNullColumnNames.add(null);
SchemaTableName schemaTableName = new SchemaTableName("testSchema", "testTable");
TableWriterOperatorFactory factory = new TableWriterOperatorFactory(
0,
Expand All @@ -300,7 +298,6 @@ private Operator createTableWriterOperator(
schemaTableName,
false),
ImmutableList.of(0),
notNullColumnNames,
session,
statisticsAggregation,
outputTypes);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1217,7 +1217,6 @@ public TableWriterNode tableWriter(
rowCountSymbol,
columns,
columnNames,
ImmutableSet.of(),
partitioningScheme,
preferredPartitioningScheme,
Optional.empty(),
Expand All @@ -1241,7 +1240,6 @@ public TableWriterNode tableWriter(
symbol("fragment", VARBINARY),
columns,
columnNames,
ImmutableSet.of(),
partitioningScheme,
preferredPartitioningScheme,
statisticAggregations,
Expand Down

0 comments on commit a631276

Please sign in to comment.