Skip to content

Commit

Permalink
Fix NameId Mapper in PlanStreamInput (elastic#99295)
Browse files Browse the repository at this point in the history
This commit fixes the NameId mapper in PlanStreamInput.

The current mapper function just materializes a NameId instance with the stream's primitive long value. This can lead to a potential issue during re-planning if the local data node planner just happens to use an id, retrieved automatically from the global counter, during the course of its planning. This can lead to confusion, where the same id is used for different logical attributes.

The updated Mapper keeps a map of seen stream Ids to NameId instances. The no-args NameId constructor is used for absent entries, as it will automatically select and increment an id from the global counter, thus avoiding potential conflicts between the id in the stream and id's during local re-planning on the data node.
  • Loading branch information
ChrisHegarty authored Sep 8, 2023
1 parent e8f8a0a commit d3f5584
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,34 @@

import java.io.IOException;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.function.LongFunction;
import java.util.function.Supplier;

/**
* A customized stream input used to deserialize ESQL physical plan fragments. Complements stream
* input with methods that read plan nodes, Attributes, Expressions, etc.
*/
public final class PlanStreamInput extends NamedWriteableAwareStreamInput {

private static final LongFunction<NameId> DEFAULT_NAME_ID_FUNC = NameId::new;
/**
* A Mapper of stream named id, represented as a primitive long value, to NameId instance.
* The no-args NameId constructor is used for absent entries, as it will automatically select
* and increment an id from the global counter, thus avoiding potential conflicts between the
* id in the stream and id's during local re-planning on the data node.
*/
static final class NameIdMapper implements LongFunction<NameId> {
final Map<Long, NameId> seen = new HashMap<>();

@Override
public NameId apply(long streamNameId) {
return seen.computeIfAbsent(streamNameId, k -> new NameId());
}
}

private static final Supplier<LongFunction<NameId>> DEFAULT_NAME_ID_FUNC = NameIdMapper::new;

private final PlanNameRegistry registry;

Expand All @@ -51,21 +69,11 @@ public PlanStreamInput(
PlanNameRegistry registry,
NamedWriteableRegistry namedWriteableRegistry,
EsqlConfiguration configuration
) {
this(streamInput, registry, namedWriteableRegistry, configuration, DEFAULT_NAME_ID_FUNC);
}

public PlanStreamInput(
StreamInput streamInput,
PlanNameRegistry registry,
NamedWriteableRegistry namedWriteableRegistry,
EsqlConfiguration configuration,
LongFunction<NameId> nameIdFunction
) {
super(streamInput, namedWriteableRegistry);
this.registry = registry;
this.nameIdFunction = nameIdFunction;
this.configuration = configuration;
this.nameIdFunction = DEFAULT_NAME_ID_FUNC.get();
}

NameId nameIdFromLongValue(long value) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,20 +169,19 @@ public void testUnsupportedAttributeSimple() throws IOException {
"foo",
new UnsupportedEsField("foo", "keyword"),
"field not supported",
new NameId(53)
new NameId()
);
BytesStreamOutput bso = new BytesStreamOutput();
PlanStreamOutput out = new PlanStreamOutput(bso, planNameRegistry);
PlanNamedTypes.writeUnsupportedAttr(out, orig);
var deser = PlanNamedTypes.readUnsupportedAttr(planStreamInput(bso));
var in = planStreamInput(bso);
var deser = PlanNamedTypes.readUnsupportedAttr(in);
EqualsHashCodeTestUtils.checkEqualsAndHashCode(orig, unused -> deser);
assertThat(deser.id(), equalTo(orig.id()));
assertThat(deser.id(), equalTo(in.nameIdFromLongValue(Long.parseLong(orig.id().toString()))));
}

public void testUnsupportedAttribute() {
Stream.generate(PlanNamedTypesTests::randomUnsupportedAttribute)
.limit(100)
.forEach(PlanNamedTypesTests::assertNamedExpressionAndId);
Stream.generate(PlanNamedTypesTests::randomUnsupportedAttribute).limit(100).forEach(PlanNamedTypesTests::assertNamedExpression);
}

public void testFieldAttributeSimple() throws IOException {
Expand All @@ -194,19 +193,20 @@ public void testFieldAttributeSimple() throws IOException {
randomEsField(),
null, // qualifier, can be null
Nullability.TRUE,
new NameId(53),
new NameId(),
true // synthetic
);
BytesStreamOutput bso = new BytesStreamOutput();
PlanStreamOutput out = new PlanStreamOutput(bso, planNameRegistry);
PlanNamedTypes.writeFieldAttribute(out, orig);
var deser = PlanNamedTypes.readFieldAttribute(planStreamInput(bso));
var in = planStreamInput(bso);
var deser = PlanNamedTypes.readFieldAttribute(in);
EqualsHashCodeTestUtils.checkEqualsAndHashCode(orig, unused -> deser);
assertThat(deser.id(), equalTo(orig.id()));
assertThat(deser.id(), equalTo(in.nameIdFromLongValue(Long.parseLong(orig.id().toString()))));
}

public void testFieldAttribute() {
Stream.generate(PlanNamedTypesTests::randomFieldAttribute).limit(100).forEach(PlanNamedTypesTests::assertNamedExpressionAndId);
Stream.generate(PlanNamedTypesTests::randomFieldAttribute).limit(100).forEach(PlanNamedTypesTests::assertNamedExpression);
}

public void testKeywordEsFieldSimple() throws IOException {
Expand Down Expand Up @@ -353,9 +353,10 @@ public void testAliasSimple() throws IOException {
BytesStreamOutput bso = new BytesStreamOutput();
PlanStreamOutput out = new PlanStreamOutput(bso, planNameRegistry);
PlanNamedTypes.writeAlias(out, orig);
var deser = PlanNamedTypes.readAlias(planStreamInput(bso));
var in = planStreamInput(bso);
var deser = PlanNamedTypes.readAlias(in);
EqualsHashCodeTestUtils.checkEqualsAndHashCode(orig, unused -> deser);
assertThat(orig.id(), equalTo(deser.id()));
assertThat(deser.id(), equalTo(in.nameIdFromLongValue(Long.parseLong(orig.id().toString()))));
}

public void testLiteralSimple() throws IOException {
Expand Down Expand Up @@ -404,10 +405,9 @@ public void testDissectParserSimple() throws IOException {
EqualsHashCodeTestUtils.checkEqualsAndHashCode(orig, unused -> deser);
}

private static void assertNamedExpressionAndId(NamedExpression origObj) {
private static void assertNamedExpression(NamedExpression origObj) {
var deserObj = serializeDeserialize(origObj, PlanStreamOutput::writeExpression, PlanStreamInput::readNamedExpression);
EqualsHashCodeTestUtils.checkEqualsAndHashCode(origObj, unused -> deserObj);
assertThat(deserObj.id(), equalTo(origObj.id()));
}

private static <T> void assertNamedType(Class<T> type, T origObj) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.io.stream;

import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.ql.expression.NameId;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;

public class PlanStreamInputTests extends ESTestCase {

public void testMapperSimple() {
var mapper = new PlanStreamInput.NameIdMapper();

NameId first = mapper.apply(1L);
NameId second = mapper.apply(1L);
assertThat(second, equalTo(first));

NameId third = mapper.apply(2L);
NameId fourth = mapper.apply(2L);
assertThat(third, not(equalTo(second)));
assertThat(fourth, equalTo(third));

assertThat(mapper.seen.size(), is(2));
}

public void testMapper() {
List<Long> longs = randomLongsListOfSize(100);
List<Long> nameIds = new ArrayList<>();
for (long l : longs) {
nameIds.add(l);
if (randomBoolean()) { // randomly insert additional values from the known list
int idx = randomIntBetween(0, longs.size() - 1);
nameIds.add(longs.get(idx));
}
}

var mapper = new PlanStreamInput.NameIdMapper();
List<NameId> mappedIds = nameIds.stream().map(mapper::apply).toList();
assertThat(mappedIds.size(), is(nameIds.size()));
// there must be exactly 100 distinct elements
assertThat(mapper.seen.size(), is(100));
assertThat(mappedIds.stream().distinct().count(), is(100L));

// The pre-mapped name id pattern must match that of the mapped one
Map<Long, List<Long>> nameIdsSeen = new LinkedHashMap<>(); // insertion order
for (int i = 0; i < nameIds.size(); i++) {
long value = nameIds.get(i);
nameIdsSeen.computeIfAbsent(value, k -> new ArrayList<>());
nameIdsSeen.get(value).add((long) i);
}
assert nameIdsSeen.size() == 100;

Map<NameId, List<Long>> mappedSeen = new LinkedHashMap<>(); // insertion order
for (int i = 0; i < mappedIds.size(); i++) {
NameId nameId = mappedIds.get(i);
mappedSeen.computeIfAbsent(nameId, k -> new ArrayList<>());
mappedSeen.get(nameId).add((long) i);
}
assert mappedSeen.size() == 100;

var mappedSeenItr = mappedSeen.values().iterator();
for (List<Long> indexes : nameIdsSeen.values()) {
assertThat(indexes, equalTo(mappedSeenItr.next()));
}
}

List<Long> randomLongsListOfSize(int size) {
Set<Long> longs = new HashSet<>();
while (longs.size() < size) {
longs.add(randomLong());
}
return longs.stream().toList();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,6 @@ public NameId() {
this.id = COUNTER.incrementAndGet();
}

public NameId(long id) {
this.id = id;
}

@Override
public int hashCode() {
return Objects.hash(id);
Expand Down

0 comments on commit d3f5584

Please sign in to comment.