From 42dd838831360cae78a43b286d8c10cc1a57c90c Mon Sep 17 00:00:00 2001 From: Fang Xing Date: Mon, 16 Dec 2024 20:48:47 -0500 Subject: [PATCH] Add a test function for MapExpression --- .../esql/core/expression/EntryExpression.java | 34 ++- .../expression/ExpressionCoreWritables.java | 4 +- .../esql/core/expression/MapExpression.java | 39 ++-- .../src/main/resources/map-functions.csv-spec | 61 +++--- .../scalar/map/LogWithBaseInMapEvaluator.java | 139 ++++++++++++ .../esql/expression/ExpressionWritables.java | 2 + .../function/EsqlFunctionRegistry.java | 8 +- .../esql/expression/function/MapParam.java | 10 +- .../function/scalar/map/LogWithBaseInMap.java | 204 ++++++++++++++++++ .../function/scalar/map/MapCount.java | 21 +- .../function/scalar/map/MapKeys.java | 123 ----------- .../xpack/esql/analysis/AnalyzerTests.java | 62 +++--- .../function/AbstractFunctionTestCase.java | 2 +- .../optimizer/LogicalPlanOptimizerTests.java | 33 +++ 14 files changed, 502 insertions(+), 240 deletions(-) create mode 100644 x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/map/LogWithBaseInMapEvaluator.java create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/map/LogWithBaseInMap.java delete mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/map/MapKeys.java diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/EntryExpression.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/EntryExpression.java index 44efa5141d8ae..012ba869a69c0 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/EntryExpression.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/EntryExpression.java @@ -25,24 +25,30 @@ public class EntryExpression extends Expression { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( Expression.class, "EntryExpression", - EntryExpression::new + EntryExpression::readFrom ); - private final Literal key; + static final NamedWriteableRegistry.Entry ENTRY_EXPRESSION_ENTRY = new NamedWriteableRegistry.Entry( + EntryExpression.class, + "EntryExpression", + EntryExpression::readFrom + ); + + private final Expression key; - private final Literal value; + private final Expression value; - public EntryExpression(Source source, Literal key, Literal value) { + public EntryExpression(Source source, Expression key, Expression value) { super(source, List.of(key, value)); this.key = key; this.value = value; } - private EntryExpression(StreamInput in) throws IOException { - this( + private static EntryExpression readFrom(StreamInput in) throws IOException { + return new EntryExpression( Source.readFrom((StreamInput & PlanStreamInput) in), - in.readNamedWriteable(Literal.class), - in.readNamedWriteable(Literal.class) + in.readNamedWriteable(Expression.class), + in.readNamedWriteable(Expression.class) ); } @@ -60,7 +66,7 @@ public String getWriteableName() { @Override public Expression replaceChildren(List newChildren) { - return new EntryExpression(source(), (Literal) newChildren.get(0), (Literal) newChildren.get(1)); + return new EntryExpression(source(), newChildren.get(0), newChildren.get(1)); } @Override @@ -81,16 +87,6 @@ public DataType dataType() { return value.dataType(); } - @Override - public boolean foldable() { - return key.foldable() && value.foldable(); - } - - @Override - public Object fold() { - return toString(); - } - @Override public Nullability nullable() { return Nullability.FALSE; diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/ExpressionCoreWritables.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/ExpressionCoreWritables.java index 19a9b460a405f..1a6467175dbef 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/ExpressionCoreWritables.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/ExpressionCoreWritables.java @@ -19,7 +19,6 @@ public static List getNamedWriteables() { entries.addAll(expressions()); entries.addAll(namedExpressions()); entries.addAll(attributes()); - entries.addAll(mapExpressions()); return entries; } @@ -30,6 +29,7 @@ public static List expressions() { entries.add(new NamedWriteableRegistry.Entry(Expression.class, e.name, in -> (Expression) e.reader.read(in))); } entries.add(Literal.ENTRY); + entries.addAll(mapExpressions()); return entries; } @@ -48,6 +48,6 @@ public static List attributes() { } public static List mapExpressions() { - return List.of(EntryExpression.ENTRY, MapExpression.ENTRY); + return List.of(EntryExpression.ENTRY_EXPRESSION_ENTRY, EntryExpression.ENTRY, MapExpression.ENTRY); } } diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/MapExpression.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/MapExpression.java index eb0dfa86318db..81966ea0ed5ba 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/MapExpression.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/MapExpression.java @@ -6,6 +6,7 @@ */ package org.elasticsearch.xpack.esql.core.expression; +import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -30,7 +31,7 @@ public class MapExpression extends Expression { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( Expression.class, "MapExpression", - MapExpression::new + MapExpression::readFrom ); private final List entries; @@ -44,14 +45,17 @@ public MapExpression(Source source, List entries) { .collect(Collectors.toMap(EntryExpression::key, EntryExpression::value, (x, y) -> y, LinkedHashMap::new)); } - private MapExpression(StreamInput in) throws IOException { - this(Source.readFrom((StreamInput & PlanStreamInput) in), in.readNamedWriteableCollectionAsList(EntryExpression.class)); + private static MapExpression readFrom(StreamInput in) throws IOException { + return new MapExpression( + Source.readFrom((StreamInput & PlanStreamInput) in), + in.readNamedWriteableCollectionAsList(EntryExpression.class) + ); } @Override public void writeTo(StreamOutput out) throws IOException { source().writeTo(out); - out.writeNamedWriteableCollection(children()); + out.writeNamedWriteableCollection(entries); } @Override @@ -77,24 +81,25 @@ public Map map() { return map; } - @Override - public DataType dataType() { - return UNSUPPORTED; - } - - @Override - public boolean foldable() { - for (EntryExpression ee : entries) { - if (ee.foldable() == false) { - return false; + public Expression getKey(String key) { + for (EntryExpression entry : entries) { + Expression k = entry.key(); + if (k.foldable()) { + Object o = k.fold(); + if (o instanceof BytesRef br) { + o = br.utf8ToString(); + } + if (o.toString().equalsIgnoreCase(key)) { + return entry.value(); + } } } - return true; + return null; } @Override - public Object fold() { - return map(); + public DataType dataType() { + return UNSUPPORTED; } @Override diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/map-functions.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/map-functions.csv-spec index a729b9a48dbea..ed8ca66633ef4 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/map-functions.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/map-functions.csv-spec @@ -50,61 +50,62 @@ c:long 1 ; -mapKeysEvalIndex +logWithBaseInMapEval required_capability: optional_named_argument_map_for_function -FROM employees -| EVAL k = map_keys({"option1":"value1", "option2":2, "option3":3.0, "option4":true, "option5":[3.0, 4.0]}) -| WHERE emp_no == 10001 -| KEEP emp_no, k +ROW value = 8.0 +| EVAL l = log_with_base_in_map(value, {"base":2.0}) ; -emp_no:integer |k:keyword -10001 |"option1, option2, option3, option4, option5" +value: double |l:double +8.0 |3.0 ; -mapKeysWhereTrueIndex +logWithBaseInMapEvalIndex required_capability: optional_named_argument_map_for_function FROM employees -| WHERE map_keys({"option1":"value1", "option2":2, "option3":3.0, "option4":true, "option5":[1, 2]}) like "option*" - AND emp_no == 10001 -| KEEP emp_no +| WHERE emp_no IN (10001, 10003) +| EVAL l = log_with_base_in_map(languages, {"base":2.0}) +| KEEP emp_no, languages, l +| SORT emp_no ; -emp_no:integer -10001 +emp_no:integer |languages:integer |l:double +10001 |2 |1.0 +10003 |4 |2.0 ; -mapKeysWhereFalseIndex +logWithBaseInMapWhereTrueIndex required_capability: optional_named_argument_map_for_function FROM employees -| WHERE map_keys({"option1":"value1", "option2":2, "option3":3.0, "option4":true, "option5":["a", "b"]}) like "false*" - AND emp_no == 10001 -| KEEP emp_no +| WHERE emp_no IN (10001, 10003) AND log_with_base_in_map(languages, {"base":2.0}) > 1 +| KEEP emp_no, languages +| SORT emp_no ; -emp_no:integer +emp_no:integer |languages:integer +10003 |4 ; -mapKeysSortIndex +logWithBaseInMapWhereFalseIndex required_capability: optional_named_argument_map_for_function FROM employees -| SORT emp_no, map_keys({"option1":"value1", "option2":2, "option3":3.0, "option4":true, "option5":[true, false]}) desc -| KEEP emp_no -| LIMIT 2 +| WHERE emp_no IN (10001, 10003) AND log_with_base_in_map(languages, {"base":2.0}) < 0 +| KEEP emp_no, languages +| SORT emp_no ; -emp_no:integer -10001 -10002 +emp_no:integer |languages:integer ; -mapKeysStatsIndex +logWithBaseInMapSortIndex required_capability: optional_named_argument_map_for_function FROM employees -| STATS c = count(*) BY map_keys({"option1":"value1", "option2":2, "option3":3.0, "option4":true, "option5":[3.0, 4.0]}) -| KEEP c +| WHERE emp_no IN (10001, 10003) +| SORT log_with_base_in_map(languages, {"base":2.0}) desc +| KEEP emp_no ; -c:long -100 +emp_no:integer +10003 +10001 ; diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/map/LogWithBaseInMapEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/map/LogWithBaseInMapEvaluator.java new file mode 100644 index 0000000000000..11c28c2a1f692 --- /dev/null +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/map/LogWithBaseInMapEvaluator.java @@ -0,0 +1,139 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.xpack.esql.expression.function.scalar.map; + +import java.lang.ArithmeticException; +import java.lang.IllegalArgumentException; +import java.lang.Override; +import java.lang.String; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.compute.operator.Warnings; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.esql.core.tree.Source; + +/** + * {@link EvalOperator.ExpressionEvaluator} implementation for {@link LogWithBaseInMap}. + * This class is generated. Do not edit it. + */ +public final class LogWithBaseInMapEvaluator implements EvalOperator.ExpressionEvaluator { + private final Source source; + + private final EvalOperator.ExpressionEvaluator value; + + private final double base; + + private final DriverContext driverContext; + + private Warnings warnings; + + public LogWithBaseInMapEvaluator(Source source, EvalOperator.ExpressionEvaluator value, + double base, DriverContext driverContext) { + this.source = source; + this.value = value; + this.base = base; + this.driverContext = driverContext; + } + + @Override + public Block eval(Page page) { + try (DoubleBlock valueBlock = (DoubleBlock) value.eval(page)) { + DoubleVector valueVector = valueBlock.asVector(); + if (valueVector == null) { + return eval(page.getPositionCount(), valueBlock); + } + return eval(page.getPositionCount(), valueVector); + } + } + + public DoubleBlock eval(int positionCount, DoubleBlock valueBlock) { + try(DoubleBlock.Builder result = driverContext.blockFactory().newDoubleBlockBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + if (valueBlock.isNull(p)) { + result.appendNull(); + continue position; + } + if (valueBlock.getValueCount(p) != 1) { + if (valueBlock.getValueCount(p) > 1) { + warnings().registerException(new IllegalArgumentException("single-value function encountered multi-value")); + } + result.appendNull(); + continue position; + } + try { + result.appendDouble(LogWithBaseInMap.process(valueBlock.getDouble(valueBlock.getFirstValueIndex(p)), this.base)); + } catch (ArithmeticException e) { + warnings().registerException(e); + result.appendNull(); + } + } + return result.build(); + } + } + + public DoubleBlock eval(int positionCount, DoubleVector valueVector) { + try(DoubleBlock.Builder result = driverContext.blockFactory().newDoubleBlockBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + try { + result.appendDouble(LogWithBaseInMap.process(valueVector.getDouble(p), this.base)); + } catch (ArithmeticException e) { + warnings().registerException(e); + result.appendNull(); + } + } + return result.build(); + } + } + + @Override + public String toString() { + return "LogWithBaseInMapEvaluator[" + "value=" + value + ", base=" + base + "]"; + } + + @Override + public void close() { + Releasables.closeExpectNoException(value); + } + + private Warnings warnings() { + if (warnings == null) { + this.warnings = Warnings.createWarnings( + driverContext.warningsMode(), + source.source().getLineNumber(), + source.source().getColumnNumber(), + source.text() + ); + } + return warnings; + } + + static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final Source source; + + private final EvalOperator.ExpressionEvaluator.Factory value; + + private final double base; + + public Factory(Source source, EvalOperator.ExpressionEvaluator.Factory value, double base) { + this.source = source; + this.value = value; + this.base = base; + } + + @Override + public LogWithBaseInMapEvaluator get(DriverContext context) { + return new LogWithBaseInMapEvaluator(source, value.get(context), base, context); + } + + @Override + public String toString() { + return "LogWithBaseInMapEvaluator[" + "value=" + value + ", base=" + base + "]"; + } + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java index febeccdad9d78..98032ee48be56 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java @@ -34,6 +34,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToString; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToUnsignedLong; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToVersion; +import org.elasticsearch.xpack.esql.expression.function.scalar.map.LogWithBaseInMap; import org.elasticsearch.xpack.esql.expression.function.scalar.math.Abs; import org.elasticsearch.xpack.esql.expression.function.scalar.math.Acos; import org.elasticsearch.xpack.esql.expression.function.scalar.math.Asin; @@ -161,6 +162,7 @@ public static List unaryScalars() { entries.add(IsNull.ENTRY); entries.add(Length.ENTRY); entries.add(Log10.ENTRY); + entries.add(LogWithBaseInMap.ENTRY); entries.add(LTrim.ENTRY); entries.add(Neg.ENTRY); entries.add(Not.ENTRY); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index bcd7f6693152e..3d840b2a9e3b4 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -71,8 +71,8 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.date.Now; import org.elasticsearch.xpack.esql.expression.function.scalar.ip.CIDRMatch; import org.elasticsearch.xpack.esql.expression.function.scalar.ip.IpPrefix; +import org.elasticsearch.xpack.esql.expression.function.scalar.map.LogWithBaseInMap; import org.elasticsearch.xpack.esql.expression.function.scalar.map.MapCount; -import org.elasticsearch.xpack.esql.expression.function.scalar.map.MapKeys; import org.elasticsearch.xpack.esql.expression.function.scalar.math.Abs; import org.elasticsearch.xpack.esql.expression.function.scalar.math.Acos; import org.elasticsearch.xpack.esql.expression.function.scalar.math.Asin; @@ -429,8 +429,10 @@ private static FunctionDefinition[][] snapshotFunctions() { // This is an experimental function and can be removed without notice. def(Delay.class, Delay::new, "delay"), def(Kql.class, uni(Kql::new), "kql"), + // The map_count and log_with_base_in_map are for debug/snapshot environments only + // and should never be enabled in a non-snapshot build. They are for the purpose of testing MapExpression only. def(MapCount.class, MapCount::new, "map_count"), - def(MapKeys.class, MapKeys::new, "map_keys"), + def(LogWithBaseInMap.class, LogWithBaseInMap::new, "log_with_base_in_map"), def(Rate.class, Rate::withUnresolvedTimestamp, "rate"), def(Term.class, bi(Term::new), "term") } }; } @@ -549,7 +551,7 @@ public static FunctionDescription description(FunctionDefinition def) { MapParam mapParamInfo = params[i].getAnnotation(MapParam.class); // refactor this if (mapParamInfo != null) { String name = mapParamInfo == null ? params[i].getName() : mapParamInfo.name(); - String[] valueType = mapParamInfo == null ? new String[] { "?" } : removeUnderConstruction(mapParamInfo.valueType()); + String[] valueType = mapParamInfo == null ? new String[] { "?" } : removeUnderConstruction(mapParamInfo.type()); String desc = mapParamInfo == null ? "" : mapParamInfo.description().replace('\n', ' '); boolean optional = mapParamInfo == null ? false : mapParamInfo.optional(); DataType targetDataType = getTargetType(valueType); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/MapParam.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/MapParam.java index 279c9617d5d07..ea86ec8b6a326 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/MapParam.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/MapParam.java @@ -20,9 +20,7 @@ public @interface MapParam { String name(); - String keyType() default "keyword"; - - String[] valueType() default { "keyword", "integer", "double", "boolean" }; + String[] type() default "map"; MapEntry[] paramHint() default {}; @@ -33,8 +31,10 @@ @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.PARAMETER) @interface MapEntry { - String key(); + String key() default ""; + + String[] type() default {}; - String value(); + String[] value() default {}; } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/map/LogWithBaseInMap.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/map/LogWithBaseInMap.java new file mode 100644 index 0000000000000..618622a9a477c --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/map/LogWithBaseInMap.java @@ -0,0 +1,204 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.map; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.compute.ann.Evaluator; +import org.elasticsearch.compute.ann.Fixed; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.xpack.esql.core.expression.EntryExpression; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.MapExpression; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; +import org.elasticsearch.xpack.esql.expression.function.MapParam; +import org.elasticsearch.xpack.esql.expression.function.OptionalArgument; +import org.elasticsearch.xpack.esql.expression.function.Param; +import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; +import org.elasticsearch.xpack.esql.expression.function.scalar.math.Cast; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; + +import static org.elasticsearch.common.logging.LoggerMessageFormat.format; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isMapExpression; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNumeric; +import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE; + +public class LogWithBaseInMap extends EsqlScalarFunction implements OptionalArgument { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Expression.class, + "LogWithBaseInMap", + LogWithBaseInMap::new + ); + + private final Expression number; + + private final Expression map; + + private static final String BASE = "base"; + + @FunctionInfo( + returnType = "double", + description = "Returns the logarithm of a value to a base. The input can be any numeric value, " + + "the return value is always a double.\n" + + "\n" + + "Logs of zero, negative numbers, and base of one return `null` as well as a warning." + ) + public LogWithBaseInMap( + Source source, + @Param( + name = "number", + type = { "double", "integer", "long", "unsigned_long" }, + description = "Numeric expression. If `null`, the function returns `null`." + ) Expression number, + @MapParam( + name = "map", + paramHint = { @MapParam.MapEntry(key = "base", type = { "double", "integer", "long", "unsigned_long" }) }, + description = "Input value. The input is a valid constant map expression.", + optional = true + ) Expression option + ) { + super(source, option == null ? Collections.singletonList(number) : List.of(number, option)); + this.number = number; + this.map = option; + } + + private LogWithBaseInMap(StreamInput in) throws IOException { + this( + Source.readFrom((PlanStreamInput) in), + in.readNamedWriteable(Expression.class), + in.readOptionalNamedWriteable(Expression.class) + ); + } + + @Override + public final void writeTo(StreamOutput out) throws IOException { + Source.EMPTY.writeTo(out); + out.writeNamedWriteable(number); + out.writeOptionalNamedWriteable(map); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + + @Override + protected TypeResolution resolveType() { + if (childrenResolved() == false) { + return new TypeResolution("Unresolved children"); + } + // validate field type + TypeResolution resolution = isNumeric(number, sourceText(), FIRST); + if (resolution.unresolved()) { + return resolution; + } + + if (map != null) { + // MapExpression does not have a DataType associated with it + return isMapExpression(map, sourceText(), SECOND).and(validateOptions()); + } + return TypeResolution.TYPE_RESOLVED; + } + + @Override + public DataType dataType() { + return DOUBLE; + } + + @Override + public boolean foldable() { + return number.foldable(); + } + + @Override + public Expression replaceChildren(List newChildren) { + return new LogWithBaseInMap(source(), newChildren.get(0), newChildren.size() > 1 ? newChildren.get(1) : null); + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, LogWithBaseInMap::new, number, map); + } + + @Override + public EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) { + var valueEval = Cast.cast(source(), number.dataType(), DataType.DOUBLE, toEvaluator.apply(number)); + double base = Math.E; + if (map instanceof MapExpression me) { + Expression b = me.getKey(BASE); + if (b != null && b.foldable()) { + Object v = b.fold(); + if (v instanceof BytesRef br) { + v = br.utf8ToString(); + } + base = Double.parseDouble(v.toString()); + } + } + return new LogWithBaseInMapEvaluator.Factory(source(), valueEval, base); + } + + @Evaluator(warnExceptions = { ArithmeticException.class }) + static double process(double value, @Fixed double base) throws ArithmeticException { + if (base <= 0d || value <= 0d) { + throw new ArithmeticException("Log of non-positive number"); + } + if (base == 1d) { + throw new ArithmeticException("Log of base 1"); + } + return Math.log10(value) / Math.log10(base); + } + + public Expression number() { + return number; + } + + public Expression map() { + return map; + } + + private TypeResolution validateOptions() { + for (EntryExpression entry : ((MapExpression) map).entries()) { + Expression key = entry.key(); + Expression value = entry.value(); + TypeResolution resolution = isFoldable(key, sourceText(), SECOND).and(isFoldable(value, sourceText(), SECOND)); + if (resolution.unresolved()) { + return resolution; + } + Object k = key.fold(); + Object v = value.fold(); + String base = k instanceof BytesRef br ? br.utf8ToString() : k.toString(); + String number = v instanceof BytesRef br ? br.utf8ToString() : v.toString(); + // validate the key is in SUPPORTED_OPTIONS + if (base.equalsIgnoreCase(BASE) == false) { + return new TypeResolution(format(null, "Invalid option key in [{}], expected base but got [{}]", sourceText(), key)); + } + // validate the value is valid for the key provided + try { + Double.parseDouble(number); + } catch (NumberFormatException e) { + return new TypeResolution( + format(null, "Invalid option value in [{}], expected a numeric number but got [{}]", sourceText(), v) + ); + } + + } + return TypeResolution.TYPE_RESOLVED; + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/map/MapCount.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/map/MapCount.java index 46abdedba8fad..d6723fc2e8b03 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/map/MapCount.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/map/MapCount.java @@ -7,8 +7,6 @@ package org.elasticsearch.xpack.esql.expression.function.scalar.map; -import org.elasticsearch.common.io.stream.NamedWriteableRegistry; -import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.logging.LoggerMessageFormat; import org.elasticsearch.xpack.esql.core.expression.Expression; @@ -19,20 +17,16 @@ import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.MapParam; -import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import java.io.IOException; import java.util.Collections; import java.util.List; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT; -import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isMapExpression; import static org.elasticsearch.xpack.esql.core.type.DataType.LONG; public class MapCount extends ScalarFunction { - public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "MapCount", MapCount::new); - private final Expression map; @FunctionInfo(returnType = "long", description = "Count the number of entries in a map") @@ -48,19 +42,14 @@ public MapCount( this.map = v; } - private MapCount(StreamInput in) throws IOException { - this(Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(Expression.class)); - } - @Override public final void writeTo(StreamOutput out) throws IOException { - Source.EMPTY.writeTo(out); - out.writeNamedWriteable(map); + throw new UnsupportedOperationException("not serialized"); } @Override public String getWriteableName() { - return ENTRY.name; + throw new UnsupportedOperationException("not serialized"); } @Override @@ -69,7 +58,7 @@ protected TypeResolution resolveType() { return new TypeResolution("Unresolved children"); } // MapExpression does not have a DataType associated with it - return isMapExpression(map, sourceText(), DEFAULT).and(isFoldable(map, sourceText(), DEFAULT)); + return isMapExpression(map, sourceText(), DEFAULT); } @Override @@ -79,7 +68,7 @@ public DataType dataType() { @Override public boolean foldable() { - return map.foldable(); + return true; } @Override @@ -102,7 +91,7 @@ public Object fold() { null, "Invalid format for [{}], expect a map of constant values but got {}", sourceText(), - map.fold() + map.toString() ) ); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/map/MapKeys.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/map/MapKeys.java deleted file mode 100644 index 748d8c4ea8776..0000000000000 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/map/MapKeys.java +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.esql.expression.function.scalar.map; - -import org.apache.lucene.util.BytesRef; -import org.elasticsearch.common.io.stream.NamedWriteableRegistry; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.logging.LoggerMessageFormat; -import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.expression.MapExpression; -import org.elasticsearch.xpack.esql.core.expression.function.scalar.ScalarFunction; -import org.elasticsearch.xpack.esql.core.tree.NodeInfo; -import org.elasticsearch.xpack.esql.core.tree.Source; -import org.elasticsearch.xpack.esql.core.type.DataType; -import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; -import org.elasticsearch.xpack.esql.expression.function.MapParam; -import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; - -import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT; -import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable; -import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isMapExpression; -import static org.elasticsearch.xpack.esql.core.type.DataType.KEYWORD; - -public class MapKeys extends ScalarFunction { - public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "MapKeys", MapKeys::new); - - private final Expression map; - - @FunctionInfo(returnType = "keyword", description = "Return the keys of a map") - public MapKeys( - Source source, - @MapParam( - name = "map", - paramHint = { @MapParam.MapEntry(key = "option1", value = "value1"), @MapParam.MapEntry(key = "option2", value = "value2") }, - description = "Input value. The input is a valid constant map expression." - ) Expression v - ) { - super(source, Collections.singletonList(v)); - this.map = v; - } - - private MapKeys(StreamInput in) throws IOException { - this(Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(Expression.class)); - } - - @Override - public final void writeTo(StreamOutput out) throws IOException { - Source.EMPTY.writeTo(out); - out.writeNamedWriteable(map); - } - - @Override - public String getWriteableName() { - return ENTRY.name; - } - - @Override - protected TypeResolution resolveType() { - if (childrenResolved() == false) { - return new TypeResolution("Unresolved children"); - } - // MapExpression does not have a DataType associated with it - return isMapExpression(map, sourceText(), DEFAULT).and(isFoldable(map, sourceText(), DEFAULT)); - } - - @Override - public DataType dataType() { - return KEYWORD; - } - - @Override - public boolean foldable() { - return map.foldable(); - } - - @Override - public Expression replaceChildren(List newChildren) { - return new MapKeys(source(), newChildren.get(0)); - } - - @Override - protected NodeInfo info() { - return NodeInfo.create(this, MapKeys::new, map); - } - - @Override - public Object fold() { - if (map instanceof MapExpression me) { - List result = new ArrayList<>(me.entries().size()); - for (Expression key : me.map().keySet()) { - if (key.foldable()) { - Object k = key.fold(); - result.add(k instanceof BytesRef b ? b.utf8ToString() : k.toString()); - } - } - return String.join(", ", result); - } else { - throw new IllegalArgumentException( - LoggerMessageFormat.format( - null, - "Invalid format for [{}], expect a map of constant values but got {}", - sourceText(), - map.fold() - ) - ); - } - } - - public Expression map() { - return this.map; - } -} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java index 7e12e1f961f8a..8c14a21d07dd6 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java @@ -35,8 +35,8 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.Max; import org.elasticsearch.xpack.esql.expression.function.aggregate.Min; import org.elasticsearch.xpack.esql.expression.function.fulltext.Match; +import org.elasticsearch.xpack.esql.expression.function.scalar.map.LogWithBaseInMap; import org.elasticsearch.xpack.esql.expression.function.scalar.map.MapCount; -import org.elasticsearch.xpack.esql.expression.function.scalar.map.MapKeys; import org.elasticsearch.xpack.esql.index.EsIndex; import org.elasticsearch.xpack.esql.index.IndexResolution; import org.elasticsearch.xpack.esql.parser.ParsingException; @@ -2532,46 +2532,60 @@ public void testMapExpressionAsFunctionArgument() { // positive LogicalPlan plan = analyze(""" from test - | EVAL c = map_count({"option1":"value1", "option2":[1,2,3]}), k = map_keys({"option1":"value1", "option2":[1,2,3]}) - | KEEP c, k + | EVAL c = map_count({"option1":"value1", "option2":[1,2,3]}) + | KEEP c """, "mapping-default.json"); var limit = as(plan, Limit.class); var proj = as(limit.child(), EsqlProject.class); List fields = proj.projections(); - assertEquals(2, fields.size()); + assertEquals(1, fields.size()); ReferenceAttribute ra = as(fields.get(0), ReferenceAttribute.class); assertEquals("c", ra.name()); assertEquals(DataType.LONG, ra.dataType()); - ra = as(fields.get(1), ReferenceAttribute.class); - assertEquals("k", ra.name()); - assertEquals(DataType.KEYWORD, ra.dataType()); var eval = as(proj.child(), Eval.class); - assertEquals(2, eval.fields().size()); + assertEquals(1, eval.fields().size()); Alias a = as(eval.fields().get(0), Alias.class); MapCount mc = as(a.child(), MapCount.class); MapExpression me = as(mc.map(), MapExpression.class); verifyMapExpression(me); - a = as(eval.fields().get(1), Alias.class); - MapKeys mk = as(a.child(), MapKeys.class); - me = as(mk.map(), MapExpression.class); - verifyMapExpression(me); var esRelation = as(eval.child(), EsRelation.class); assertEquals(esRelation.index().name(), "test"); + plan = analyze(""" + from test + | EVAL l = log_with_base_in_map(languages, {"base":2.0}) + | KEEP l + """, "mapping-default.json"); + limit = as(plan, Limit.class); + proj = as(limit.child(), EsqlProject.class); + fields = proj.projections(); + assertEquals(1, fields.size()); + ra = as(fields.get(0), ReferenceAttribute.class); + assertEquals("l", ra.name()); + assertEquals(DataType.DOUBLE, ra.dataType()); + eval = as(proj.child(), Eval.class); + assertEquals(1, eval.fields().size()); + a = as(eval.fields().get(0), Alias.class); + LogWithBaseInMap l = as(a.child(), LogWithBaseInMap.class); + me = as(l.map(), MapExpression.class); + assertEquals(1, me.entries().size()); + EntryExpression ee = as(me.entries().get(0), EntryExpression.class); + assertEquals(new Literal(EMPTY, "base", DataType.KEYWORD), ee.key()); + assertEquals(new Literal(EMPTY, 2.0, DataType.DOUBLE), ee.value()); + assertEquals(DataType.DOUBLE, ee.dataType()); + esRelation = as(eval.child(), EsRelation.class); + assertEquals(esRelation.index().name(), "test"); + // negative MapCount and MapKeys do not take fields, alias, non-map constants or null as inputs for (String arg : List.of("1", "emp_no", "x", "null")) { - for (String function : List.of("map_count", "map_keys")) { - Exception e = expectThrows( - VerificationException.class, - () -> analyze("from test | EVAL x = languages, f = " + function + "(" + arg + ")") - ); - assertThat( - e.getMessage(), - containsString( - "line 1:37: argument of [" + function + "(" + arg + ")] must be a map expression, received [" + arg + "]" - ) - ); - } + Exception e = expectThrows( + VerificationException.class, + () -> analyze("from test | EVAL x = languages, f = map_count(" + arg + ")") + ); + assertThat( + e.getMessage(), + containsString("line 1:37: argument of [map_count(" + arg + ")] must be a map expression, received [" + arg + "]") + ); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java index 6c06f35ce7b01..b5b8cf2a5c4d3 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java @@ -1083,7 +1083,7 @@ private static void renderDocsForOperators(String name) throws IOException { MapParam mapParamInfo = params[i].getAnnotation(MapParam.class); if (mapParamInfo != null) { String paramName = mapParamInfo.name(); - String[] valueType = mapParamInfo.valueType(); + String[] valueType = mapParamInfo.type(); String desc = mapParamInfo.description().replace('\n', ' '); boolean optional = mapParamInfo.optional(); args.add(new EsqlFunctionRegistry.ArgSignature(paramName, valueType, desc, optional)); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java index c4d7b30115c2d..e594213f95b0d 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java @@ -30,10 +30,12 @@ import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.AttributeSet; +import org.elasticsearch.xpack.esql.core.expression.EntryExpression; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.expression.MapExpression; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.core.expression.Nullability; import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute; @@ -65,6 +67,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToInteger; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToString; +import org.elasticsearch.xpack.esql.expression.function.scalar.map.LogWithBaseInMap; import org.elasticsearch.xpack.esql.expression.function.scalar.math.Round; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvAvg; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvCount; @@ -6722,4 +6725,34 @@ public void testMatchFunctionIsNotNullable() { containsString("[MATCH] function cannot operate on [text::keyword], which is not a field from an index mapping") ); } + + public void testMapExpressionAsFunctionArgument() { + assumeTrue("MapExpression require snapshot build", EsqlCapabilities.Cap.OPTIONAL_NAMED_ARGUMENT_MAP_FOR_FUNCTION.isEnabled()); + var query = """ + from test + | EVAL l = log_with_base_in_map(languages, {"base":2.0}) + | KEEP l + """; + var plan = optimizedPlan(query); + Project proj = as(plan, EsqlProject.class); + List fields = proj.projections(); + assertEquals(1, fields.size()); + ReferenceAttribute ra = as(fields.get(0), ReferenceAttribute.class); + assertEquals("l", ra.name()); + assertEquals(DataType.DOUBLE, ra.dataType()); + Eval eval = as(proj.child(), Eval.class); + assertEquals(1, eval.fields().size()); + Alias a = as(eval.fields().get(0), Alias.class); + LogWithBaseInMap l = as(a.child(), LogWithBaseInMap.class); + MapExpression me = as(l.map(), MapExpression.class); + assertEquals(1, me.entries().size()); + EntryExpression ee = as(me.entries().get(0), EntryExpression.class); + BytesRef key = as(ee.key().fold(), BytesRef.class); + assertEquals("base", key.utf8ToString()); + assertEquals(new Literal(EMPTY, 2.0, DataType.DOUBLE), ee.value()); + assertEquals(DataType.DOUBLE, ee.dataType()); + Limit limit = as(eval.child(), Limit.class); + EsRelation esRelation = as(limit.child(), EsRelation.class); + assertEquals(esRelation.index().name(), "test"); + } }