From 2c8e9aeb2a138c71fddd3998fa321f7bdfbe1d68 Mon Sep 17 00:00:00 2001 From: Stuart Tettemer Date: Thu, 20 Jun 2019 13:09:34 -0600 Subject: [PATCH] Add painless method getByPath, get value from nested collections with dotted path (#43170) Given a nested structure composed of Lists and Maps, getByPath will return the value keyed by path. getByPath is a method on Lists and Maps. The path is string Map keys and integer List indices separated by dot. An optional third argument returns a default value if the path lookup fails due to a missing value. Eg. ['key0': ['a', 'b'], 'key1': ['c', 'd']].getByPath('key1') = ['c', 'd'] ['key0': ['a', 'b'], 'key1': ['c', 'd']].getByPath('key1.0') = 'c' ['key0': ['a', 'b'], 'key1': ['c', 'd']].getByPath('key2', 'x') = 'x' [['key0': 'value0'], ['key1': 'value1']].getByPath('1.key1') = 'value1' Throws IllegalArgumentException if an item cannot be found and a default is not given. Throws NumberFormatException if a path element operating on a List is not an integer. Fixes #42769 --- .../packages.asciidoc | 6 + .../packages.asciidoc | 48 ++++ .../painless/api/Augmentation.java | 113 ++++++++ .../elasticsearch/painless/spi/java.util.txt | 4 + .../painless/AugmentationTests.java | 1 - .../painless/GetByPathAugmentationTests.java | 259 ++++++++++++++++++ 6 files changed, 430 insertions(+), 1 deletion(-) create mode 100644 modules/lang-painless/src/test/java/org/elasticsearch/painless/GetByPathAugmentationTests.java diff --git a/docs/painless/painless-api-reference/painless-api-reference-score/packages.asciidoc b/docs/painless/painless-api-reference/painless-api-reference-score/packages.asciidoc index 10f0f1b6daeab..a1beaeb5bc520 100644 --- a/docs/painless/painless-api-reference/painless-api-reference-score/packages.asciidoc +++ b/docs/painless/painless-api-reference/painless-api-reference-score/packages.asciidoc @@ -30,6 +30,8 @@ See the <> for a high-level overview of * List findResults(Function) * void {java11-javadoc}/java.base/java/lang/Iterable.html#forEach(java.util.function.Consumer)[forEach](Consumer) * def {java11-javadoc}/java.base/java/util/List.html#get(int)[get](int) +* Object getByPath(String) +* Object getByPath(String, Object) * int getLength() * Map groupBy(Function) * int {java11-javadoc}/java.base/java/util/List.html#hashCode()[hashCode]() @@ -84,6 +86,8 @@ See the <> for a high-level overview of * List findResults(Function) * void {java11-javadoc}/java.base/java/lang/Iterable.html#forEach(java.util.function.Consumer)[forEach](Consumer) * def {java11-javadoc}/java.base/java/util/List.html#get(int)[get](int) +* Object getByPath(String) +* Object getByPath(String, Object) * int getLength() * Map groupBy(Function) * int {java11-javadoc}/java.base/java/util/List.html#hashCode()[hashCode]() @@ -138,6 +142,8 @@ See the <> for a high-level overview of * List findResults(Function) * void {java11-javadoc}/java.base/java/lang/Iterable.html#forEach(java.util.function.Consumer)[forEach](Consumer) * def {java11-javadoc}/java.base/java/util/List.html#get(int)[get](int) +* Object getByPath(String) +* Object getByPath(String, Object) * int getLength() * Map groupBy(Function) * int {java11-javadoc}/java.base/java/util/List.html#hashCode()[hashCode]() diff --git a/docs/painless/painless-api-reference/painless-api-reference-shared/packages.asciidoc b/docs/painless/painless-api-reference/painless-api-reference-shared/packages.asciidoc index 75ad21ddc93f2..dd8141c2e96d0 100644 --- a/docs/painless/painless-api-reference/painless-api-reference-shared/packages.asciidoc +++ b/docs/painless/painless-api-reference/painless-api-reference-shared/packages.asciidoc @@ -4335,6 +4335,8 @@ See the <> for a high-level overview * List findResults(Function) * void {java11-javadoc}/java.base/java/lang/Iterable.html#forEach(java.util.function.Consumer)[forEach](Consumer) * def {java11-javadoc}/java.base/java/util/List.html#get(int)[get](int) +* Object getByPath(String) +* Object getByPath(String, Object) * int getLength() * Map groupBy(Function) * int {java11-javadoc}/java.base/java/util/List.html#hashCode()[hashCode]() @@ -4386,6 +4388,8 @@ See the <> for a high-level overview * List findResults(BiFunction) * void {java11-javadoc}/java.base/java/util/Map.html#forEach(java.util.function.BiConsumer)[forEach](BiConsumer) * def {java11-javadoc}/java.base/java/util/Map.html#get(java.lang.Object)[get](def) +* Object getByPath(String) +* Object getByPath(String, Object) * def {java11-javadoc}/java.base/java/util/Map.html#getOrDefault(java.lang.Object,java.lang.Object)[getOrDefault](def, def) * Map groupBy(BiFunction) * int {java11-javadoc}/java.base/java/lang/Object.html#hashCode()[hashCode]() @@ -4500,6 +4504,8 @@ See the <> for a high-level overview * List findResults(Function) * void {java11-javadoc}/java.base/java/lang/Iterable.html#forEach(java.util.function.Consumer)[forEach](Consumer) * def {java11-javadoc}/java.base/java/util/List.html#get(int)[get](int) +* Object getByPath(String) +* Object getByPath(String, Object) * int getLength() * Map groupBy(Function) * int {java11-javadoc}/java.base/java/util/List.html#hashCode()[hashCode]() @@ -4666,6 +4672,8 @@ See the <> for a high-level overview * List findResults(Function) * void {java11-javadoc}/java.base/java/lang/Iterable.html#forEach(java.util.function.Consumer)[forEach](Consumer) * def {java11-javadoc}/java.base/java/util/List.html#get(int)[get](int) +* Object getByPath(String) +* Object getByPath(String, Object) * int getLength() * Map groupBy(Function) * int {java11-javadoc}/java.base/java/util/List.html#hashCode()[hashCode]() @@ -5367,6 +5375,8 @@ See the <> for a high-level overview * List findResults(BiFunction) * void {java11-javadoc}/java.base/java/util/Map.html#forEach(java.util.function.BiConsumer)[forEach](BiConsumer) * def {java11-javadoc}/java.base/java/util/Map.html#get(java.lang.Object)[get](def) +* Object getByPath(String) +* Object getByPath(String, Object) * def {java11-javadoc}/java.base/java/util/Map.html#getOrDefault(java.lang.Object,java.lang.Object)[getOrDefault](def, def) * Map groupBy(BiFunction) * int {java11-javadoc}/java.base/java/lang/Object.html#hashCode()[hashCode]() @@ -5457,6 +5467,8 @@ See the <> for a high-level overview * List findResults(BiFunction) * void {java11-javadoc}/java.base/java/util/Map.html#forEach(java.util.function.BiConsumer)[forEach](BiConsumer) * def {java11-javadoc}/java.base/java/util/Map.html#get(java.lang.Object)[get](def) +* Object getByPath(String) +* Object getByPath(String, Object) * def {java11-javadoc}/java.base/java/util/Map.html#getOrDefault(java.lang.Object,java.lang.Object)[getOrDefault](def, def) * Map groupBy(BiFunction) * int {java11-javadoc}/java.base/java/lang/Object.html#hashCode()[hashCode]() @@ -5502,6 +5514,8 @@ See the <> for a high-level overview * List findResults(BiFunction) * void {java11-javadoc}/java.base/java/util/Map.html#forEach(java.util.function.BiConsumer)[forEach](BiConsumer) * def {java11-javadoc}/java.base/java/util/Map.html#get(java.lang.Object)[get](def) +* Object getByPath(String) +* Object getByPath(String, Object) * def {java11-javadoc}/java.base/java/util/Map.html#getOrDefault(java.lang.Object,java.lang.Object)[getOrDefault](def, def) * Map groupBy(BiFunction) * int {java11-javadoc}/java.base/java/lang/Object.html#hashCode()[hashCode]() @@ -5668,6 +5682,8 @@ See the <> for a high-level overview * List findResults(BiFunction) * void {java11-javadoc}/java.base/java/util/Map.html#forEach(java.util.function.BiConsumer)[forEach](BiConsumer) * def {java11-javadoc}/java.base/java/util/Map.html#get(java.lang.Object)[get](def) +* Object getByPath(String) +* Object getByPath(String, Object) * def {java11-javadoc}/java.base/java/util/Map.html#getOrDefault(java.lang.Object,java.lang.Object)[getOrDefault](def, def) * Map groupBy(BiFunction) * int {java11-javadoc}/java.base/java/lang/Object.html#hashCode()[hashCode]() @@ -5764,6 +5780,8 @@ See the <> for a high-level overview * List findResults(Function) * void {java11-javadoc}/java.base/java/lang/Iterable.html#forEach(java.util.function.Consumer)[forEach](Consumer) * def {java11-javadoc}/java.base/java/util/List.html#get(int)[get](int) +* Object getByPath(String) +* Object getByPath(String, Object) * def {java11-javadoc}/java.base/java/util/Deque.html#getFirst()[getFirst]() * def {java11-javadoc}/java.base/java/util/Deque.html#getLast()[getLast]() * int getLength() @@ -5836,6 +5854,8 @@ See the <> for a high-level overview * List findResults(Function) * void {java11-javadoc}/java.base/java/lang/Iterable.html#forEach(java.util.function.Consumer)[forEach](Consumer) * def {java11-javadoc}/java.base/java/util/List.html#get(int)[get](int) +* Object getByPath(String) +* Object getByPath(String, Object) * int getLength() * Map groupBy(Function) * int {java11-javadoc}/java.base/java/util/List.html#hashCode()[hashCode]() @@ -6056,6 +6076,8 @@ See the <> for a high-level overview * List findResults(BiFunction) * void {java11-javadoc}/java.base/java/util/Map.html#forEach(java.util.function.BiConsumer)[forEach](BiConsumer) * def {java11-javadoc}/java.base/java/util/Map.html#get(java.lang.Object)[get](def) +* Object getByPath(String) +* Object getByPath(String, Object) * def {java11-javadoc}/java.base/java/util/Map.html#getOrDefault(java.lang.Object,java.lang.Object)[getOrDefault](def, def) * Map groupBy(BiFunction) * int {java11-javadoc}/java.base/java/lang/Object.html#hashCode()[hashCode]() @@ -6157,6 +6179,8 @@ See the <> for a high-level overview * def {java11-javadoc}/java.base/java/util/NavigableMap.html#floorKey(java.lang.Object)[floorKey](def) * void {java11-javadoc}/java.base/java/util/Map.html#forEach(java.util.function.BiConsumer)[forEach](BiConsumer) * def {java11-javadoc}/java.base/java/util/Map.html#get(java.lang.Object)[get](def) +* Object getByPath(String) +* Object getByPath(String, Object) * def {java11-javadoc}/java.base/java/util/Map.html#getOrDefault(java.lang.Object,java.lang.Object)[getOrDefault](def, def) * Map groupBy(BiFunction) * int {java11-javadoc}/java.base/java/lang/Object.html#hashCode()[hashCode]() @@ -6642,6 +6666,8 @@ See the <> for a high-level overview * def {java11-javadoc}/java.base/java/util/SortedMap.html#firstKey()[firstKey]() * void {java11-javadoc}/java.base/java/util/Map.html#forEach(java.util.function.BiConsumer)[forEach](BiConsumer) * def {java11-javadoc}/java.base/java/util/Map.html#get(java.lang.Object)[get](def) +* Object getByPath(String) +* Object getByPath(String, Object) * def {java11-javadoc}/java.base/java/util/Map.html#getOrDefault(java.lang.Object,java.lang.Object)[getOrDefault](def, def) * Map groupBy(BiFunction) * int {java11-javadoc}/java.base/java/lang/Object.html#hashCode()[hashCode]() @@ -6844,6 +6870,8 @@ See the <> for a high-level overview * def {java11-javadoc}/java.base/java/util/Vector.html#firstElement()[firstElement]() * void {java11-javadoc}/java.base/java/lang/Iterable.html#forEach(java.util.function.Consumer)[forEach](Consumer) * def {java11-javadoc}/java.base/java/util/List.html#get(int)[get](int) +* Object getByPath(String) +* Object getByPath(String, Object) * int getLength() * Map groupBy(Function) * int {java11-javadoc}/java.base/java/util/List.html#hashCode()[hashCode]() @@ -6988,6 +7016,8 @@ See the <> for a high-level overview * def {java11-javadoc}/java.base/java/util/NavigableMap.html#floorKey(java.lang.Object)[floorKey](def) * void {java11-javadoc}/java.base/java/util/Map.html#forEach(java.util.function.BiConsumer)[forEach](BiConsumer) * def {java11-javadoc}/java.base/java/util/Map.html#get(java.lang.Object)[get](def) +* Object getByPath(String) +* Object getByPath(String, Object) * def {java11-javadoc}/java.base/java/util/Map.html#getOrDefault(java.lang.Object,java.lang.Object)[getOrDefault](def, def) * Map groupBy(BiFunction) * int {java11-javadoc}/java.base/java/lang/Object.html#hashCode()[hashCode]() @@ -7158,6 +7188,8 @@ See the <> for a high-level overview * def {java11-javadoc}/java.base/java/util/Vector.html#firstElement()[firstElement]() * void {java11-javadoc}/java.base/java/lang/Iterable.html#forEach(java.util.function.Consumer)[forEach](Consumer) * def {java11-javadoc}/java.base/java/util/List.html#get(int)[get](int) +* Object getByPath(String) +* Object getByPath(String, Object) * int getLength() * Map groupBy(Function) * int {java11-javadoc}/java.base/java/util/List.html#hashCode()[hashCode]() @@ -8016,6 +8048,8 @@ See the <> for a high-level overview * List findResults(Function) * void {java11-javadoc}/java.base/java/lang/Iterable.html#forEach(java.util.function.Consumer)[forEach](Consumer) * Boolean get(int) +* Object getByPath(String) +* Object getByPath(String, Object) * int getLength() * boolean getValue() * Map groupBy(Function) @@ -8071,6 +8105,8 @@ See the <> for a high-level overview * List findResults(Function) * void {java11-javadoc}/java.base/java/lang/Iterable.html#forEach(java.util.function.Consumer)[forEach](Consumer) * BytesRef get(int) +* Object getByPath(String) +* Object getByPath(String, Object) * int getLength() * BytesRef getValue() * Map groupBy(Function) @@ -8126,6 +8162,8 @@ See the <> for a high-level overview * List findResults(Function) * void {java11-javadoc}/java.base/java/lang/Iterable.html#forEach(java.util.function.Consumer)[forEach](Consumer) * JodaCompatibleZonedDateTime get(int) +* Object getByPath(String) +* Object getByPath(String, Object) * int getLength() * JodaCompatibleZonedDateTime getValue() * Map groupBy(Function) @@ -8181,6 +8219,8 @@ See the <> for a high-level overview * List findResults(Function) * void {java11-javadoc}/java.base/java/lang/Iterable.html#forEach(java.util.function.Consumer)[forEach](Consumer) * Double get(int) +* Object getByPath(String) +* Object getByPath(String, Object) * int getLength() * double getValue() * Map groupBy(Function) @@ -8240,6 +8280,8 @@ See the <> for a high-level overview * double geohashDistance(String) * double geohashDistanceWithDefault(String, double) * GeoPoint get(int) +* Object getByPath(String) +* Object getByPath(String, Object) * double getLat() * double[] getLats() * int getLength() @@ -8301,6 +8343,8 @@ See the <> for a high-level overview * List findResults(Function) * void {java11-javadoc}/java.base/java/lang/Iterable.html#forEach(java.util.function.Consumer)[forEach](Consumer) * Long get(int) +* Object getByPath(String) +* Object getByPath(String, Object) * int getLength() * long getValue() * Map groupBy(Function) @@ -8356,6 +8400,8 @@ See the <> for a high-level overview * List findResults(Function) * void {java11-javadoc}/java.base/java/lang/Iterable.html#forEach(java.util.function.Consumer)[forEach](Consumer) * String get(int) +* Object getByPath(String) +* Object getByPath(String, Object) * int getLength() * String getValue() * Map groupBy(Function) @@ -8415,6 +8461,8 @@ See the <> for a high-level overview * List findResults(Function) * void {java11-javadoc}/java.base/java/lang/Iterable.html#forEach(java.util.function.Consumer)[forEach](Consumer) * String get(int) +* Object getByPath(String) +* Object getByPath(String, Object) * int getLength() * String getValue() * Map groupBy(Function) diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/api/Augmentation.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/api/Augmentation.java index bbbbc3dfc37cf..d0745dc982c36 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/api/Augmentation.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/api/Augmentation.java @@ -25,6 +25,7 @@ import java.util.Collection; import java.util.LinkedHashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.TreeMap; import java.util.function.BiConsumer; @@ -34,6 +35,7 @@ import java.util.function.Function; import java.util.function.ObjIntConsumer; import java.util.function.Predicate; +import java.util.function.Supplier; import java.util.function.ToDoubleFunction; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -552,4 +554,115 @@ public static String[] splitOnToken(String receiver, String token, int limit) { // O(N) or faster depending on implementation return result.toArray(new String[0]); } + + /** + * Access values in nested containers with a dot separated path. Path elements are treated + * as strings for Maps and integers for Lists. + * @throws IllegalArgumentException if any of the following: + * - path is empty + * - path contains a trailing '.' or a repeated '.' + * - an element of the path does not exist, ie key or index not present + * - there is a non-container type at a non-terminal path element + * - a path element for a List is not an integer + * @return object at path + */ + public static Object getByPath(List receiver, String path) { + return getByPathDispatch(receiver, splitPath(path), 0, throwCantFindValue(path)); + } + + /** + * Same as {@link #getByPath(List, String)}, but for Map. + */ + public static Object getByPath(Map receiver, String path) { + return getByPathDispatch(receiver, splitPath(path), 0, throwCantFindValue(path)); + } + + /** + * Same as {@link #getByPath(List, String)}, but with a default value. + * @return element at path or {@code defaultValue} if the terminal path element does not exist. + */ + public static Object getByPath(List receiver, String path, Object defaultValue) { + return getByPathDispatch(receiver, splitPath(path), 0, () -> defaultValue); + } + + /** + * Same as {@link #getByPath(List, String, Object)}, but for Map. + */ + public static Object getByPath(Map receiver, String path, Object defaultValue) { + return getByPathDispatch(receiver, splitPath(path), 0, () -> defaultValue); + } + + // Dispatches to getByPathMap, getByPathList or returns obj if done. See handleMissing for dealing with missing + // elements. + private static Object getByPathDispatch(Object obj, String[] elements, int i, Supplier defaultSupplier) { + if (i > elements.length - 1) { + return obj; + } else if (elements[i].length() == 0 ) { + String format = "Extra '.' in path [%s] at index [%d]"; + throw new IllegalArgumentException(String.format(Locale.ROOT, format, String.join(".", elements), i)); + } else if (obj instanceof Map) { + return getByPathMap((Map) obj, elements, i, defaultSupplier); + } else if (obj instanceof List) { + return getByPathList((List) obj, elements, i, defaultSupplier); + } + return handleMissing(obj, elements, i, defaultSupplier); + } + + // lookup existing key in map, call back to dispatch. + private static Object getByPathMap(Map map, String[] elements, int i, Supplier defaultSupplier) { + String element = elements[i]; + if (map.containsKey(element)) { + return getByPathDispatch(map.get(element), elements, i + 1, defaultSupplier); + } + return handleMissing(map, elements, i, defaultSupplier); + } + + // lookup existing index in list, call back to dispatch. Throws IllegalArgumentException with NumberFormatException + // if index can't be parsed as an int. + private static Object getByPathList(List list, String[] elements, int i, Supplier defaultSupplier) { + String element = elements[i]; + try { + int elemInt = Integer.parseInt(element); + if (list.size() >= elemInt) { + return getByPathDispatch(list.get(elemInt), elements, i + 1, defaultSupplier); + } + } catch (NumberFormatException e) { + String format = "Could not parse [%s] as a int index into list at path [%s] and index [%d]"; + throw new IllegalArgumentException(String.format(Locale.ROOT, format, element, String.join(".", elements), i), e); + } + return handleMissing(list, elements, i, defaultSupplier); + } + + // Split path on '.', throws IllegalArgumentException for empty paths and paths ending in '.' + private static String[] splitPath(String path) { + if (path.length() == 0) { + throw new IllegalArgumentException("Missing path"); + } + if (path.endsWith(".")) { + String format = "Trailing '.' in path [%s]"; + throw new IllegalArgumentException(String.format(Locale.ROOT, format, path)); + } + return path.split("\\."); + } + + // A supplier that throws IllegalArgumentException + private static Supplier throwCantFindValue(String path) { + return () -> { + throw new IllegalArgumentException(String.format(Locale.ROOT, "Could not find value at path [%s]", path)); + }; + } + + // Use defaultSupplier if at last path element, otherwise throw IllegalArgumentException + private static Object handleMissing(Object obj, String[] elements, int i, Supplier defaultSupplier) { + if (obj instanceof List || obj instanceof Map) { + if (elements.length - 1 == i) { + return defaultSupplier.get(); + } + String format = "Container does not have [%s], for non-terminal index [%d] in path [%s]"; + throw new IllegalArgumentException(String.format(Locale.ROOT, format, elements[i], i, String.join(".", elements))); + } + String format = "Non-container [%s] at [%s], index [%d] in path [%s]"; + throw new IllegalArgumentException( + String.format(Locale.ROOT, format, obj.getClass().getName(), elements[i], i, String.join(".", elements))); + } } diff --git a/modules/lang-painless/src/main/resources/org/elasticsearch/painless/spi/java.util.txt b/modules/lang-painless/src/main/resources/org/elasticsearch/painless/spi/java.util.txt index 94f302a891d48..958ac927a66dd 100644 --- a/modules/lang-painless/src/main/resources/org/elasticsearch/painless/spi/java.util.txt +++ b/modules/lang-painless/src/main/resources/org/elasticsearch/painless/spi/java.util.txt @@ -126,6 +126,8 @@ class java.util.List { int org.elasticsearch.painless.api.Augmentation getLength() void sort(Comparator) List subList(int,int) + Object org.elasticsearch.painless.api.Augmentation getByPath(String) + Object org.elasticsearch.painless.api.Augmentation getByPath(String, Object) } class java.util.ListIterator { @@ -161,6 +163,8 @@ class java.util.Map { void replaceAll(BiFunction) int size() Collection values() + Object org.elasticsearch.painless.api.Augmentation getByPath(String) + Object org.elasticsearch.painless.api.Augmentation getByPath(String, Object) # some adaptations of groovy methods List org.elasticsearch.painless.api.Augmentation collect(BiFunction) diff --git a/modules/lang-painless/src/test/java/org/elasticsearch/painless/AugmentationTests.java b/modules/lang-painless/src/test/java/org/elasticsearch/painless/AugmentationTests.java index 70fbb733e2f8f..e462997444165 100644 --- a/modules/lang-painless/src/test/java/org/elasticsearch/painless/AugmentationTests.java +++ b/modules/lang-painless/src/test/java/org/elasticsearch/painless/AugmentationTests.java @@ -232,7 +232,6 @@ public void testString_SplitOnToken() { new SplitCase("1\n1.1.\r\n1\r\n111", "\r\n"), }; for (SplitCase split : cases) { - //System.out.println(String.format("Splitting '%s' by '%s' %d times", split.input, split.token, split.count)); assertArrayEquals( split.input.split(Pattern.quote(split.token), split.count), (String[])exec("return \""+split.input+"\".splitOnToken(\""+split.token+"\", "+split.count+");") diff --git a/modules/lang-painless/src/test/java/org/elasticsearch/painless/GetByPathAugmentationTests.java b/modules/lang-painless/src/test/java/org/elasticsearch/painless/GetByPathAugmentationTests.java new file mode 100644 index 0000000000000..603ab7fd0e60c --- /dev/null +++ b/modules/lang-painless/src/test/java/org/elasticsearch/painless/GetByPathAugmentationTests.java @@ -0,0 +1,259 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.painless; + + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +public class GetByPathAugmentationTests extends ScriptTestCase { + + private final String k001Key = "k011"; + private final String k001Value = "b"; + private final Map k001Obj = new HashMap<>(); + private final String k001MapStr = "['" + k001Key + "': '" + k001Value + "']"; + private final String mapMapList = "['k0': ['k01': [['k010': 'a'], " + k001MapStr + "]], 'k1': ['q']]"; + + private final String l2m2l1Index0 = "ll0"; + private final String l2m2l1Index1 = "ll1"; + private final List l2m2l1Obj = new ArrayList<>(); + private final String l2m2l1Str = "['" + l2m2l1Index0 + "', '" + l2m2l1Index1 + "']"; + private final String listMapListList = "[['m0':'v0'],['m1':'v1'],['m2':['l0','l1', " + l2m2l1Str + "]]]"; + + private final String mapList = "['key0': ['a', 'b'], 'key1': ['c', 'd']]"; + private final String mapMap = "['a': ['b': 'c']]"; + + public GetByPathAugmentationTests() { + l2m2l1Obj.add(l2m2l1Index0); + l2m2l1Obj.add(l2m2l1Index1); + k001Obj.put(k001Key, k001Value); + } + + private String toScript(String collection, String key) { + return String.format(Locale.ROOT, "return %s.getByPath('%s')", collection, key); + } + + private String toScript(String collection, String key, String defaultValue) { + return String.format(Locale.ROOT, "return %s.getByPath('%s', %s)", collection, key, defaultValue); + } + + private String numberFormat(String unparsable, String path, int i) { + String format = "Could not parse [%s] as a int index into list at path [%s] and index [%d]"; + return String.format(Locale.ROOT, format, unparsable, path, i); + } + + private String missingValue(String path) { + return String.format(Locale.ROOT, "Could not find value at path [%s]", path); + } + + private void assertPathValue(String collection, String key, Object value) { + assertEquals(value, exec(toScript(collection, key))); + } + + private void assertPathDefaultValue(String collection, String key, Object value, String defaultValue) { + assertEquals(value, exec(toScript(collection, key, defaultValue))); + } + + private IllegalArgumentException assertPathError(String collection, String key, String message) { + return assertPathError(toScript(collection, key), message); + } + + private IllegalArgumentException assertPathError(String collection, String key, String defaultValue, String message) { + return assertPathError(toScript(collection, key, defaultValue), message); + } + + private IllegalArgumentException assertPathError(String script, String message) { + IllegalArgumentException illegal = expectScriptThrows( + IllegalArgumentException.class, + () -> exec(script) + ); + assertEquals(message, illegal.getMessage()); + return illegal; + } + + public void testOneLevelMap() { + assertPathValue("['k0':'v0']", "k0", "v0"); + } + + public void testOneLevelList() { + assertPathValue("['a','b','c','d']", "2", "c"); + } + + public void testTwoLevelMapList() { + assertPathValue("['key0': ['a', 'b'], 'key1': ['c', 'd']]", "key1.0", "c"); + } + + public void testMapDiffSizeList() { + assertPathValue("['k0': ['a','b','c','d'], 'k1': ['q']]", "k0.3", "d"); + } + + public void testBiMapList() { + assertPathValue(mapMapList, "k0.k01.1.k011", k001Value); + } + + public void testBiMapListObject() { + assertPathValue(mapMapList, "k0.k01.1", k001Obj); + } + + public void testListMap() { + assertPathValue("[['key0': 'value0'], ['key1': 'value1']]", "1.key1", "value1"); + } + + public void testTriList() { + assertPathValue("[['a','b'],['c','d'],[['e','f'],['g','h']]]", "2.1.1", "h"); + } + + public void testMapBiListObject() { + assertPathValue(listMapListList, "2.m2.2", l2m2l1Obj); + } + + public void testMapBiList() { + assertPathValue(listMapListList, "2.m2.2.1", l2m2l1Index1); + } + + public void testGetCollection() { + List k1List = new ArrayList<>(); + k1List.add("c"); + k1List.add("d"); + assertPathValue("['key0': ['a', 'b'], 'key1': ['c', 'd']]", "key1", k1List); + } + + public void testMapListDefaultOneLevel() { + assertPathDefaultValue(mapList, "key2", "x", "'x'"); + } + + public void testMapListDefaultTwoLevel() { + assertPathDefaultValue(mapList, "key1.1", "d", "'x'"); + } + + public void testBiMapListDefault() { + assertPathDefaultValue(mapMapList, "k0.k01.1.k012", "foo", "'foo'"); + } + + public void testBiMapListDefaultExists() { + assertPathDefaultValue(mapMapList, "k0.k01.1.k011", "b", "'foo'"); + } + + public void testBiMapListDefaultObjectExists() { + assertPathDefaultValue(mapMapList, "k0.k01.1", k001Obj, "'foo'"); + } + + public void testBiMapListDefaultObject() { + assertPathDefaultValue(mapMapList, "k0.k01.9", k001Obj, k001MapStr); + } + + public void testListMapBiListDefaultExists() { + assertPathDefaultValue(listMapListList, "2.m2.2", l2m2l1Obj, "'foo'"); + } + + public void testListMapBiListDefaultObject() { + assertPathDefaultValue(listMapListList, "2.m2.9", l2m2l1Obj, l2m2l1Str); + } + + public void testBiListBadIndex() { + String path = "1.k0"; + IllegalArgumentException err = assertPathError("[['a','b'],['c','d']]", path, numberFormat("k0", path, 1)); + assertEquals(err.getCause().getClass(), NumberFormatException.class); + } + + public void testBiMapListMissingLast() { + String path = "k0.k01.1.k012"; + assertPathError(mapMapList, path, missingValue(path)); + } + + public void testBiMapListBadIndex() { + String path = "k0.k01.k012"; + IllegalArgumentException err = assertPathError(mapMapList, path, numberFormat("k012", path, 2)); + assertEquals(err.getCause().getClass(), NumberFormatException.class); + } + + public void testListMapBiListMissingObject() { + String path = "2.m2.12"; + assertPathError(listMapListList, path, missingValue(path)); + } + + public void testListMapBiListBadIndexAtObject() { + String path = "2.m2.a8"; + IllegalArgumentException err = assertPathError(listMapListList, path, numberFormat("a8", path, 2)); + assertEquals(err.getCause().getClass(), NumberFormatException.class); + } + + public void testNonContainer() { + assertPathError(mapMap, "a.b.c", "Non-container [java.lang.String] at [c], index [2] in path [a.b.c]"); + } + + public void testMissingPath() { + assertPathError(mapMap, "", "Missing path"); + } + + public void testDoubleDot() { + assertPathError(mapMap, "a..b", "Extra '.' in path [a..b] at index [1]"); + } + + public void testTrailingDot() { + assertPathError(mapMap, "a.b.", "Trailing '.' in path [a.b.]"); + } + + public void testBiListDefaultBadIndex() { + String path = "1.k0"; + IllegalArgumentException err = assertPathError( + "[['a','b'],['c','d']]", + path, + "'foo'", + numberFormat("k0", path, 1)); + assertEquals(err.getCause().getClass(), NumberFormatException.class); + } + + public void testBiMapListDefaultBadIndex() { + String path = "k0.k01.k012"; + IllegalArgumentException err = assertPathError( + mapMapList, + path, + "'foo'", + numberFormat("k012", path, 2)); + assertEquals(err.getCause().getClass(), NumberFormatException.class); + } + + public void testListMapBiListObjectDefaultBadIndex() { + String path = "2.m2.a8"; + IllegalArgumentException err = assertPathError( + listMapListList, + path, + "'foo'", + numberFormat("a8", path, 2)); + assertEquals(err.getCause().getClass(), NumberFormatException.class); + } + + public void testNonContainerDefaultBadIndex() { + assertPathError(mapMap, "a.b.c", "'foo'", + "Non-container [java.lang.String] at [c], index [2] in path [a.b.c]"); + } + + public void testDoubleDotDefault() { + assertPathError(mapMap, "a..b", "'foo'", "Extra '.' in path [a..b] at index [1]"); + } + + public void testTrailingDotDefault() { + assertPathError(mapMap, "a.b.", "'foo'", "Trailing '.' in path [a.b.]"); + } +}