Skip to content

Commit

Permalink
Add a test function for MapExpression
Browse files Browse the repository at this point in the history
  • Loading branch information
fang-xing-esql committed Dec 17, 2024
1 parent dae1ae4 commit 42dd838
Show file tree
Hide file tree
Showing 14 changed files with 502 additions and 240 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,30 @@ public class EntryExpression extends Expression {
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
Expression.class,
"EntryExpression",
EntryExpression::new
EntryExpression::readFrom
);

private final Literal key;
static final NamedWriteableRegistry.Entry ENTRY_EXPRESSION_ENTRY = new NamedWriteableRegistry.Entry(
EntryExpression.class,
"EntryExpression",
EntryExpression::readFrom
);

private final Expression key;

private final Literal value;
private final Expression value;

public EntryExpression(Source source, Literal key, Literal value) {
public EntryExpression(Source source, Expression key, Expression value) {
super(source, List.of(key, value));
this.key = key;
this.value = value;
}

private EntryExpression(StreamInput in) throws IOException {
this(
private static EntryExpression readFrom(StreamInput in) throws IOException {
return new EntryExpression(
Source.readFrom((StreamInput & PlanStreamInput) in),
in.readNamedWriteable(Literal.class),
in.readNamedWriteable(Literal.class)
in.readNamedWriteable(Expression.class),
in.readNamedWriteable(Expression.class)
);
}

Expand All @@ -60,7 +66,7 @@ public String getWriteableName() {

@Override
public Expression replaceChildren(List<Expression> newChildren) {
return new EntryExpression(source(), (Literal) newChildren.get(0), (Literal) newChildren.get(1));
return new EntryExpression(source(), newChildren.get(0), newChildren.get(1));
}

@Override
Expand All @@ -81,16 +87,6 @@ public DataType dataType() {
return value.dataType();
}

@Override
public boolean foldable() {
return key.foldable() && value.foldable();
}

@Override
public Object fold() {
return toString();
}

@Override
public Nullability nullable() {
return Nullability.FALSE;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
entries.addAll(expressions());
entries.addAll(namedExpressions());
entries.addAll(attributes());
entries.addAll(mapExpressions());
return entries;
}

Expand All @@ -30,6 +29,7 @@ public static List<NamedWriteableRegistry.Entry> expressions() {
entries.add(new NamedWriteableRegistry.Entry(Expression.class, e.name, in -> (Expression) e.reader.read(in)));
}
entries.add(Literal.ENTRY);
entries.addAll(mapExpressions());
return entries;
}

Expand All @@ -48,6 +48,6 @@ public static List<NamedWriteableRegistry.Entry> attributes() {
}

public static List<NamedWriteableRegistry.Entry> mapExpressions() {
return List.of(EntryExpression.ENTRY, MapExpression.ENTRY);
return List.of(EntryExpression.ENTRY_EXPRESSION_ENTRY, EntryExpression.ENTRY, MapExpression.ENTRY);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
*/
package org.elasticsearch.xpack.esql.core.expression;

import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
Expand All @@ -30,7 +31,7 @@ public class MapExpression extends Expression {
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
Expression.class,
"MapExpression",
MapExpression::new
MapExpression::readFrom
);

private final List<EntryExpression> entries;
Expand All @@ -44,14 +45,17 @@ public MapExpression(Source source, List<EntryExpression> entries) {
.collect(Collectors.toMap(EntryExpression::key, EntryExpression::value, (x, y) -> y, LinkedHashMap::new));
}

private MapExpression(StreamInput in) throws IOException {
this(Source.readFrom((StreamInput & PlanStreamInput) in), in.readNamedWriteableCollectionAsList(EntryExpression.class));
private static MapExpression readFrom(StreamInput in) throws IOException {
return new MapExpression(
Source.readFrom((StreamInput & PlanStreamInput) in),
in.readNamedWriteableCollectionAsList(EntryExpression.class)
);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
source().writeTo(out);
out.writeNamedWriteableCollection(children());
out.writeNamedWriteableCollection(entries);
}

@Override
Expand All @@ -77,24 +81,25 @@ public Map<Expression, Expression> map() {
return map;
}

@Override
public DataType dataType() {
return UNSUPPORTED;
}

@Override
public boolean foldable() {
for (EntryExpression ee : entries) {
if (ee.foldable() == false) {
return false;
public Expression getKey(String key) {
for (EntryExpression entry : entries) {
Expression k = entry.key();
if (k.foldable()) {
Object o = k.fold();
if (o instanceof BytesRef br) {
o = br.utf8ToString();
}
if (o.toString().equalsIgnoreCase(key)) {
return entry.value();
}
}
}
return true;
return null;
}

@Override
public Object fold() {
return map();
public DataType dataType() {
return UNSUPPORTED;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,61 +50,62 @@ c:long
1
;

mapKeysEvalIndex
logWithBaseInMapEval
required_capability: optional_named_argument_map_for_function
FROM employees
| EVAL k = map_keys({"option1":"value1", "option2":2, "option3":3.0, "option4":true, "option5":[3.0, 4.0]})
| WHERE emp_no == 10001
| KEEP emp_no, k
ROW value = 8.0
| EVAL l = log_with_base_in_map(value, {"base":2.0})
;

emp_no:integer |k:keyword
10001 |"option1, option2, option3, option4, option5"
value: double |l:double
8.0 |3.0
;

mapKeysWhereTrueIndex
logWithBaseInMapEvalIndex
required_capability: optional_named_argument_map_for_function
FROM employees
| WHERE map_keys({"option1":"value1", "option2":2, "option3":3.0, "option4":true, "option5":[1, 2]}) like "option*"
AND emp_no == 10001
| KEEP emp_no
| WHERE emp_no IN (10001, 10003)
| EVAL l = log_with_base_in_map(languages, {"base":2.0})
| KEEP emp_no, languages, l
| SORT emp_no
;

emp_no:integer
10001
emp_no:integer |languages:integer |l:double
10001 |2 |1.0
10003 |4 |2.0
;

mapKeysWhereFalseIndex
logWithBaseInMapWhereTrueIndex
required_capability: optional_named_argument_map_for_function
FROM employees
| WHERE map_keys({"option1":"value1", "option2":2, "option3":3.0, "option4":true, "option5":["a", "b"]}) like "false*"
AND emp_no == 10001
| KEEP emp_no
| WHERE emp_no IN (10001, 10003) AND log_with_base_in_map(languages, {"base":2.0}) > 1
| KEEP emp_no, languages
| SORT emp_no
;

emp_no:integer
emp_no:integer |languages:integer
10003 |4
;

mapKeysSortIndex
logWithBaseInMapWhereFalseIndex
required_capability: optional_named_argument_map_for_function
FROM employees
| SORT emp_no, map_keys({"option1":"value1", "option2":2, "option3":3.0, "option4":true, "option5":[true, false]}) desc
| KEEP emp_no
| LIMIT 2
| WHERE emp_no IN (10001, 10003) AND log_with_base_in_map(languages, {"base":2.0}) < 0
| KEEP emp_no, languages
| SORT emp_no
;

emp_no:integer
10001
10002
emp_no:integer |languages:integer
;

mapKeysStatsIndex
logWithBaseInMapSortIndex
required_capability: optional_named_argument_map_for_function
FROM employees
| STATS c = count(*) BY map_keys({"option1":"value1", "option2":2, "option3":3.0, "option4":true, "option5":[3.0, 4.0]})
| KEEP c
| WHERE emp_no IN (10001, 10003)
| SORT log_with_base_in_map(languages, {"base":2.0}) desc
| KEEP emp_no
;

c:long
100
emp_no:integer
10003
10001
;

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 42dd838

Please sign in to comment.