Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
fang-xing-esql committed Dec 21, 2024
1 parent a81a81e commit 5fef44d
Show file tree
Hide file tree
Showing 11 changed files with 74 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,6 @@ public class EntryExpression extends Expression {
EntryExpression::readFrom
);

static final NamedWriteableRegistry.Entry ENTRY_EXPRESSION_ENTRY = new NamedWriteableRegistry.Entry(
EntryExpression.class,
"EntryExpression",
EntryExpression::readFrom
);

private final Expression key;

private final Expression value;
Expand Down Expand Up @@ -112,6 +106,6 @@ public boolean equals(Object obj) {

@Override
public String toString() {
return key.fold() + ":" + value.fold();
return key.toString() + ":" + value.toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,6 @@ public static List<NamedWriteableRegistry.Entry> attributes() {
}

public static List<NamedWriteableRegistry.Entry> mapExpressions() {
return List.of(EntryExpression.ENTRY_EXPRESSION_ENTRY, EntryExpression.ENTRY, MapExpression.ENTRY);
return List.of(EntryExpression.ENTRY, MapExpression.ENTRY);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.xpack.esql.core.util.PlanStreamInput;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
Expand All @@ -37,34 +38,39 @@ public class MapExpression extends Expression implements Map<Expression, Express
MapExpression::readFrom
);

private final List<EntryExpression> entries;
private final List<EntryExpression> entryExpressions;

private final Map<Expression, Expression> map;

private final Map<Object, Expression> foldedMap;
private final Map<Object, Expression> keyFoldedMap;

public MapExpression(Source source, List<EntryExpression> entries) {
super(source, entries.stream().map(Expression.class::cast).toList());
this.entries = entries;
this.map = entries.stream()
public MapExpression(Source source, List<Expression> entries) {
super(source, entries);
int entryCount = entries.size() / 2;
this.entryExpressions = new ArrayList<>(entryCount);
for (int i = 0; i < entryCount; i++) {
Expression key = entries.get(i * 2);
entryExpressions.add(new EntryExpression(key.source(), key, entries.get(i * 2 + 1)));
}
this.map = this.entryExpressions.stream()
.collect(Collectors.toMap(EntryExpression::key, EntryExpression::value, (x, y) -> y, LinkedHashMap::new));
// create a foldedMap by removing source, it makes the retrieval of value easier
this.foldedMap = entries.stream()
// create a map with key folded and source removed to make the retrieval of value easier
this.keyFoldedMap = this.entryExpressions.stream()
.filter(e -> e.key().foldable() && e.key().fold() != null)
.collect(Collectors.toMap(e -> e.key().fold(), EntryExpression::value, (x, y) -> y, LinkedHashMap::new));
}

private static MapExpression readFrom(StreamInput in) throws IOException {
return new MapExpression(
Source.readFrom((StreamInput & PlanStreamInput) in),
in.readNamedWriteableCollectionAsList(EntryExpression.class)
in.readNamedWriteableCollectionAsList(Expression.class)
);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
source().writeTo(out);
out.writeNamedWriteableCollection(entries);
out.writeNamedWriteableCollection(children());
}

@Override
Expand All @@ -74,26 +80,22 @@ public String getWriteableName() {

@Override
public MapExpression replaceChildren(List<Expression> newChildren) {
return new MapExpression(source(), newChildren.stream().map(EntryExpression.class::cast).toList());
return new MapExpression(source(), newChildren);
}

@Override
protected NodeInfo<MapExpression> info() {
return NodeInfo.create(this, MapExpression::new, entries());
return NodeInfo.create(this, MapExpression::new, children());
}

public List<EntryExpression> entries() {
return entries;
public List<EntryExpression> entryExpressions() {
return entryExpressions;
}

public Map<Expression, Expression> map() {
return map;
}

public Map<Object, Expression> foldedMap() {
return foldedMap;
}

@Override
public Nullability nullable() {
return Nullability.FALSE;
Expand All @@ -106,7 +108,7 @@ public DataType dataType() {

@Override
public int hashCode() {
return Objects.hash(entries);
return Objects.hash(entryExpressions);
}

@Override
Expand All @@ -128,8 +130,8 @@ public Expression get(Object key) {
return map.get(key);
} else {
key = key instanceof String s ? s.toLowerCase(Locale.ROOT) : key;
// the literal key could be converted to BytesRef by ConvertStringToByteRef
return foldedMap.containsKey(key) ? foldedMap.get(key) : foldedMap.get(new BytesRef(key.toString()));
// the key(literal) could be converted to BytesRef by ConvertStringToByteRef
return keyFoldedMap.containsKey(key) ? keyFoldedMap.get(key) : keyFoldedMap.get(new BytesRef(key.toString()));
}
}

Expand Down Expand Up @@ -188,12 +190,12 @@ public boolean equals(Object obj) {
}

MapExpression other = (MapExpression) obj;
return Objects.equals(entries, other.entries);
return Objects.equals(entryExpressions, other.entryExpressions);
}

@Override
public String toString() {
String str = entries.stream().map(String::valueOf).collect(Collectors.joining(", "));
String str = entryExpressions.stream().map(String::valueOf).collect(Collectors.joining(", "));
return "{ " + str + " }";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ final <E> T transform(Function<? super E, ? extends E> rule, Class<E> typeToken)
List<?> children = node.children();

Function<Object, Object> realRule = p -> {
if (false == children.equals(p) && false == children.contains(p) && (p == null || typeToken.isInstance(p))) {
if (p != children && false == children.contains(p) && (p == null || typeToken.isInstance(p))) {
return rule.apply(typeToken.cast(p));
}
return p;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ public Expression base() {
}

private TypeResolution validateOptions() {
for (EntryExpression entry : ((MapExpression) map).entries()) {
for (EntryExpression entry : ((MapExpression) map).entryExpressions()) {
Expression key = entry.key();
Expression value = entry.value();
TypeResolution resolution = isFoldable(key, sourceText(), SECOND).and(isFoldable(value, sourceText(), SECOND));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ protected NodeInfo<? extends Expression> info() {
@Override
public Object fold() {
if (map instanceof MapExpression me) {
return (long) me.entries().size();
return (long) me.entryExpressions().size();
} else {
throw new IllegalArgumentException(
LoggerMessageFormat.format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import org.elasticsearch.xpack.esql.core.InvalidArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.EntryExpression;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.MapExpression;
Expand Down Expand Up @@ -622,7 +621,7 @@ public String visitFunctionName(EsqlBaseParser.FunctionNameContext ctx) {

@Override
public MapExpression visitFunctionArgumentWithName(EsqlBaseParser.FunctionArgumentWithNameContext ctx) {
List<EntryExpression> namedArgs = new ArrayList<>(ctx.mapExpression().entryExpression().size());
List<Expression> namedArgs = new ArrayList<>(ctx.mapExpression().entryExpression().size());
List<EsqlBaseParser.EntryExpressionContext> kvCtx = ctx.mapExpression().entryExpression();
for (EsqlBaseParser.EntryExpressionContext entry : kvCtx) {
String key = visitString(entry.string()).fold().toString().toLowerCase(Locale.ROOT); // make key case-insensitive
Expand All @@ -631,8 +630,8 @@ public MapExpression visitFunctionArgumentWithName(EsqlBaseParser.FunctionArgume
if (l.dataType() == NULL) {
throw new ParsingException(source(ctx), "Invalid named function argument [{}], NULL is not supported", l);
}
EntryExpression ee = new EntryExpression(Source.EMPTY, new Literal(source(entry.string()), key, KEYWORD), l);
namedArgs.add(ee);
namedArgs.add(new Literal(source(entry.string()), key, KEYWORD));
namedArgs.add(l);
} else {
throw new ParsingException(source(ctx), "Invalid named function argument [{}], only constant value is supported", value);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2628,8 +2628,8 @@ public void testMapExpressionAsFunctionArgument() {
a = as(eval.fields().get(0), Alias.class);
LogWithBaseInMap l = as(a.child(), LogWithBaseInMap.class);
me = as(l.base(), MapExpression.class);
assertEquals(1, me.entries().size());
EntryExpression ee = as(me.entries().get(0), EntryExpression.class);
assertEquals(1, me.entryExpressions().size());
EntryExpression ee = as(me.entryExpressions().get(0), EntryExpression.class);
assertEquals(new Literal(EMPTY, "base", DataType.KEYWORD), ee.key());
assertEquals(new Literal(EMPTY, 2.0, DataType.DOUBLE), ee.value());
assertEquals(DataType.DOUBLE, ee.dataType());
Expand All @@ -2655,13 +2655,13 @@ private void verifyMapExpression(MapExpression me) {
Literal option2 = new Literal(EMPTY, "option2", DataType.KEYWORD);
Literal value2 = new Literal(EMPTY, List.of(1, 2, 3), DataType.INTEGER);

assertEquals(2, me.entries().size());
EntryExpression ee = as(me.entries().get(0), EntryExpression.class);
assertEquals(2, me.entryExpressions().size());
EntryExpression ee = as(me.entryExpressions().get(0), EntryExpression.class);
assertEquals(option1, ee.key());
assertEquals(value1, ee.value());
assertEquals(value1.dataType(), ee.dataType());

ee = as(me.entries().get(1), EntryExpression.class);
ee = as(me.entryExpressions().get(1), EntryExpression.class);
assertEquals(option2, ee.key());
assertEquals(value2, ee.value());
assertEquals(value2.dataType(), ee.dataType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import com.carrotsearch.randomizedtesting.annotations.Name;
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;

import org.elasticsearch.xpack.esql.core.expression.EntryExpression;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.MapExpression;
Expand Down Expand Up @@ -42,26 +41,27 @@ public static Iterable<Object[]> parameters() {
ints(suppliers);
longs(suppliers);
doubles(suppliers);
// Add null cases before the rest of the error cases, so messages are correct.
// suppliers = anyNullIsNull(true, suppliers);
// Negative cases
suppliers = anyNullIsNull(true, suppliers);

// return parameterSuppliersFromTypedData(errorsForCasesWithoutExamples(suppliers, (v, p) -> "numeric"));
return parameterSuppliersFromTypedData(suppliers);
}

private static void ints(List<TestCaseSupplier> suppliers) {
TestCaseSupplier supplier = new TestCaseSupplier(List.of(INTEGER), () -> {
int number = randomIntBetween(2, 100);
int base = randomIntBetween(2, 100);
EntryExpression entry = new EntryExpression(
Source.EMPTY,
new Literal(Source.EMPTY, "base", KEYWORD),
new Literal(Source.EMPTY, base, INTEGER)
);
List<TestCaseSupplier.TypedData> values = new ArrayList<>();
values.add(new TestCaseSupplier.TypedData(number, INTEGER, "number"));
values.add(new TestCaseSupplier.TypedData(new MapExpression(Source.EMPTY, List.of(entry)), UNSUPPORTED, "base").forceLiteral());
values.add(
new TestCaseSupplier.TypedData(
new MapExpression(
Source.EMPTY,
List.of(new Literal(Source.EMPTY, "base", KEYWORD), new Literal(Source.EMPTY, base, INTEGER))
),
UNSUPPORTED,
"base"
).forceLiteral()
);
return new TestCaseSupplier.TestCase(
values,
"LogWithBaseInMapEvaluator[value=CastIntToDoubleEvaluator[v=Attribute[channel=0]], base=" + (double) base + "]",
Expand All @@ -76,14 +76,18 @@ private static void longs(List<TestCaseSupplier> suppliers) {
TestCaseSupplier supplier = new TestCaseSupplier(List.of(LONG), () -> {
long number = randomLongBetween(2L, 100L);
long base = randomLongBetween(2L, 100L);
EntryExpression entry = new EntryExpression(
Source.EMPTY,
new Literal(Source.EMPTY, "base", KEYWORD),
new Literal(Source.EMPTY, base, LONG)
);
List<TestCaseSupplier.TypedData> values = new ArrayList<>();
values.add(new TestCaseSupplier.TypedData(number, LONG, "number"));
values.add(new TestCaseSupplier.TypedData(new MapExpression(Source.EMPTY, List.of(entry)), UNSUPPORTED, "base").forceLiteral());
values.add(
new TestCaseSupplier.TypedData(
new MapExpression(
Source.EMPTY,
List.of(new Literal(Source.EMPTY, "base", KEYWORD), new Literal(Source.EMPTY, base, LONG))
),
UNSUPPORTED,
"base"
).forceLiteral()
);
return new TestCaseSupplier.TestCase(
values,
"LogWithBaseInMapEvaluator[value=CastLongToDoubleEvaluator[v=Attribute[channel=0]], base=" + (double) base + "]",
Expand All @@ -98,14 +102,18 @@ private static void doubles(List<TestCaseSupplier> suppliers) {
TestCaseSupplier supplier = new TestCaseSupplier(List.of(DOUBLE), () -> {
double number = Maths.round(randomDoubleBetween(2d, 100d, true), 2).doubleValue();
double base = Maths.round(randomDoubleBetween(2d, 100d, true), 2).doubleValue();
EntryExpression entry = new EntryExpression(
Source.EMPTY,
new Literal(Source.EMPTY, "base", KEYWORD),
new Literal(Source.EMPTY, base, DOUBLE)
);
List<TestCaseSupplier.TypedData> values = new ArrayList<>();
values.add(new TestCaseSupplier.TypedData(number, DOUBLE, "number"));
values.add(new TestCaseSupplier.TypedData(new MapExpression(Source.EMPTY, List.of(entry)), UNSUPPORTED, "base").forceLiteral());
values.add(
new TestCaseSupplier.TypedData(
new MapExpression(
Source.EMPTY,
List.of(new Literal(Source.EMPTY, "base", KEYWORD), new Literal(Source.EMPTY, base, DOUBLE))
),
UNSUPPORTED,
"base"
).forceLiteral()
);
return new TestCaseSupplier.TestCase(
values,
"LogWithBaseInMapEvaluator[value=Attribute[channel=0], base=" + base + "]",
Expand All @@ -118,17 +126,6 @@ private static void doubles(List<TestCaseSupplier> suppliers) {

@Override
protected Expression build(Source source, List<Expression> args) {
// return new LogWithBaseInMap(source, args.get(0), randomBoolean() ? randomMapExpression() : null);
return new LogWithBaseInMap(source, args.get(0), args.size() > 1 ? args.get(1) : null);
}

private static MapExpression randomMapExpression() {
double base = randomDoubleBetween(2d, Double.MAX_VALUE, true);
EntryExpression entry = new EntryExpression(
Source.EMPTY,
new Literal(Source.EMPTY, "base", KEYWORD),
new Literal(Source.EMPTY, base, DOUBLE)
);
return new MapExpression(Source.EMPTY, List.of(entry));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6756,8 +6756,8 @@ public void testMapExpressionAsFunctionArgument() {
Alias a = as(eval.fields().get(0), Alias.class);
LogWithBaseInMap l = as(a.child(), LogWithBaseInMap.class);
MapExpression me = as(l.base(), MapExpression.class);
assertEquals(1, me.entries().size());
EntryExpression ee = as(me.entries().get(0), EntryExpression.class);
assertEquals(1, me.entryExpressions().size());
EntryExpression ee = as(me.entryExpressions().get(0), EntryExpression.class);
BytesRef key = as(ee.key().fold(), BytesRef.class);
assertEquals("base", key.utf8ToString());
assertEquals(new Literal(EMPTY, 2.0, DataType.DOUBLE), ee.value());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import org.elasticsearch.index.IndexMode;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.esql.VerificationException;
import org.elasticsearch.xpack.esql.core.expression.EntryExpression;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.MapExpression;
Expand Down Expand Up @@ -129,13 +128,13 @@ static Literal literalStrings(String... strings) {
}

static MapExpression mapExpression(Map<String, Object> keyValuePairs) {
List<EntryExpression> ees = new ArrayList<>(keyValuePairs.size());
List<Expression> ees = new ArrayList<>(keyValuePairs.size());
for (Map.Entry<String, Object> entry : keyValuePairs.entrySet()) {
String key = entry.getKey();
Object value = entry.getValue();
DataType type = (value instanceof List<?> l) ? DataType.fromJava(l.get(0)) : DataType.fromJava(value);
EntryExpression ee = new EntryExpression(EMPTY, new Literal(EMPTY, key, DataType.KEYWORD), new Literal(EMPTY, value, type));
ees.add(ee);
ees.add(new Literal(EMPTY, key, DataType.KEYWORD));
ees.add(new Literal(EMPTY, value, type));
}
return new MapExpression(EMPTY, ees);
}
Expand Down

0 comments on commit 5fef44d

Please sign in to comment.