diff --git a/docs/changelog/85872.yaml b/docs/changelog/85872.yaml new file mode 100644 index 0000000000000..225d68bd72ca2 --- /dev/null +++ b/docs/changelog/85872.yaml @@ -0,0 +1,5 @@ +pr: 85872 +summary: Replace the implementation of the `categorize_text` aggregation +area: Machine Learning +type: enhancement +issues: [] diff --git a/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc b/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc index 9ee97f120e6f7..9aba9b417f63d 100644 --- a/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc +++ b/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc @@ -17,56 +17,14 @@ NOTE: If you have considerable memory allocated to your JVM but are receiving ci <>, or <> to explore the created categories. +NOTE: The algorithm used for categorization was completely changed in version 8.3.0. As a result this aggregation + will not work in a mixed version cluster where some nodes are on version 8.3.0 or higher and others are + on a version older than 8.3.0. Upgrade all nodes in your cluster to the same version if you experience + an error related to this change. + [[bucket-categorize-text-agg-syntax]] ==== Parameters -`field`:: -(Required, string) -The semi-structured text field to categorize. - -`max_unique_tokens`:: -(Optional, integer, default: `50`) -The maximum number of unique tokens at any position up to `max_matched_tokens`. -Must be larger than 1. Smaller values use less memory and create fewer categories. -Larger values will use more memory and create narrower categories. -Max allowed value is `100`. - -`max_matched_tokens`:: -(Optional, integer, default: `5`) -The maximum number of token positions to match on before attempting to merge categories. -Larger values will use more memory and create narrower categories. -Max allowed value is `100`. - -Example: -`max_matched_tokens` of 2 would disallow merging of the categories -[`foo` `bar` `baz`] -[`foo` `baz` `bozo`] -As the first 2 tokens are required to match for the category. - -NOTE: Once `max_unique_tokens` is reached at a given position, a new `*` token is -added and all new tokens at that position are matched by the `*` token. - -`similarity_threshold`:: -(Optional, integer, default: `50`) -The minimum percentage of tokens that must match for text to be added to the -category bucket. -Must be between 1 and 100. The larger the value the narrower the categories. -Larger values will increase memory usage and create narrower categories. - -`categorization_filters`:: -(Optional, array of strings) -This property expects an array of regular expressions. The expressions -are used to filter out matching sequences from the categorization field values. -You can use this functionality to fine tune the categorization by excluding -sequences from consideration when categories are defined. For example, you can -exclude SQL statements that appear in your log files. This -property cannot be used at the same time as `categorization_analyzer`. If you -only want to define simple regular expression filters that are applied prior to -tokenization, setting this property is the easiest method. If you also want to -customize the tokenizer or post-tokenization filtering, use the -`categorization_analyzer` property instead and include the filters as -`pattern_replace` character filters. - `categorization_analyzer`:: (Optional, object or string) The categorization analyzer specifies how the text is analyzed and tokenized before @@ -95,14 +53,33 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=tokenizer] include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=filter] ===== -`shard_size`:: +`categorization_filters`:: +(Optional, array of strings) +This property expects an array of regular expressions. The expressions +are used to filter out matching sequences from the categorization field values. +You can use this functionality to fine tune the categorization by excluding +sequences from consideration when categories are defined. For example, you can +exclude SQL statements that appear in your log files. This +property cannot be used at the same time as `categorization_analyzer`. If you +only want to define simple regular expression filters that are applied prior to +tokenization, setting this property is the easiest method. If you also want to +customize the tokenizer or post-tokenization filtering, use the +`categorization_analyzer` property instead and include the filters as +`pattern_replace` character filters. + +`field`:: +(Required, string) +The semi-structured text field to categorize. + +`max_matched_tokens`:: (Optional, integer) -The number of categorization buckets to return from each shard before merging -all the results. +This parameter does nothing now, but is permitted for compatibility with the original +pre-8.3.0 implementation. -`size`:: -(Optional, integer, default: `10`) -The number of buckets to return. +`max_unique_tokens`:: +(Optional, integer) +This parameter does nothing now, but is permitted for compatibility with the original +pre-8.3.0 implementation. `min_doc_count`:: (Optional, integer) @@ -113,8 +90,23 @@ The minimum number of documents for a bucket to be returned to the results. The minimum number of documents for a bucket to be returned from the shard before merging. -==== Basic use +`shard_size`:: +(Optional, integer) +The number of categorization buckets to return from each shard before merging +all the results. + +`similarity_threshold`:: +(Optional, integer, default: `70`) +The minimum percentage of token weight that must match for text to be added to the +category bucket. +Must be between 1 and 100. The larger the value the narrower the categories. +Larger values will increase memory usage and create narrower categories. +`size`:: +(Optional, integer, default: `10`) +The number of buckets to return. + +==== Basic use WARNING: Re-analyzing _large_ result sets will require a lot of time and memory. This aggregation should be used in conjunction with <>. Additionally, you may consider @@ -149,19 +141,23 @@ Response: "buckets" : [ { "doc_count" : 3, - "key" : "Node shutting down" + "key" : "Node shutting down", + "max_matching_length" : 49 }, { "doc_count" : 1, - "key" : "Node starting up" + "key" : "Node starting up", + "max_matching_length" : 47 }, { "doc_count" : 1, - "key" : "User foo_325 logging on" + "key" : "User foo_325 logging on", + "max_matching_length" : 52 }, { "doc_count" : 1, - "key" : "User foo_864 logged off" + "key" : "User foo_864 logged off", + "max_matching_length" : 52 } ] } @@ -169,7 +165,6 @@ Response: } -------------------------------------------------- - Here is an example using `categorization_filters` [source,console] @@ -202,19 +197,23 @@ category results "buckets" : [ { "doc_count" : 3, - "key" : "Node shutting down" + "key" : "Node shutting down", + "max_matching_length" : 49 }, { "doc_count" : 1, - "key" : "Node starting up" + "key" : "Node starting up", + "max_matching_length" : 47 }, { "doc_count" : 1, - "key" : "User logged off" + "key" : "User logged off", + "max_matching_length" : 52 }, { "doc_count" : 1, - "key" : "User logging on" + "key" : "User logging on", + "max_matching_length" : 52 } ] } @@ -223,11 +222,15 @@ category results -------------------------------------------------- Here is an example using `categorization_filters`. -The default analyzer is a whitespace analyzer with a custom token filter -which filters out tokens that start with any number. +The default analyzer uses the `ml_standard` tokenizer which is similar to a whitespace tokenizer +but filters out tokens that could be interpreted as hexadecimal numbers. The default analyzer +also uses the `first_line_with_letters` character filter, so that only the first meaningful line +of multi-line messages is considered. But, it may be that a token is a known highly-variable token (formatted usernames, emails, etc.). In that case, it is good to supply -custom `categorization_filters` to filter out those tokens for better categories. These filters will also reduce memory usage as fewer -tokens are held in memory for the categories. +custom `categorization_filters` to filter out those tokens for better categories. These filters may also reduce memory usage as fewer +tokens are held in memory for the categories. (If there are sufficient examples of different usernames, emails, etc., then +categories will form that naturally discard them as variables, but for small input data where only one example exists this won't +happen.) [source,console] -------------------------------------------------- @@ -238,8 +241,7 @@ POST log-messages/_search?filter_path=aggregations "categorize_text": { "field": "message", "categorization_filters": ["\\w+\\_\\d{3}"], <1> - "max_matched_tokens": 2, <2> - "similarity_threshold": 30 <3> + "similarity_threshold": 11 <2> } } } @@ -248,12 +250,12 @@ POST log-messages/_search?filter_path=aggregations // TEST[setup:categorize_text] <1> The filters to apply to the analyzed tokens. It filters out tokens like `bar_123`. -<2> Require at least 2 tokens before the log categories attempt to merge together -<3> Require 30% of the tokens to match before expanding a log categories - to add a new log entry +<2> Require 11% of token weight to match before adding a message to an + existing category rather than creating a new one. -The resulting categories are now broad, matching the first token -and merging the log groups. +The resulting categories are now very broad, merging the log groups. +(A `similarity_threshold` of 11% is generally too low. Settings over +50% are usually better.) [source,console-result] -------------------------------------------------- @@ -263,11 +265,13 @@ and merging the log groups. "buckets" : [ { "doc_count" : 4, - "key" : "Node *" + "key" : "Node", + "max_matching_length" : 49 }, { "doc_count" : 2, - "key" : "User *" + "key" : "User", + "max_matching_length" : 52 } ] } @@ -326,6 +330,7 @@ POST log-messages/_search?filter_path=aggregations { "doc_count" : 2, "key" : "Node shutting down", + "max_matching_length" : 49, "hit" : { "hits" : { "total" : { @@ -352,6 +357,7 @@ POST log-messages/_search?filter_path=aggregations { "doc_count" : 1, "key" : "Node starting up", + "max_matching_length" : 47, "hit" : { "hits" : { "total" : { @@ -387,6 +393,7 @@ POST log-messages/_search?filter_path=aggregations { "doc_count" : 1, "key" : "Node shutting down", + "max_matching_length" : 49, "hit" : { "hits" : { "total" : { @@ -413,6 +420,7 @@ POST log-messages/_search?filter_path=aggregations { "doc_count" : 1, "key" : "User logged off", + "max_matching_length" : 52, "hit" : { "hits" : { "total" : { @@ -439,6 +447,7 @@ POST log-messages/_search?filter_path=aggregations { "doc_count" : 1, "key" : "User logging on", + "max_matching_length" : 52, "hit" : { "hits" : { "total" : { diff --git a/x-pack/plugin/build.gradle b/x-pack/plugin/build.gradle index 0f90a6c391e64..8c2f8f2476356 100644 --- a/x-pack/plugin/build.gradle +++ b/x-pack/plugin/build.gradle @@ -118,16 +118,28 @@ tasks.named("yamlRestTestV7CompatTransform").configure { task -> "ml/datafeeds_crud/Test update datafeed to point to job already attached to another datafeed", "behaviour change #44752 - not allowing to update datafeed job_id" ) + task.skipTest( + "ml/trained_model_cat_apis/Test cat trained models", + "A type field was added to cat.ml_trained_models #73660, this is a backwards compatible change. Still this is a cat api, and we don't support them with rest api compatibility. (the test would be very hard to transform too)" + ) + task.skipTest( + "ml/categorization_agg/Test categorization agg simple", + "categorize_text was changed in 8.3, but experimental prior to the change" + ) + task.skipTest( + "ml/categorization_agg/Test categorization aggregation against unsupported field", + "categorize_text was changed in 8.3, but experimental prior to the change" + ) + task.skipTest( + "ml/categorization_agg/Test categorization aggregation with poor settings", + "categorize_text was changed in 8.3, but experimental prior to the change" + ) task.skipTest("rollup/delete_job/Test basic delete_job", "rollup was an experimental feature, also see #41227") task.skipTest("rollup/delete_job/Test delete job twice", "rollup was an experimental feature, also see #41227") task.skipTest("rollup/delete_job/Test delete running job", "rollup was an experimental feature, also see #41227") task.skipTest("rollup/get_jobs/Test basic get_jobs", "rollup was an experimental feature, also see #41227") task.skipTest("rollup/put_job/Test basic put_job", "rollup was an experimental feature, also see #41227") task.skipTest("rollup/start_job/Test start job twice", "rollup was an experimental feature, also see #41227") - task.skipTest( - "ml/trained_model_cat_apis/Test cat trained models", - "A type field was added to cat.ml_trained_models #73660, this is a backwards compatible change. Still this is a cat api, and we don't support them with rest api compatibility. (the test would be very hard to transform too)" - ) task.skipTest("indices.freeze/30_usage/Usage stats on frozen indices", "#70192 -- the freeze index API is removed from 8.0") task.skipTest("indices.freeze/20_stats/Translog stats on frozen indices", "#70192 -- the freeze index API is removed from 8.0") task.skipTest("indices.freeze/10_basic/Basic", "#70192 -- the freeze index API is removed from 8.0") diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/CategorizationAggregationIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/CategorizeTextAggregationIT.java similarity index 96% rename from x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/CategorizationAggregationIT.java rename to x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/CategorizeTextAggregationIT.java index e6d680328a204..309ca2211b1c7 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/CategorizationAggregationIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/CategorizeTextAggregationIT.java @@ -27,7 +27,7 @@ import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.notANumber; -public class CategorizationAggregationIT extends BaseMlIntegTestCase { +public class CategorizeTextAggregationIT extends BaseMlIntegTestCase { private static final String DATA_INDEX = "categorization-agg-data"; @@ -77,9 +77,9 @@ public void testAggregationWithBroadCategories() { .setSize(0) .setTrackTotalHits(false) .addAggregation( + // Overriding the similarity threshold to just 11% (default is 70%) results in the + // "Node started" and "Node stopped" messages being grouped in the same category new CategorizeTextAggregationBuilder("categorize", "msg").setSimilarityThreshold(11) - .setMaxUniqueTokens(2) - .setMaxMatchedTokens(1) .subAggregation(AggregationBuilders.max("max").field("time")) .subAggregation(AggregationBuilders.min("min").field("time")) ) @@ -87,7 +87,7 @@ public void testAggregationWithBroadCategories() { InternalCategorizationAggregation agg = response.getAggregations().get("categorize"); assertThat(agg.getBuckets(), hasSize(2)); - assertCategorizationBucket(agg.getBuckets().get(0), "Node *", 4); + assertCategorizationBucket(agg.getBuckets().get(0), "Node", 4); assertCategorizationBucket(agg.getBuckets().get(1), "Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception", 2); } diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/CategorizeTextDistributedIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/CategorizeTextDistributedIT.java index eed00010ac3cf..6f32acd5f08d8 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/CategorizeTextDistributedIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/CategorizeTextDistributedIT.java @@ -16,8 +16,8 @@ import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.routing.ShardRouting; import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.xpack.ml.aggs.categorization2.CategorizeTextAggregationBuilder; -import org.elasticsearch.xpack.ml.aggs.categorization2.InternalCategorizationAggregation; +import org.elasticsearch.xpack.ml.aggs.categorization.CategorizeTextAggregationBuilder; +import org.elasticsearch.xpack.ml.aggs.categorization.InternalCategorizationAggregation; import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase; import java.util.Arrays; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 3d38e6d817f2e..ab28ae4a432d1 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -1431,16 +1431,7 @@ public List getAggregations() { CategorizeTextAggregationBuilder::new, CategorizeTextAggregationBuilder.PARSER ).addResultReader(InternalCategorizationAggregation::new) - .setAggregatorRegistrar(s -> s.registerUsage(CategorizeTextAggregationBuilder.NAME)), - // TODO: in the long term only keep one or other of these categorization aggregations - new AggregationSpec( - org.elasticsearch.xpack.ml.aggs.categorization2.CategorizeTextAggregationBuilder.NAME, - org.elasticsearch.xpack.ml.aggs.categorization2.CategorizeTextAggregationBuilder::new, - org.elasticsearch.xpack.ml.aggs.categorization2.CategorizeTextAggregationBuilder.PARSER - ).addResultReader(org.elasticsearch.xpack.ml.aggs.categorization2.InternalCategorizationAggregation::new) - .setAggregatorRegistrar( - s -> s.registerUsage(org.elasticsearch.xpack.ml.aggs.categorization2.CategorizeTextAggregationBuilder.NAME) - ) + .setAggregatorRegistrar(s -> s.registerUsage(CategorizeTextAggregationBuilder.NAME)) ); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationBytesRefHash.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationBytesRefHash.java index 6246683ddfe6e..23d95b4ba0f7f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationBytesRefHash.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationBytesRefHash.java @@ -15,14 +15,6 @@ class CategorizationBytesRefHash implements Releasable { - /** - * Our special wild card value. - */ - static final BytesRef WILD_CARD_REF = new BytesRef("*"); - /** - * For all WILD_CARD references, the token ID is always -1 - */ - static final int WILD_CARD_ID = -1; private final BytesRefHash bytesRefHash; CategorizationBytesRefHash(BytesRefHash bytesRefHash) { @@ -46,34 +38,28 @@ BytesRef[] getDeeps(int[] ids) { } BytesRef getDeep(long id) { - if (id == WILD_CARD_ID) { - return WILD_CARD_REF; - } BytesRef shallow = bytesRefHash.get(id, new BytesRef()); return BytesRef.deepCopyOf(shallow); } int put(BytesRef bytesRef) { - if (WILD_CARD_REF.equals(bytesRef)) { - return WILD_CARD_ID; - } long hash = bytesRefHash.add(bytesRef); if (hash < 0) { + // BytesRefHash returns -1 - hash if the entry already existed, but we just want to return the hash return (int) (-1L - hash); - } else { - if (hash > Integer.MAX_VALUE) { - throw new AggregationExecutionException( - LoggerMessageFormat.format( - "more than [{}] unique terms encountered. " - + "Consider restricting the documents queried or adding [{}] in the {} configuration", - Integer.MAX_VALUE, - CategorizeTextAggregationBuilder.CATEGORIZATION_FILTERS.getPreferredName(), - CategorizeTextAggregationBuilder.NAME - ) - ); - } - return (int) hash; } + if (hash > Integer.MAX_VALUE) { + throw new AggregationExecutionException( + LoggerMessageFormat.format( + "more than [{}] unique terms encountered. " + + "Consider restricting the documents queried or adding [{}] in the {} configuration", + Integer.MAX_VALUE, + CategorizeTextAggregationBuilder.CATEGORIZATION_FILTERS.getPreferredName(), + CategorizeTextAggregationBuilder.NAME + ) + ); + } + return (int) hash; } @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization2/CategorizationPartOfSpeechDictionary.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationPartOfSpeechDictionary.java similarity index 98% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization2/CategorizationPartOfSpeechDictionary.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationPartOfSpeechDictionary.java index 1a9d23f566af2..09a6846ead344 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization2/CategorizationPartOfSpeechDictionary.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationPartOfSpeechDictionary.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.ml.aggs.categorization2; +package org.elasticsearch.xpack.ml.aggs.categorization; import java.io.BufferedReader; import java.io.IOException; @@ -25,7 +25,7 @@ */ public class CategorizationPartOfSpeechDictionary { - static final String DICTIONARY_FILE_PATH = "/org/elasticsearch/xpack/ml/aggs/categorization2/ml-en.dict"; + static final String DICTIONARY_FILE_PATH = "/org/elasticsearch/xpack/ml/aggs/categorization/ml-en.dict"; static final String PART_OF_SPEECH_SEPARATOR = "@"; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java deleted file mode 100644 index 75560ec70555d..0000000000000 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java +++ /dev/null @@ -1,142 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.ml.aggs.categorization; - -import org.apache.lucene.util.Accountable; -import org.apache.lucene.util.BytesRef; -import org.apache.lucene.util.RamUsageEstimator; -import org.elasticsearch.search.aggregations.InternalAggregations; - -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.stream.Collectors; - -/** - * Categorized semi-structured text utilizing the drain algorithm: https://arxiv.org/pdf/1806.04356.pdf - * With the following key differences - * - This structure keeps track of the "smallest" sub-tree. So, instead of naively adding a new "*" node, the smallest sub-tree - * is transformed if the incoming token has a higher doc_count. - * - Additionally, similarities are weighted, which allows for nicer merging of existing categories - * - An optional tree reduction step is available to collapse together tiny sub-trees - * - * - * The main implementation is a fixed-sized prefix tree. - * Consequently, this assumes that splits that give us more information come earlier in the text. - * - * Examples: - * - * Given token values: - * - * Node is online - * Node is offline - * - * With a fixed tree depth of 2 we would get the following splits - * 3 // initial root is the number of tokens - * | - * "Node" // first prefix node of value "Node" - * | - * "is" - * / \ - * [Node is online] [Node is offline] //the individual categories for this simple case - * - * If the similarityThreshold was less than 0.6, the result would be a single category [Node is *] - * - */ -public class CategorizationTokenTree implements Accountable { - - private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(CategorizationTokenTree.class); - private final int maxMatchTokens; - private final int maxUniqueTokens; - private final int similarityThreshold; - private long idGenerator; - private final Map root = new HashMap<>(); - private long sizeInBytes; - - public CategorizationTokenTree(int maxUniqueTokens, int maxMatchTokens, int similarityThreshold) { - assert maxUniqueTokens > 0 && maxMatchTokens >= 0; - this.maxUniqueTokens = maxUniqueTokens; - this.maxMatchTokens = maxMatchTokens; - this.similarityThreshold = similarityThreshold; - this.sizeInBytes = SHALLOW_SIZE; - } - - public List toIntermediateBuckets(CategorizationBytesRefHash hash) { - return root.values().stream().flatMap(c -> c.getAllChildrenTextCategorizations().stream()).map(lg -> { - int[] categoryTokenIds = lg.getCategorization(); - BytesRef[] bytesRefs = new BytesRef[categoryTokenIds.length]; - for (int i = 0; i < categoryTokenIds.length; i++) { - bytesRefs[i] = hash.getDeep(categoryTokenIds[i]); - } - InternalCategorizationAggregation.Bucket bucket = new InternalCategorizationAggregation.Bucket( - new InternalCategorizationAggregation.BucketKey(bytesRefs), - lg.getCount(), - InternalAggregations.EMPTY - ); - bucket.bucketOrd = lg.bucketOrd; - return bucket; - }).collect(Collectors.toList()); - } - - void mergeSmallestChildren() { - root.values().forEach(TreeNode::collapseTinyChildren); - } - - /** - * This method does not mutate the underlying structure. Meaning, if a matching categories isn't found, it may return empty. - * - * @param tokenIds The tokens to categorize - * @return The category or `Optional.empty()` if one doesn't exist - */ - public Optional parseTokensConst(final int[] tokenIds) { - TreeNode currentNode = this.root.get(tokenIds.length); - if (currentNode == null) { // we are missing an entire sub tree. New token length found - return Optional.empty(); - } - return Optional.ofNullable(currentNode.getCategorization(tokenIds)); - } - - /** - * This categorizes the passed tokens, potentially mutating the structure by expanding an existing category or adding a new one. - * @param tokenIds The tokens to categorize - * @param docCount The count of docs for the given tokens - * @return An existing categorization or a newly created one - */ - public TextCategorization parseTokens(final int[] tokenIds, long docCount) { - TreeNode currentNode = this.root.get(tokenIds.length); - if (currentNode == null) { // we are missing an entire sub tree. New token length found - currentNode = newNode(docCount, 0, tokenIds); - incSize(currentNode.ramBytesUsed() + RamUsageEstimator.HASHTABLE_RAM_BYTES_PER_ENTRY + RamUsageEstimator.NUM_BYTES_OBJECT_REF); - this.root.put(tokenIds.length, currentNode); - } else { - currentNode.incCount(docCount); - } - return currentNode.addText(tokenIds, docCount, this); - } - - TreeNode newNode(long docCount, int tokenPos, int[] tokenIds) { - return tokenPos < maxMatchTokens - 1 && tokenPos < tokenIds.length - ? new TreeNode.InnerTreeNode(docCount, tokenPos, maxUniqueTokens) - : new TreeNode.LeafTreeNode(docCount, similarityThreshold); - } - - TextCategorization newCategorization(long docCount, int[] tokenIds) { - return new TextCategorization(tokenIds, docCount, idGenerator++); - } - - void incSize(long size) { - sizeInBytes += size; - } - - @Override - public long ramBytesUsed() { - return sizeInBytes; - } - -} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilder.java index c5fe223bf09b2..fa1b8313d5bbd 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilder.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.aggs.categorization; +import org.elasticsearch.ElasticsearchException; import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -44,14 +45,22 @@ public class CategorizeTextAggregationBuilder extends AbstractAggregationBuilder -1 ); - static final int MAX_MAX_UNIQUE_TOKENS = 100; - static final int MAX_MAX_MATCHED_TOKENS = 100; public static final String NAME = "categorize_text"; + // In 8.3 the algorithm used by this aggregation was completely changed. + // Prior to 8.3 the Drain algorithm was used. From 8.3 the same algorithm + // we use in our C++ categorization code was used. As a result of this + // the aggregation will not perform well in mixed version clusters where + // some nodes are pre-8.3 and others are newer, so we throw an error in + // this situation. The aggregation was experimental at the time this change + // was made, so this is acceptable. + public static final Version ALGORITHM_CHANGED_VERSION = Version.V_8_3_0; + static final ParseField FIELD_NAME = new ParseField("field"); - static final ParseField MAX_UNIQUE_TOKENS = new ParseField("max_unique_tokens"); static final ParseField SIMILARITY_THRESHOLD = new ParseField("similarity_threshold"); - static final ParseField MAX_MATCHED_TOKENS = new ParseField("max_matched_tokens"); + // The next two are unused, but accepted and ignored to avoid breaking client code + static final ParseField MAX_UNIQUE_TOKENS = new ParseField("max_unique_tokens").withAllDeprecated(); + static final ParseField MAX_MATCHED_TOKENS = new ParseField("max_matched_tokens").withAllDeprecated(); static final ParseField CATEGORIZATION_FILTERS = new ParseField("categorization_filters"); static final ParseField CATEGORIZATION_ANALYZER = new ParseField("categorization_analyzer"); @@ -61,9 +70,10 @@ public class CategorizeTextAggregationBuilder extends AbstractAggregationBuilder ); static { PARSER.declareString(CategorizeTextAggregationBuilder::setFieldName, FIELD_NAME); - PARSER.declareInt(CategorizeTextAggregationBuilder::setMaxUniqueTokens, MAX_UNIQUE_TOKENS); - PARSER.declareInt(CategorizeTextAggregationBuilder::setMaxMatchedTokens, MAX_MATCHED_TOKENS); PARSER.declareInt(CategorizeTextAggregationBuilder::setSimilarityThreshold, SIMILARITY_THRESHOLD); + // The next two are unused, but accepted and ignored to avoid breaking client code + PARSER.declareInt((p, c) -> {}, MAX_UNIQUE_TOKENS); + PARSER.declareInt((p, c) -> {}, MAX_MATCHED_TOKENS); PARSER.declareField( CategorizeTextAggregationBuilder::setCategorizationAnalyzerConfig, (p, c) -> CategorizationAnalyzerConfig.buildFromXContentFragment(p, false), @@ -86,9 +96,8 @@ public static CategorizeTextAggregationBuilder parse(String aggregationName, XCo ); private CategorizationAnalyzerConfig categorizationAnalyzerConfig; private String fieldName; - private int maxUniqueTokens = 50; - private int similarityThreshold = 50; - private int maxMatchedTokens = 5; + // Default of 70% matches the C++ code + private int similarityThreshold = 70; private CategorizeTextAggregationBuilder(String name) { super(name); @@ -99,6 +108,11 @@ public CategorizeTextAggregationBuilder(String name, String fieldName) { this.fieldName = ExceptionsHelper.requireNonNull(fieldName, FIELD_NAME); } + @Override + public boolean supportsSampling() { + return true; + } + public String getFieldName() { return fieldName; } @@ -110,37 +124,22 @@ public CategorizeTextAggregationBuilder setFieldName(String fieldName) { public CategorizeTextAggregationBuilder(StreamInput in) throws IOException { super(in); + // Disallow this aggregation in mixed version clusters that cross the algorithm change boundary. + if (in.getVersion().before(ALGORITHM_CHANGED_VERSION)) { + throw new ElasticsearchException( + "[" + + NAME + + "] aggregation cannot be used in a cluster where some nodes have version [" + + ALGORITHM_CHANGED_VERSION + + "] or higher and others have a version before this" + ); + } this.bucketCountThresholds = new TermsAggregator.BucketCountThresholds(in); this.fieldName = in.readString(); - this.maxUniqueTokens = in.readVInt(); - this.maxMatchedTokens = in.readVInt(); this.similarityThreshold = in.readVInt(); this.categorizationAnalyzerConfig = in.readOptionalWriteable(CategorizationAnalyzerConfig::new); } - @Override - public boolean supportsSampling() { - return true; - } - - public int getMaxUniqueTokens() { - return maxUniqueTokens; - } - - public CategorizeTextAggregationBuilder setMaxUniqueTokens(int maxUniqueTokens) { - this.maxUniqueTokens = maxUniqueTokens; - if (maxUniqueTokens <= 0 || maxUniqueTokens > MAX_MAX_UNIQUE_TOKENS) { - throw ExceptionsHelper.badRequestException( - "[{}] must be greater than 0 and less than or equal [{}]. Found [{}] in [{}]", - MAX_UNIQUE_TOKENS.getPreferredName(), - MAX_MAX_UNIQUE_TOKENS, - maxUniqueTokens, - name - ); - } - return this; - } - public double getSimilarityThreshold() { return similarityThreshold; } @@ -198,24 +197,6 @@ public CategorizeTextAggregationBuilder setCategorizationFilters(List ca return this; } - public int getMaxMatchedTokens() { - return maxMatchedTokens; - } - - public CategorizeTextAggregationBuilder setMaxMatchedTokens(int maxMatchedTokens) { - this.maxMatchedTokens = maxMatchedTokens; - if (maxMatchedTokens <= 0 || maxMatchedTokens > MAX_MAX_MATCHED_TOKENS) { - throw ExceptionsHelper.badRequestException( - "[{}] must be greater than 0 and less than or equal [{}]. Found [{}] in [{}]", - MAX_MATCHED_TOKENS.getPreferredName(), - MAX_MAX_MATCHED_TOKENS, - maxMatchedTokens, - name - ); - } - return this; - } - /** * @param size indicating how many buckets should be returned */ @@ -233,7 +214,7 @@ public CategorizeTextAggregationBuilder size(int size) { } /** - * @param shardSize - indicating the number of buckets each shard + * @param shardSize - indicating the number of buckets each shard * will return to the coordinating node (the node that coordinates the * search execution). The higher the shard size is, the more accurate the * results are. @@ -293,18 +274,24 @@ protected CategorizeTextAggregationBuilder( super(clone, factoriesBuilder, metadata); this.bucketCountThresholds = new TermsAggregator.BucketCountThresholds(clone.bucketCountThresholds); this.fieldName = clone.fieldName; - this.maxUniqueTokens = clone.maxUniqueTokens; - this.maxMatchedTokens = clone.maxMatchedTokens; this.similarityThreshold = clone.similarityThreshold; this.categorizationAnalyzerConfig = clone.categorizationAnalyzerConfig; } @Override protected void doWriteTo(StreamOutput out) throws IOException { + // Disallow this aggregation in mixed version clusters that cross the algorithm change boundary. + if (out.getVersion().before(ALGORITHM_CHANGED_VERSION)) { + throw new ElasticsearchException( + "[" + + NAME + + "] aggregation cannot be used in a cluster where some nodes have version [" + + ALGORITHM_CHANGED_VERSION + + "] or higher and others have a version before this" + ); + } bucketCountThresholds.writeTo(out); out.writeString(fieldName); - out.writeVInt(maxUniqueTokens); - out.writeVInt(maxMatchedTokens); out.writeVInt(similarityThreshold); out.writeOptionalWriteable(categorizationAnalyzerConfig); } @@ -318,8 +305,6 @@ protected AggregatorFactory doBuild( return new CategorizeTextAggregatorFactory( name, fieldName, - maxUniqueTokens, - maxMatchedTokens, similarityThreshold, bucketCountThresholds, categorizationAnalyzerConfig, @@ -335,8 +320,6 @@ protected XContentBuilder internalXContent(XContentBuilder builder, Params param builder.startObject(); bucketCountThresholds.toXContent(builder, params); builder.field(FIELD_NAME.getPreferredName(), fieldName); - builder.field(MAX_UNIQUE_TOKENS.getPreferredName(), maxUniqueTokens); - builder.field(MAX_MATCHED_TOKENS.getPreferredName(), maxMatchedTokens); builder.field(SIMILARITY_THRESHOLD.getPreferredName(), similarityThreshold); if (categorizationAnalyzerConfig != null) { categorizationAnalyzerConfig.toXContent(builder, params); @@ -362,6 +345,10 @@ public String getType() { @Override public Version getMinimalSupportedVersion() { - return Version.V_7_16_0; + // This isn't strictly true, as the categorize_text aggregation has existed since 7.16. + // However, the implementation completely changed in 8.3, so it's best that if the + // coordinating node is on 8.3 or above then it should refuse to use this aggregation + // until the older nodes are upgraded. + return ALGORITHM_CHANGED_VERSION; } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java index 39d3a0f83f77d..6bc3454c332c5 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java @@ -9,10 +9,8 @@ import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.TokenStream; -import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.util.BytesRef; -import org.apache.lucene.util.PriorityQueue; import org.elasticsearch.common.util.BytesRefHash; import org.elasticsearch.common.util.ObjectArray; import org.elasticsearch.core.Releasables; @@ -29,19 +27,16 @@ import org.elasticsearch.search.aggregations.support.AggregationContext; import org.elasticsearch.search.lookup.SourceLookup; import org.elasticsearch.xpack.core.ml.job.config.CategorizationAnalyzerConfig; +import org.elasticsearch.xpack.ml.aggs.categorization.InternalCategorizationAggregation.Bucket; import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer; import java.io.IOException; -import java.util.ArrayList; import java.util.Arrays; -import java.util.Collections; -import java.util.Comparator; import java.util.Iterator; +import java.util.List; import java.util.Map; import java.util.Optional; -import static org.elasticsearch.xpack.ml.aggs.categorization.CategorizeTextAggregationBuilder.MAX_MAX_MATCHED_TOKENS; - public class CategorizeTextAggregator extends DeferableBucketAggregator { private final TermsAggregator.BucketCountThresholds bucketCountThresholds; @@ -49,12 +44,11 @@ public class CategorizeTextAggregator extends DeferableBucketAggregator { private final MappedFieldType fieldType; private final CategorizationAnalyzer analyzer; private final String sourceFieldName; - private ObjectArray categorizers; - private final int maxUniqueTokens; - private final int maxMatchTokens; + private ObjectArray categorizers; private final int similarityThreshold; private final LongKeyedBucketOrds bucketOrds; private final CategorizationBytesRefHash bytesRefHash; + private final CategorizationPartOfSpeechDictionary partOfSpeechDictionary; protected CategorizeTextAggregator( String name, @@ -64,8 +58,6 @@ protected CategorizeTextAggregator( String sourceFieldName, MappedFieldType fieldType, TermsAggregator.BucketCountThresholds bucketCountThresholds, - int maxUniqueTokens, - int maxMatchTokens, int similarityThreshold, CategorizationAnalyzerConfig categorizationAnalyzerConfig, Map metadata @@ -75,7 +67,7 @@ protected CategorizeTextAggregator( this.sourceFieldName = sourceFieldName; this.fieldType = fieldType; CategorizationAnalyzerConfig analyzerConfig = Optional.ofNullable(categorizationAnalyzerConfig) - .orElse(CategorizationAnalyzerConfig.buildStandardCategorizationAnalyzer(Collections.emptyList())); + .orElse(CategorizationAnalyzerConfig.buildStandardCategorizationAnalyzer(List.of())); final String analyzerName = analyzerConfig.getAnalyzer(); if (analyzerName != null) { Analyzer globalAnalyzer = context.getNamedAnalyzer(analyzerName); @@ -96,12 +88,12 @@ protected CategorizeTextAggregator( ); } this.categorizers = bigArrays().newObjectArray(1); - this.maxUniqueTokens = maxUniqueTokens; - this.maxMatchTokens = maxMatchTokens; this.similarityThreshold = similarityThreshold; this.bucketOrds = LongKeyedBucketOrds.build(bigArrays(), CardinalityUpperBound.MANY); this.bucketCountThresholds = bucketCountThresholds; this.bytesRefHash = new CategorizationBytesRefHash(new BytesRefHash(2048, bigArrays())); + // TODO: make it possible to choose a language instead of or as well as English for the part-of-speech dictionary + this.partOfSpeechDictionary = CategorizationPartOfSpeechDictionary.getInstance(); } @Override @@ -112,42 +104,26 @@ protected void doClose() { @Override public InternalAggregation[] buildAggregations(long[] ordsToCollect) throws IOException { - InternalCategorizationAggregation.Bucket[][] topBucketsPerOrd = - new InternalCategorizationAggregation.Bucket[ordsToCollect.length][]; + Bucket[][] topBucketsPerOrd = new Bucket[ordsToCollect.length][]; for (int ordIdx = 0; ordIdx < ordsToCollect.length; ordIdx++) { - final CategorizationTokenTree categorizationTokenTree = categorizers.get(ordsToCollect[ordIdx]); - if (categorizationTokenTree == null) { - topBucketsPerOrd[ordIdx] = new InternalCategorizationAggregation.Bucket[0]; + final TokenListCategorizer categorizer = categorizers.get(ordsToCollect[ordIdx]); + if (categorizer == null) { + topBucketsPerOrd[ordIdx] = new Bucket[0]; continue; } int size = (int) Math.min(bucketOrds.bucketsInOrd(ordIdx), bucketCountThresholds.getShardSize()); - PriorityQueue ordered = - new InternalCategorizationAggregation.BucketCountPriorityQueue(size); - for (InternalCategorizationAggregation.Bucket bucket : categorizationTokenTree.toIntermediateBuckets(bytesRefHash)) { - if (bucket.docCount < bucketCountThresholds.getShardMinDocCount()) { - continue; - } - ordered.insertWithOverflow(bucket); - } - topBucketsPerOrd[ordIdx] = new InternalCategorizationAggregation.Bucket[ordered.size()]; - for (int i = ordered.size() - 1; i >= 0; --i) { - topBucketsPerOrd[ordIdx][i] = ordered.pop(); - } + topBucketsPerOrd[ordIdx] = categorizer.toOrderedBuckets(size); } - buildSubAggsForAllBuckets(topBucketsPerOrd, b -> b.bucketOrd, (b, a) -> b.aggregations = a); + buildSubAggsForAllBuckets(topBucketsPerOrd, Bucket::getBucketOrd, Bucket::setAggregations); InternalAggregation[] results = new InternalAggregation[ordsToCollect.length]; for (int ordIdx = 0; ordIdx < ordsToCollect.length; ordIdx++) { - InternalCategorizationAggregation.Bucket[] bucketArray = topBucketsPerOrd[ordIdx]; - Arrays.sort(bucketArray, Comparator.naturalOrder()); results[ordIdx] = new InternalCategorizationAggregation( name, bucketCountThresholds.getRequiredSize(), bucketCountThresholds.getMinDocCount(), - maxUniqueTokens, - maxMatchTokens, similarityThreshold, metadata(), - Arrays.asList(bucketArray) + Arrays.asList(topBucketsPerOrd[ordIdx]) ); } return results; @@ -159,8 +135,6 @@ public InternalAggregation buildEmptyAggregation() { name, bucketCountThresholds.getRequiredSize(), bucketCountThresholds.getMinDocCount(), - maxUniqueTokens, - maxMatchTokens, similarityThreshold, metadata() ); @@ -172,66 +146,50 @@ protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucket @Override public void collect(int doc, long owningBucketOrd) throws IOException { categorizers = bigArrays().grow(categorizers, owningBucketOrd + 1); - CategorizationTokenTree categorizer = categorizers.get(owningBucketOrd); + TokenListCategorizer categorizer = categorizers.get(owningBucketOrd); if (categorizer == null) { - categorizer = new CategorizationTokenTree(maxUniqueTokens, maxMatchTokens, similarityThreshold); + categorizer = new TokenListCategorizer(bytesRefHash, partOfSpeechDictionary, (float) similarityThreshold / 100.0f); addRequestCircuitBreakerBytes(categorizer.ramBytesUsed()); categorizers.set(owningBucketOrd, categorizer); } collectFromSource(doc, owningBucketOrd, categorizer); } - private void collectFromSource(int doc, long owningBucketOrd, CategorizationTokenTree categorizer) throws IOException { + private void collectFromSource(int doc, long owningBucketOrd, TokenListCategorizer categorizer) throws IOException { sourceLookup.setSegmentAndDocument(ctx, doc); Iterator itr = sourceLookup.extractRawValuesWithoutCaching(sourceFieldName).stream().map(obj -> { - if (obj == null) { - return null; - } if (obj instanceof BytesRef) { return fieldType.valueForDisplay(obj).toString(); } - return obj.toString(); + return (obj == null) ? null : obj.toString(); }).iterator(); while (itr.hasNext()) { - TokenStream ts = analyzer.tokenStream(fieldType.name(), itr.next()); - processTokenStream(owningBucketOrd, ts, doc, categorizer); + String string = itr.next(); + try (TokenStream ts = analyzer.tokenStream(fieldType.name(), string)) { + processTokenStream(owningBucketOrd, ts, string.length(), doc, categorizer); + } } } - private void processTokenStream(long owningBucketOrd, TokenStream ts, int doc, CategorizationTokenTree categorizer) - throws IOException { - ArrayList tokens = new ArrayList<>(); - try (ts) { - CharTermAttribute termAtt = ts.addAttribute(CharTermAttribute.class); - ts.reset(); - int numTokens = 0; - // Only categorize the first MAX_MAX_MATCHED_TOKENS tokens - while (ts.incrementToken() && numTokens < MAX_MAX_MATCHED_TOKENS) { - if (termAtt.length() > 0) { - tokens.add(bytesRefHash.put(new BytesRef(termAtt))); - numTokens++; - } - } - if (tokens.isEmpty()) { - return; - } - } + private void processTokenStream( + long owningBucketOrd, + TokenStream ts, + int unfilteredLength, + int doc, + TokenListCategorizer categorizer + ) throws IOException { long previousSize = categorizer.ramBytesUsed(); - TextCategorization lg = categorizer.parseTokens( - tokens.stream().mapToInt(Integer::valueOf).toArray(), - docCountProvider.getDocCount(doc) - ); - long newSize = categorizer.ramBytesUsed(); - if (newSize - previousSize > 0) { - addRequestCircuitBreakerBytes(newSize - previousSize); + TokenListCategory category = categorizer.computeCategory(ts, unfilteredLength, docCountProvider.getDocCount(doc)); + if (category == null) { + return; } - - long bucketOrd = bucketOrds.add(owningBucketOrd, lg.getId()); + long sizeDiff = categorizer.ramBytesUsed() - previousSize; + addRequestCircuitBreakerBytes(sizeDiff); + long bucketOrd = bucketOrds.add(owningBucketOrd, category.getId()); if (bucketOrd < 0) { // already seen - bucketOrd = -1 - bucketOrd; - collectExistingBucket(sub, doc, bucketOrd); + collectExistingBucket(sub, doc, -1 - bucketOrd); } else { - lg.bucketOrd = bucketOrd; + category.setBucketOrd(bucketOrd); collectBucket(sub, doc, bucketOrd); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorFactory.java index ca56741796ae3..2d67842001bd1 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorFactory.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.aggs.categorization; +import org.elasticsearch.index.mapper.KeywordScriptFieldType; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.mapper.TextSearchInfo; import org.elasticsearch.search.aggregations.Aggregator; @@ -26,9 +27,6 @@ public class CategorizeTextAggregatorFactory extends AggregatorFactory { private final MappedFieldType fieldType; - private final String indexedFieldName; - private final int maxUniqueTokens; - private final int maxMatchTokens; private final int similarityThreshold; private final CategorizationAnalyzerConfig categorizationAnalyzerConfig; private final TermsAggregator.BucketCountThresholds bucketCountThresholds; @@ -36,8 +34,6 @@ public class CategorizeTextAggregatorFactory extends AggregatorFactory { public CategorizeTextAggregatorFactory( String name, String fieldName, - int maxUniqueTokens, - int maxMatchTokens, int similarityThreshold, TermsAggregator.BucketCountThresholds bucketCountThresholds, CategorizationAnalyzerConfig categorizationAnalyzerConfig, @@ -48,13 +44,6 @@ public CategorizeTextAggregatorFactory( ) throws IOException { super(name, context, parent, subFactoriesBuilder, metadata); this.fieldType = context.getFieldType(fieldName); - if (fieldType != null) { - this.indexedFieldName = fieldType.name(); - } else { - this.indexedFieldName = null; - } - this.maxUniqueTokens = maxUniqueTokens; - this.maxMatchTokens = maxMatchTokens; this.similarityThreshold = similarityThreshold; this.categorizationAnalyzerConfig = categorizationAnalyzerConfig; this.bucketCountThresholds = bucketCountThresholds; @@ -65,8 +54,6 @@ protected Aggregator createUnmapped(Aggregator parent, Map metad name, bucketCountThresholds.getRequiredSize(), bucketCountThresholds.getMinDocCount(), - maxUniqueTokens, - maxMatchTokens, similarityThreshold, metadata ); @@ -84,43 +71,47 @@ protected Aggregator createInternal(Aggregator parent, CardinalityUpperBound car if (fieldType == null) { return createUnmapped(parent, metadata); } - // TODO add support for Keyword && KeywordScriptFieldType + // Most of the text and keyword family of fields use a bespoke TextSearchInfo that doesn't match any + // of the static final ones created in the TextSearchInfo class definition. KeywordScriptFieldType is + // the exception that we do want to support, so we need to check for that separately. (It's not a + // complete disaster if we end up analyzing an inappropriate field, for example if the user has added + // a new field type via a plugin that also creates a bespoke TextSearchInfo member - it will just get + // converted to a string and then likely the analyzer won't create any tokens, so the categorizer + // will see an empty token list.) if (fieldType.getTextSearchInfo() == TextSearchInfo.NONE - || fieldType.getTextSearchInfo() == TextSearchInfo.SIMPLE_MATCH_WITHOUT_TERMS) { + || (fieldType.getTextSearchInfo() == TextSearchInfo.SIMPLE_MATCH_WITHOUT_TERMS + && fieldType instanceof KeywordScriptFieldType == false)) { throw new IllegalArgumentException( "categorize_text agg [" + name - + "] only works on analyzable text fields. Cannot aggregate field type [" + + "] only works on text and keyword fields. Cannot aggregate field type [" + fieldType.name() + "] via [" + fieldType.getClass().getSimpleName() + "]" ); } - TermsAggregator.BucketCountThresholds thresholds = new TermsAggregator.BucketCountThresholds(this.bucketCountThresholds); - if (thresholds.getShardSize() == CategorizeTextAggregationBuilder.DEFAULT_BUCKET_COUNT_THRESHOLDS.getShardSize()) { + TermsAggregator.BucketCountThresholds bucketCountThresholds = new TermsAggregator.BucketCountThresholds(this.bucketCountThresholds); + if (bucketCountThresholds.getShardSize() == CategorizeTextAggregationBuilder.DEFAULT_BUCKET_COUNT_THRESHOLDS.getShardSize()) { // The user has not made a shardSize selection. Use default // heuristic to avoid any wrong-ranking caused by distributed // counting // TODO significant text does a 2x here, should we as well? - thresholds.setShardSize(BucketUtils.suggestShardSideQueueSize(thresholds.getRequiredSize())); + bucketCountThresholds.setShardSize(BucketUtils.suggestShardSideQueueSize(bucketCountThresholds.getRequiredSize())); } - thresholds.ensureValidity(); + bucketCountThresholds.ensureValidity(); return new CategorizeTextAggregator( name, factories, context, parent, - indexedFieldName, + fieldType.name(), fieldType, - thresholds, - maxUniqueTokens, - maxMatchTokens, + bucketCountThresholds, similarityThreshold, categorizationAnalyzerConfig, metadata ); } - } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java index 4f6b33c6ba803..2bc9a7fca5949 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java @@ -8,12 +8,10 @@ package org.elasticsearch.xpack.ml.aggs.categorization; import org.apache.lucene.util.BytesRef; -import org.apache.lucene.util.PriorityQueue; +import org.elasticsearch.ElasticsearchException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.util.BytesRefHash; -import org.elasticsearch.search.aggregations.AggregationExecutionException; import org.elasticsearch.search.aggregations.AggregationReduceContext; import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.search.aggregations.InternalAggregation; @@ -23,143 +21,58 @@ import org.elasticsearch.search.aggregations.support.SamplingContext; import org.elasticsearch.xcontent.ToXContentFragment; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.job.results.CategoryDefinition; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; -import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.stream.Collectors; -import static org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash.WILD_CARD_REF; - public class InternalCategorizationAggregation extends InternalMultiBucketAggregation< InternalCategorizationAggregation, InternalCategorizationAggregation.Bucket> { - // Carries state allowing for delayed reduction of the bucket - // This allows us to keep from accidentally calling "reduce" on the sub-aggs more than once - private static class DelayedCategorizationBucket { - private final BucketKey key; - private long docCount; - private final List toReduce; - - DelayedCategorizationBucket(BucketKey key, List toReduce, long docCount) { - this.key = key; - this.toReduce = new ArrayList<>(toReduce); - this.docCount = docCount; - } - - public long getDocCount() { - return docCount; - } - - public Bucket reduce(BucketKey bucketKey, AggregationReduceContext reduceContext) { - List innerAggs = new ArrayList<>(toReduce.size()); - long totalDocCount = 0; - for (Bucket bucket : toReduce) { - innerAggs.add(bucket.aggregations); - totalDocCount += bucket.docCount; - } - return new Bucket(bucketKey, totalDocCount, InternalAggregations.reduce(innerAggs, reduceContext)); - } - - public DelayedCategorizationBucket add(Bucket bucket) { - this.docCount += bucket.docCount; - this.toReduce.add(bucket); - return this; - } - - public DelayedCategorizationBucket add(DelayedCategorizationBucket bucket) { - this.docCount += bucket.docCount; - this.toReduce.addAll(bucket.toReduce); - return this; - } - } - - static class BucketCountPriorityQueue extends PriorityQueue { - BucketCountPriorityQueue(int size) { - super(size); - } - - @Override - protected boolean lessThan(Bucket a, Bucket b) { - return a.docCount < b.docCount; - } - } - - static class BucketKey implements ToXContentFragment, Writeable, Comparable { + static class BucketKey implements ToXContentFragment, Comparable { private final BytesRef[] key; - static BucketKey withCollapsedWildcards(BytesRef[] key) { - if (key.length <= 1) { - return new BucketKey(key); - } - List collapsedWildCards = new ArrayList<>(); - boolean previousTokenWildCard = false; - for (BytesRef token : key) { - if (token.equals(WILD_CARD_REF)) { - if (previousTokenWildCard == false) { - previousTokenWildCard = true; - collapsedWildCards.add(WILD_CARD_REF); - } - } else { - previousTokenWildCard = false; - collapsedWildCards.add(token); - } - } - if (collapsedWildCards.size() == key.length) { - return new BucketKey(key); - } - return new BucketKey(collapsedWildCards.toArray(BytesRef[]::new)); + BucketKey(SerializableTokenListCategory serializableCategory) { + this.key = serializableCategory.getKeyTokens(); } BucketKey(BytesRef[] key) { this.key = key; } - BucketKey(StreamInput in) throws IOException { - key = in.readArray(StreamInput::readBytesRef, BytesRef[]::new); - } - @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - return builder.value(asString()); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeArray(StreamOutput::writeBytesRef, key); - } - - public String asString() { - StringBuilder builder = new StringBuilder(); - for (int i = 0; i < key.length - 1; i++) { - builder.append(key[i].utf8ToString()).append(" "); - } - builder.append(key[key.length - 1].utf8ToString()); - return builder.toString(); + return builder.value(toString()); } @Override public String toString() { - return asString(); + return Arrays.stream(key).map(BytesRef::utf8ToString).collect(Collectors.joining(" ")); } @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - BucketKey bucketKey = (BucketKey) o; - return Arrays.equals(key, bucketKey.key); + public int hashCode() { + return Arrays.hashCode(key); } @Override - public int hashCode() { - return Arrays.hashCode(key); + public boolean equals(Object other) { + if (other == this) { + return true; + } + if (other == null || getClass() != other.getClass()) { + return false; + } + BucketKey that = (BucketKey) other; + return Arrays.equals(this.key, that.key); } public BytesRef[] keyAsTokens() { @@ -170,45 +83,69 @@ public BytesRef[] keyAsTokens() { public int compareTo(BucketKey o) { return Arrays.compare(key, o.key); } - } public static class Bucket extends InternalMultiBucketAggregation.InternalBucket implements MultiBucketsAggregation.Bucket, Comparable { - // Used on the shard level to keep track of sub aggregations - long bucketOrd; - final BucketKey key; - final long docCount; - InternalAggregations aggregations; + private final SerializableTokenListCategory serializableCategory; + private final BucketKey key; + private long bucketOrd; + private InternalAggregations aggregations; + + public Bucket(SerializableTokenListCategory serializableCategory, long bucketOrd) { + this(serializableCategory, bucketOrd, InternalAggregations.EMPTY); + } - public Bucket(BucketKey key, long docCount, InternalAggregations aggregations) { - this.key = key; - this.docCount = docCount; - this.aggregations = aggregations; + public Bucket(SerializableTokenListCategory serializableCategory, long bucketOrd, InternalAggregations aggregations) { + this.serializableCategory = serializableCategory; + this.key = new BucketKey(serializableCategory); + this.bucketOrd = bucketOrd; + this.aggregations = Objects.requireNonNull(aggregations); } public Bucket(StreamInput in) throws IOException { - key = new BucketKey(in); - docCount = in.readVLong(); + // Disallow this aggregation in mixed version clusters that cross the algorithm change boundary. + if (in.getVersion().before(CategorizeTextAggregationBuilder.ALGORITHM_CHANGED_VERSION)) { + throw new ElasticsearchException( + "[" + + CategorizeTextAggregationBuilder.NAME + + "] aggregation cannot be used in a cluster where some nodes have version [" + + CategorizeTextAggregationBuilder.ALGORITHM_CHANGED_VERSION + + "] or higher and others have a version before this" + ); + } + serializableCategory = new SerializableTokenListCategory(in); + key = new BucketKey(serializableCategory); + bucketOrd = -1; aggregations = InternalAggregations.readFrom(in); } @Override public void writeTo(StreamOutput out) throws IOException { - key.writeTo(out); - out.writeVLong(getDocCount()); + // Disallow this aggregation in mixed version clusters that cross the algorithm change boundary. + if (out.getVersion().before(CategorizeTextAggregationBuilder.ALGORITHM_CHANGED_VERSION)) { + throw new ElasticsearchException( + "[" + + CategorizeTextAggregationBuilder.NAME + + "] aggregation cannot be used in a cluster where some nodes have version [" + + CategorizeTextAggregationBuilder.ALGORITHM_CHANGED_VERSION + + "] or higher and others have a version before this" + ); + } + serializableCategory.writeTo(out); aggregations.writeTo(out); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(CommonFields.DOC_COUNT.getPreferredName(), docCount); + builder.field(CommonFields.DOC_COUNT.getPreferredName(), serializableCategory.getNumMatches()); builder.field(CommonFields.KEY.getPreferredName()); key.toXContent(builder, params); + builder.field(CategoryDefinition.MAX_MATCHING_LENGTH.getPreferredName(), serializableCategory.maxMatchingStringLen()); aggregations.toXContentInternal(builder, params); builder.endObject(); return builder; @@ -225,12 +162,12 @@ public Object getKey() { @Override public String getKeyAsString() { - return key.asString(); + return key.toString(); } @Override public long getDocCount() { - return docCount; + return serializableCategory.getNumMatches(); } @Override @@ -238,22 +175,37 @@ public Aggregations getAggregations() { return aggregations; } + void setAggregations(InternalAggregations aggregations) { + this.aggregations = aggregations; + } + + long getBucketOrd() { + return bucketOrd; + } + + SerializableTokenListCategory getSerializableCategory() { + return serializableCategory; + } + @Override public String toString() { - return "Bucket{" + "key=" + getKeyAsString() + ", docCount=" + docCount + ", aggregations=" + aggregations.asMap() + "}\n"; + return "Bucket{key=" + + getKeyAsString() + + ", docCount=" + + serializableCategory.getNumMatches() + + ", aggregations=" + + aggregations.asMap() + + "}\n"; } @Override - public int compareTo(Bucket o) { - return key.compareTo(o.key); + public int compareTo(Bucket other) { + return Long.signum(this.serializableCategory.getNumMatches() - other.serializableCategory.getNumMatches()); } - } private final List buckets; - private final int maxUniqueTokens; private final int similarityThreshold; - private final int maxMatchTokens; private final int requiredSize; private final long minDocCount; @@ -261,28 +213,22 @@ protected InternalCategorizationAggregation( String name, int requiredSize, long minDocCount, - int maxUniqueTokens, - int maxMatchTokens, int similarityThreshold, Map metadata ) { - this(name, requiredSize, minDocCount, maxUniqueTokens, maxMatchTokens, similarityThreshold, metadata, new ArrayList<>()); + this(name, requiredSize, minDocCount, similarityThreshold, metadata, new ArrayList<>()); } protected InternalCategorizationAggregation( String name, int requiredSize, long minDocCount, - int maxUniqueTokens, - int maxMatchTokens, int similarityThreshold, Map metadata, List buckets ) { super(name, metadata); this.buckets = buckets; - this.maxUniqueTokens = maxUniqueTokens; - this.maxMatchTokens = maxMatchTokens; this.similarityThreshold = similarityThreshold; this.minDocCount = minDocCount; this.requiredSize = requiredSize; @@ -290,8 +236,16 @@ protected InternalCategorizationAggregation( public InternalCategorizationAggregation(StreamInput in) throws IOException { super(in); - this.maxUniqueTokens = in.readVInt(); - this.maxMatchTokens = in.readVInt(); + // Disallow this aggregation in mixed version clusters that cross the algorithm change boundary. + if (in.getVersion().before(CategorizeTextAggregationBuilder.ALGORITHM_CHANGED_VERSION)) { + throw new ElasticsearchException( + "[" + + CategorizeTextAggregationBuilder.NAME + + "] aggregation cannot be used in a cluster where some nodes have version [" + + CategorizeTextAggregationBuilder.ALGORITHM_CHANGED_VERSION + + "] or higher and others have a version before this" + ); + } this.similarityThreshold = in.readVInt(); this.buckets = in.readList(Bucket::new); this.requiredSize = readSize(in); @@ -300,8 +254,16 @@ public InternalCategorizationAggregation(StreamInput in) throws IOException { @Override protected void doWriteTo(StreamOutput out) throws IOException { - out.writeVInt(maxUniqueTokens); - out.writeVInt(maxMatchTokens); + // Disallow this aggregation in mixed version clusters that cross the algorithm change boundary. + if (out.getVersion().before(CategorizeTextAggregationBuilder.ALGORITHM_CHANGED_VERSION)) { + throw new ElasticsearchException( + "[" + + CategorizeTextAggregationBuilder.NAME + + "] aggregation cannot be used in a cluster where some nodes have version [" + + CategorizeTextAggregationBuilder.ALGORITHM_CHANGED_VERSION + + "] or higher and others have a version before this" + ); + } out.writeVInt(similarityThreshold); out.writeList(buckets); writeSize(requiredSize, out); @@ -319,27 +281,18 @@ public XContentBuilder doXContentBody(XContentBuilder builder, Params params) th } @Override - public InternalCategorizationAggregation create(List bucketList) { - return new InternalCategorizationAggregation( - name, - requiredSize, - minDocCount, - maxUniqueTokens, - maxMatchTokens, - similarityThreshold, - super.metadata, - bucketList - ); + public InternalCategorizationAggregation create(List buckets) { + return new InternalCategorizationAggregation(name, requiredSize, minDocCount, similarityThreshold, super.metadata, buckets); } @Override public Bucket createBucket(InternalAggregations aggregations, Bucket prototype) { - return new Bucket(prototype.key, prototype.docCount, aggregations); + return new Bucket(prototype.serializableCategory, prototype.bucketOrd, aggregations); } @Override protected Bucket reduceBucket(List buckets, AggregationReduceContext context) { - throw new IllegalArgumentException("For optimization purposes, typical bucket path is not supported"); + throw new UnsupportedOperationException("For optimization purposes, typical bucket path is not supported"); } @Override @@ -355,79 +308,38 @@ public String getWriteableName() { @Override public InternalAggregation reduce(List aggregations, AggregationReduceContext reduceContext) { try (CategorizationBytesRefHash hash = new CategorizationBytesRefHash(new BytesRefHash(1L, reduceContext.bigArrays()))) { - CategorizationTokenTree categorizationTokenTree = new CategorizationTokenTree( - maxUniqueTokens, - maxMatchTokens, - similarityThreshold + TokenListCategorizer categorizer = new TokenListCategorizer( + hash, + null, // part-of-speech dictionary is not needed for the reduce phase as weights are already decided + (float) similarityThreshold / 100.0f ); - // TODO: Could we do a merge sort similar to terms? - // It would require us returning partial reductions sorted by key, not by doc_count - // First, make sure we have all the counts for equal categorizations - Map reduced = new HashMap<>(); + // Merge all the categories into the newly created empty categorizer to combine them for (InternalAggregation aggregation : aggregations) { InternalCategorizationAggregation categorizationAggregation = (InternalCategorizationAggregation) aggregation; for (Bucket bucket : categorizationAggregation.buckets) { - reduced.computeIfAbsent(bucket.key, key -> new DelayedCategorizationBucket(key, new ArrayList<>(1), 0L)).add(bucket); - } - } - - reduced.values().stream().sorted(Comparator.comparing(DelayedCategorizationBucket::getDocCount).reversed()).forEach(bucket -> - // Parse tokens takes document count into account and merging on smallest groups - categorizationTokenTree.parseTokens(hash.getIds(bucket.key.keyAsTokens()), bucket.docCount)); - categorizationTokenTree.mergeSmallestChildren(); - Map mergedBuckets = new HashMap<>(); - for (DelayedCategorizationBucket delayedBucket : reduced.values()) { - TextCategorization group = categorizationTokenTree.parseTokensConst(hash.getIds(delayedBucket.key.keyAsTokens())) - .orElseThrow( - () -> new AggregationExecutionException( - "Unexpected null categorization group for bucket [" + delayedBucket.key.asString() + "]" - ) - ); - BytesRef[] categoryTokens = hash.getDeeps(group.getCategorization()); - - BucketKey key = reduceContext.isFinalReduce() - ? BucketKey.withCollapsedWildcards(categoryTokens) - : new BucketKey(categoryTokens); - mergedBuckets.computeIfAbsent( - key, - k -> new DelayedCategorizationBucket(k, new ArrayList<>(delayedBucket.toReduce.size()), 0L) - ).add(delayedBucket); - } - - final int size = reduceContext.isFinalReduce() == false ? mergedBuckets.size() : Math.min(requiredSize, mergedBuckets.size()); - final PriorityQueue pq = new BucketCountPriorityQueue(size); - for (Map.Entry keyAndBuckets : mergedBuckets.entrySet()) { - final BucketKey key = keyAndBuckets.getKey(); - DelayedCategorizationBucket bucket = keyAndBuckets.getValue(); - Bucket newBucket = bucket.reduce(key, reduceContext); - if ((newBucket.docCount >= minDocCount) || reduceContext.isFinalReduce() == false) { - Bucket removed = pq.insertWithOverflow(newBucket); - if (removed == null) { - reduceContext.consumeBucketsAndMaybeBreak(1); - } else { - reduceContext.consumeBucketsAndMaybeBreak(-countInnerBucket(removed)); + categorizer.mergeWireCategory(bucket.serializableCategory).addSubAggs((InternalAggregations) bucket.getAggregations()); + if (reduceContext.isCanceled().get()) { + break; } - } else { - reduceContext.consumeBucketsAndMaybeBreak(-countInnerBucket(newBucket)); } } - Bucket[] bucketList = new Bucket[pq.size()]; - for (int i = pq.size() - 1; i >= 0; i--) { - bucketList[i] = pq.pop(); - } + final int size = reduceContext.isFinalReduce() + ? Math.min(requiredSize, categorizer.getCategoryCount()) + : categorizer.getCategoryCount(); + Bucket[] mergedBuckets = categorizer.toOrderedBuckets(size, reduceContext.isFinalReduce() ? minDocCount : 0, reduceContext); + // TODO: not sure if this next line is correct - if we discarded some categories due to size or minDocCount is this handled? + reduceContext.consumeBucketsAndMaybeBreak(mergedBuckets.length); // Keep the top categories top, but then sort by the key for those with duplicate counts if (reduceContext.isFinalReduce()) { - Arrays.sort(bucketList, Comparator.comparing(Bucket::getDocCount).reversed().thenComparing(Bucket::getRawKey)); + Arrays.sort(mergedBuckets, Comparator.comparing(Bucket::getDocCount).reversed().thenComparing(Bucket::getRawKey)); } return new InternalCategorizationAggregation( name, requiredSize, minDocCount, - maxUniqueTokens, - maxMatchTokens, similarityThreshold, metadata, - Arrays.asList(bucketList) + Arrays.asList(mergedBuckets) ); } } @@ -438,15 +350,13 @@ public InternalAggregation finalizeSampling(SamplingContext samplingContext) { name, requiredSize, minDocCount, - maxUniqueTokens, - maxMatchTokens, similarityThreshold, metadata, buckets.stream() .map( b -> new Bucket( - b.key, - samplingContext.scaleUp(b.docCount), + new SerializableTokenListCategory(b.getSerializableCategory(), samplingContext.scaleUp(b.getDocCount())), + b.getBucketOrd(), InternalAggregations.finalizeSampling(b.aggregations, samplingContext) ) ) @@ -454,18 +364,10 @@ public InternalAggregation finalizeSampling(SamplingContext samplingContext) { ); } - public int getMaxUniqueTokens() { - return maxUniqueTokens; - } - public int getSimilarityThreshold() { return similarityThreshold; } - public int getMaxMatchTokens() { - return maxMatchTokens; - } - public int getRequiredSize() { return requiredSize; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization2/SerializableTokenListCategory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/SerializableTokenListCategory.java similarity index 95% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization2/SerializableTokenListCategory.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/SerializableTokenListCategory.java index 3700430c8b195..f81ee2bf3efdf 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization2/SerializableTokenListCategory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/SerializableTokenListCategory.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.ml.aggs.categorization2; +package org.elasticsearch.xpack.ml.aggs.categorization; import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.io.stream.StreamInput; @@ -19,6 +19,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.stream.Collectors; /** * {@link TokenListCategory} cannot be serialized between nodes as its token IDs @@ -209,4 +210,13 @@ public boolean equals(Object other) { && this.origUniqueTokenWeight == that.origUniqueTokenWeight && this.numMatches == that.numMatches; } + + @Override + public String toString() { + return Arrays.stream(keyTokenIndexes) + .mapToObj(index -> baseTokens[index].utf8ToString()) + .collect(Collectors.joining(", ", "Category with key tokens [", "]")) + Arrays.stream(baseTokens) + .map(BytesRef::utf8ToString) + .collect(Collectors.joining(", ", " and base tokens [", "]")); + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorization.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorization.java deleted file mode 100644 index 76ec8b59487f4..0000000000000 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorization.java +++ /dev/null @@ -1,116 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.ml.aggs.categorization; - -import org.apache.lucene.util.Accountable; -import org.apache.lucene.util.RamUsageEstimator; - -import java.util.Arrays; - -import static org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash.WILD_CARD_ID; - -/** - * A text categorization group that provides methods for: - * - calculating similarity between it and a new text - * - expanding the existing categorization by adding a new array of tokens - */ -class TextCategorization implements Accountable { - - private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TextCategorization.class); - private final long id; - private final int[] categorization; - private final long[] tokenCounts; - private long count; - - // Used at the shard level for tracking the bucket ordinal for collecting sub aggregations - long bucketOrd; - - TextCategorization(int[] tokenIds, long count, long id) { - this.id = id; - this.categorization = tokenIds; - this.count = count; - this.tokenCounts = new long[tokenIds.length]; - Arrays.fill(this.tokenCounts, count); - } - - public long getId() { - return id; - } - - int[] getCategorization() { - return categorization; - } - - public long getCount() { - return count; - } - - Similarity calculateSimilarity(int[] tokenIds) { - assert tokenIds.length == this.categorization.length; - int eqParams = 0; - long tokenCount = 0; - long tokensKept = 0; - for (int i = 0; i < tokenIds.length; i++) { - if (tokenIds[i] == this.categorization[i]) { - tokensKept += tokenCounts[i]; - tokenCount += tokenCounts[i]; - } else if (this.categorization[i] == WILD_CARD_ID) { - eqParams++; - } else { - tokenCount += tokenCounts[i]; - } - } - return new Similarity((double) tokensKept / tokenCount, eqParams); - } - - void addTokens(int[] tokenIds, long docCount) { - assert tokenIds.length == this.categorization.length; - for (int i = 0; i < tokenIds.length; i++) { - if (tokenIds[i] != this.categorization[i]) { - this.categorization[i] = WILD_CARD_ID; - } else { - tokenCounts[i] += docCount; - } - } - this.count += docCount; - } - - @Override - public long ramBytesUsed() { - return SHALLOW_SIZE + RamUsageEstimator.sizeOf(categorization) // categorization token Ids - + RamUsageEstimator.sizeOf(tokenCounts); // tokenCounts - } - - static class Similarity implements Comparable { - private final double similarity; - private final int wildCardCount; - - private Similarity(double similarity, int wildCardCount) { - this.similarity = similarity; - this.wildCardCount = wildCardCount; - } - - @Override - public int compareTo(Similarity o) { - int d = Double.compare(similarity, o.similarity); - if (d != 0) { - return d; - } - return Integer.compare(wildCardCount, o.wildCardCount); - } - - public double getSimilarity() { - return similarity; - } - - public int getWildCardCount() { - return wildCardCount; - } - } - -} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization2/TokenListCategorizer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TokenListCategorizer.java similarity index 83% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization2/TokenListCategorizer.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TokenListCategorizer.java index e22ff15d23057..e1f2404ee56b5 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization2/TokenListCategorizer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TokenListCategorizer.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.ml.aggs.categorization2; +package org.elasticsearch.xpack.ml.aggs.categorization; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -16,7 +16,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.search.aggregations.AggregationReduceContext; import org.elasticsearch.search.aggregations.InternalAggregations; -import org.elasticsearch.xpack.ml.aggs.categorization2.TokenListCategory.TokenAndWeight; +import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategory.TokenAndWeight; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -127,11 +127,15 @@ public TokenListCategory computeCategory(List weightedTokenIds, // Although this can be done using stream() and collect() with a grouping // collector, profiling shows it's faster to use a handcrafted loop. int workWeight = 0; + int minReweightedTotalWeight = 0; + int maxReweightedTotalWeight = 0; SortedMap groupingMap = new TreeMap<>(); for (TokenAndWeight weightedTokenId : weightedTokenIds) { int tokenId = weightedTokenId.getTokenId(); int weight = weightedTokenId.getWeight(); workWeight += weight; + minReweightedTotalWeight += WeightCalculator.getMinMatchingWeight(weight); + maxReweightedTotalWeight += WeightCalculator.getMaxMatchingWeight(weight); // There's a tradeoff here: the map value duplicates the map key. But // this means that in the case where a token only occurs once we can // reuse the original TokenAndWeight object instead of creating a new @@ -143,34 +147,58 @@ public TokenListCategory computeCategory(List weightedTokenIds, } List workTokenUniqueIds = new ArrayList<>(groupingMap.values()); - return computeCategory(weightedTokenIds, workTokenUniqueIds, workWeight, unfilteredStringLen, unfilteredStringLen, numDocs); + return computeCategory( + weightedTokenIds, + workTokenUniqueIds, + workWeight, + minReweightedTotalWeight, + maxReweightedTotalWeight, + unfilteredStringLen, + unfilteredStringLen, + numDocs + ); } public TokenListCategory mergeWireCategory(SerializableTokenListCategory serializableCategory) { + int sizeBefore = categoriesByNumMatches.size(); TokenListCategory foreignCategory = new TokenListCategory(0, serializableCategory, bytesRefHash); - return computeCategory( + TokenListCategory mergedCategory = computeCategory( foreignCategory.getBaseWeightedTokenIds(), foreignCategory.getCommonUniqueTokenIds(), - foreignCategory.getCommonUniqueTokenWeight(), + foreignCategory.getBaseWeight(), + // These next two lines are crude approximations + // TODO: improve the calculation of this min and max + WeightCalculator.getMinMatchingWeight(foreignCategory.getBaseWeight()), + WeightCalculator.getMaxMatchingWeight(foreignCategory.getBaseWeight()), foreignCategory.getBaseUnfilteredLength(), foreignCategory.getMaxUnfilteredStringLength(), foreignCategory.getNumMatches() ); + if (logger.isDebugEnabled() && categoriesByNumMatches.size() == sizeBefore) { + logger.debug( + "Merged wire category [{}] into existing category to form [{}]", + serializableCategory, + new SerializableTokenListCategory(mergedCategory, bytesRefHash) + ); + } + return mergedCategory; } private synchronized TokenListCategory computeCategory( List weightedTokenIds, List workTokenUniqueIds, int workWeight, + int minReweightedTotalWeight, + int maxReweightedTotalWeight, int unfilteredStringLen, int maxUnfilteredStringLen, long numDocs ) { // Determine the minimum and maximum token weight that could possibly match the weight we've got. - int minWeight = minMatchingWeight(workWeight, lowerThreshold); - int maxWeight = maxMatchingWeight(workWeight, lowerThreshold); + int minWeight = minMatchingWeight(minReweightedTotalWeight, lowerThreshold); + int maxWeight = maxMatchingWeight(maxReweightedTotalWeight, lowerThreshold); // We search previous categories in descending order of the number of matches we've seen for them. int bestSoFarIndex = -1; @@ -193,16 +221,20 @@ private synchronized TokenListCategory computeCategory( if (matchesSearch == false) { // Quickly rule out wildly different token weights prior to doing the expensive similarity calculations. if (baseWeight < minWeight || baseWeight > maxWeight) { + assert baseTokenIds.equals(weightedTokenIds) == false + : "Min [" + minWeight + "] and/or max [" + maxWeight + "] weights calculated incorrectly " + baseTokenIds; continue; } // Rule out categories where adding the current string would unacceptably reduce the number of unique common tokens. - int origUniqueTokenWeight = compCategory.getOrigUniqueTokenWeight(); - int commonUniqueTokenWeight = compCategory.getCommonUniqueTokenWeight(); int missingCommonTokenWeight = compCategory.missingCommonTokenWeight(workTokenUniqueIds); - float proportionOfOrig = (float) (commonUniqueTokenWeight - missingCommonTokenWeight) / (float) origUniqueTokenWeight; - if (proportionOfOrig < lowerThreshold) { - continue; + if (missingCommonTokenWeight > 0) { + int origUniqueTokenWeight = compCategory.getOrigUniqueTokenWeight(); + int commonUniqueTokenWeight = compCategory.getCommonUniqueTokenWeight(); + float proportionOfOrig = (float) (commonUniqueTokenWeight - missingCommonTokenWeight) / (float) origUniqueTokenWeight; + if (proportionOfOrig < lowerThreshold) { + continue; + } } } @@ -229,8 +261,8 @@ private synchronized TokenListCategory computeCategory( bestSoFarSimilarity = similarity; // Recalculate the minimum and maximum token counts that might produce a better match. - minWeight = minMatchingWeight(workWeight, similarity); - maxWeight = maxMatchingWeight(workWeight, similarity); + minWeight = minMatchingWeight(minReweightedTotalWeight, similarity); + maxWeight = maxMatchingWeight(maxReweightedTotalWeight, similarity); } } @@ -348,17 +380,15 @@ static int maxMatchingWeight(int weight, float threshold) { } /** - * Compute similarity between two vectors + * Compute the similarity between two vectors. */ static float similarity(List left, int leftWeight, List right, int rightWeight) { - float similarity = 1.0f; - int maxWeight = Math.max(leftWeight, rightWeight); if (maxWeight > 0) { - similarity = 1.0f - (float) TokenListSimilarityTester.weightedEditDistance(left, right) / (float) maxWeight; + return 1.0f - (float) TokenListSimilarityTester.weightedEditDistance(left, right) / (float) maxWeight; + } else { + return 1.0f; } - - return similarity; } public InternalCategorizationAggregation.Bucket[] toOrderedBuckets(int size) { @@ -402,6 +432,9 @@ static class WeightCalculator { private static final int MIN_DICTIONARY_LENGTH = 2; private static final int CONSECUTIVE_DICTIONARY_WORDS_FOR_EXTRA_WEIGHT = 3; + private static final int EXTRA_VERB_WEIGHT = 5; + private static final int EXTRA_OTHER_DICTIONARY_WEIGHT = 2; + private static final int ADJACENCY_BOOST_MULTIPLIER = 6; private final CategorizationPartOfSpeechDictionary partOfSpeechDictionary; private int consecutiveHighWeights; @@ -428,10 +461,24 @@ int calculateWeight(String term) { consecutiveHighWeights = 0; return 1; } - ++consecutiveHighWeights; - int posWeight = (pos == CategorizationPartOfSpeechDictionary.PartOfSpeech.VERB) ? 6 : 3; - int adjacencyBoost = (consecutiveHighWeights >= CONSECUTIVE_DICTIONARY_WORDS_FOR_EXTRA_WEIGHT) ? 6 : 0; - return posWeight + adjacencyBoost; + int posWeight = (pos == CategorizationPartOfSpeechDictionary.PartOfSpeech.VERB) + ? EXTRA_VERB_WEIGHT + : EXTRA_OTHER_DICTIONARY_WEIGHT; + int adjacencyBoost = (++consecutiveHighWeights >= CONSECUTIVE_DICTIONARY_WORDS_FOR_EXTRA_WEIGHT) + ? ADJACENCY_BOOST_MULTIPLIER + : 1; + return 1 + (posWeight * adjacencyBoost); + } + + static int getMinMatchingWeight(int weight) { + return (weight <= ADJACENCY_BOOST_MULTIPLIER) ? weight : (1 + (weight - 1) / ADJACENCY_BOOST_MULTIPLIER); + } + + static int getMaxMatchingWeight(int weight) { + return (weight <= Math.min(EXTRA_VERB_WEIGHT, EXTRA_OTHER_DICTIONARY_WEIGHT) + || weight > Math.max(EXTRA_VERB_WEIGHT + 1, EXTRA_OTHER_DICTIONARY_WEIGHT + 1)) + ? weight + : (1 + (weight - 1) * ADJACENCY_BOOST_MULTIPLIER); } } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization2/TokenListCategory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TokenListCategory.java similarity index 92% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization2/TokenListCategory.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TokenListCategory.java index 4a95b6950a5cc..fce206203cb45 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization2/TokenListCategory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TokenListCategory.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.ml.aggs.categorization2; +package org.elasticsearch.xpack.ml.aggs.categorization; import org.apache.lucene.util.Accountable; import org.elasticsearch.common.util.set.Sets; @@ -152,6 +152,7 @@ public TokenListCategory( // As well as being unique, the unique token IDs must be in the base token IDs. assert uniqueTokenIds.stream().map(TokenAndWeight::getTokenId).distinct().count() == uniqueTokenIds.size() : "Unique token IDs contains duplicates " + uniqueTokenIds; + assert isSorted(uniqueTokenIds) : "Unique token IDs is not sorted " + uniqueTokenIds; assert Sets.intersection( uniqueTokenIds.stream().map(TokenAndWeight::getTokenId).collect(Collectors.toSet()), baseWeightedTokenIds.stream().map(TokenAndWeight::getTokenId).collect(Collectors.toSet()) @@ -206,6 +207,7 @@ public void addString( List uniqueTokenIds, long numMatches ) { + assert isSorted(uniqueTokenIds) : "Unique token IDs is not sorted " + uniqueTokenIds; assert numMatches > 0 : "number of matches must be positive, got " + numMatches; mergeWith(unfilteredLength, weightedTokenIds, 0, weightedTokenIds.size(), uniqueTokenIds, numMatches); } @@ -257,27 +259,33 @@ public List getSubAggs() { * both lists in parallel looking for differences. */ private void updateCommonUniqueTokenIds(List newUniqueTokenIds) { + assert commonUniqueTokenWeight == commonUniqueTokenIds.stream().mapToInt(TokenAndWeight::getWeight).sum() + : "commonUniqueTokenWeight not up to date"; + + commonUniqueTokenWeight = 0; int initialSize = commonUniqueTokenIds.size(); + int commonIndex = 0; int newIndex = 0; int outputIndex = 0; - for (int commonIndex = 0; commonIndex < initialSize; ++commonIndex) { + while (commonIndex < initialSize) { + if (newIndex >= newUniqueTokenIds.size()) { + ++commonIndex; + continue; + } TokenAndWeight commonTokenAndWeight = commonUniqueTokenIds.get(commonIndex); - TokenAndWeight newTokenAndWeight; - if (newIndex >= newUniqueTokenIds.size() - || commonTokenAndWeight.getTokenId() < (newTokenAndWeight = newUniqueTokenIds.get(newIndex)).getTokenId()) { - commonUniqueTokenWeight -= commonTokenAndWeight.getWeight(); - } else { - if (commonTokenAndWeight.getTokenId() == newTokenAndWeight.getTokenId()) { - if (commonTokenAndWeight.getWeight() == newTokenAndWeight.getWeight()) { - commonUniqueTokenIds.set(outputIndex++, commonTokenAndWeight); - } else { - commonUniqueTokenWeight -= commonTokenAndWeight.getWeight(); - } - } - ++newIndex; + int cmp = commonTokenAndWeight.compareTo(newUniqueTokenIds.get(newIndex)); + if (cmp < 0) { + ++commonIndex; + continue; + } + if (cmp == 0) { + commonUniqueTokenIds.set(outputIndex++, commonTokenAndWeight); + commonUniqueTokenWeight += commonTokenAndWeight.getWeight(); + ++commonIndex; } + ++newIndex; } if (outputIndex < initialSize) { commonUniqueTokenIds.subList(outputIndex, initialSize).clear(); @@ -286,6 +294,8 @@ private void updateCommonUniqueTokenIds(List newUniqueTokenIds) assert outputIndex == initialSize : "should be impossible for output index to exceed initial size, but got " + outputIndex + " > " + initialSize; } + assert commonUniqueTokenWeight == commonUniqueTokenIds.stream().mapToInt(TokenAndWeight::getWeight).sum() + : "commonUniqueTokenWeight not up to date"; } /** @@ -360,7 +370,7 @@ && isTokenIdCommon(baseWeightedTokenIds.get(orderedCommonTokenBeginIndex)) == fa if (newToken.getTokenId() != baseToken.getTokenId()) { ++newIndex; } else { - tryWeight += newToken.getWeight() + baseToken.getWeight(); + tryWeight += baseToken.getWeight(); break; } } @@ -484,26 +494,25 @@ long getBucketOrd() { } public int missingCommonTokenWeight(List uniqueTokenIds) { + assert isSorted(uniqueTokenIds) : "Unique token IDs is not sorted " + uniqueTokenIds; + int presentWeight = 0; int commonIndex = 0; int testIndex = 0; while (commonIndex < commonUniqueTokenIds.size() && testIndex < uniqueTokenIds.size()) { - switch (Integer.signum(commonUniqueTokenIds.get(commonIndex).compareTo(uniqueTokenIds.get(testIndex)))) { - case -1 -> ++commonIndex; - case 0 -> { - // Don't increment the weight if a given token appears a different - // number of times in the two strings. - int testWeight = uniqueTokenIds.get(testIndex).getWeight(); - if (commonUniqueTokenIds.get(commonIndex).getWeight() == testWeight) { - presentWeight += testWeight; - } - ++commonIndex; - ++testIndex; - } - case 1 -> ++testIndex; - default -> throw new IllegalStateException("signum should not return numbers other than -1, 0 and 1"); + TokenAndWeight commonTokenAndWeight = commonUniqueTokenIds.get(commonIndex); + int cmp = commonTokenAndWeight.compareTo(uniqueTokenIds.get(testIndex)); + if (cmp < 0) { + ++commonIndex; + continue; } + if (cmp == 0) { + // If the token ID matches then consider the token present even if the weight in the test list is different. + presentWeight += commonTokenAndWeight.getWeight(); + ++commonIndex; + } + ++testIndex; } // The missing weight will be the total weight less the weight of those @@ -538,6 +547,7 @@ && isMissingCommonTokenWeightZero(otherUniqueTokenIds) * @return Is every common unique token for this category present with the same weight in the supplied {@code uniqueTokenIds}? */ public boolean isMissingCommonTokenWeightZero(List uniqueTokenIds) { + assert isSorted(uniqueTokenIds) : "Unique token IDs is not sorted " + uniqueTokenIds; int uniqueTokenIdsSize = uniqueTokenIds.size(); int testIndex = 0; @@ -551,8 +561,7 @@ public boolean isMissingCommonTokenWeightZero(List uniqueTokenId return false; } } - if (testTokenAndWeight.getTokenId() != commonTokenAndWeight.getTokenId() - || testTokenAndWeight.getWeight() != commonTokenAndWeight.getWeight()) { + if (testTokenAndWeight.getTokenId() != commonTokenAndWeight.getTokenId()) { return false; } ++testIndex; @@ -714,4 +723,15 @@ public String toString() { return "{" + tokenId + ", " + weight + "}"; } } + + static boolean isSorted(List list) { + TokenAndWeight previousTokenAndWeight = null; + for (TokenAndWeight tokenAndWeight : list) { + if (previousTokenAndWeight != null && tokenAndWeight.compareTo(previousTokenAndWeight) < 0) { + return false; + } + previousTokenAndWeight = tokenAndWeight; + } + return true; + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization2/TokenListSimilarityTester.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TokenListSimilarityTester.java similarity index 96% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization2/TokenListSimilarityTester.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TokenListSimilarityTester.java index 1827bc5921185..25f1056f7dede 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization2/TokenListSimilarityTester.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TokenListSimilarityTester.java @@ -5,9 +5,9 @@ * 2.0. */ -package org.elasticsearch.xpack.ml.aggs.categorization2; +package org.elasticsearch.xpack.ml.aggs.categorization; -import org.elasticsearch.xpack.ml.aggs.categorization2.TokenListCategory.TokenAndWeight; +import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategory.TokenAndWeight; import java.util.List; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java deleted file mode 100644 index 603ad5f98fe71..0000000000000 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java +++ /dev/null @@ -1,405 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.ml.aggs.categorization; - -import org.apache.lucene.util.Accountable; -import org.apache.lucene.util.RamUsageEstimator; -import org.elasticsearch.core.Tuple; -import org.elasticsearch.search.aggregations.AggregationExecutionException; - -import java.util.ArrayList; -import java.util.Comparator; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Optional; -import java.util.PriorityQueue; -import java.util.stream.Collectors; - -import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_REF; -import static org.apache.lucene.util.RamUsageEstimator.sizeOfCollection; -import static org.apache.lucene.util.RamUsageEstimator.sizeOfMap; -import static org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash.WILD_CARD_ID; - -/** - * Tree node classes for the categorization token tree. - * - * Two major node types exist: - * - Inner: which are nodes that have children token nodes - * - Leaf: Which collection multiple {@link TextCategorization} based on similarity restrictions - */ -abstract class TreeNode implements Accountable { - - private long count; - - TreeNode(long count) { - this.count = count; - } - - abstract void mergeWith(TreeNode otherNode); - - abstract boolean isLeaf(); - - final void incCount(long count) { - this.count += count; - } - - final long getCount() { - return count; - } - - // TODO add option for calculating the cost of adding the new group - abstract TextCategorization addText(int[] tokenIds, long docCount, CategorizationTokenTree treeNodeFactory); - - abstract TextCategorization getCategorization(int[] tokenIds); - - abstract List getAllChildrenTextCategorizations(); - - abstract void collapseTinyChildren(); - - static class LeafTreeNode extends TreeNode { - private final List textCategorizations; - private final int similarityThreshold; - - LeafTreeNode(long count, int similarityThreshold) { - super(count); - this.textCategorizations = new ArrayList<>(); - this.similarityThreshold = similarityThreshold; - if (similarityThreshold < 1 || similarityThreshold > 100) { - throw new IllegalArgumentException("similarityThreshold must be between 1 and 100"); - } - } - - @Override - public boolean isLeaf() { - return true; - } - - @Override - void mergeWith(TreeNode treeNode) { - if (treeNode == null) { - return; - } - if (treeNode.isLeaf() == false) { - throw new UnsupportedOperationException( - "cannot merge leaf node with non-leaf node in categorization tree \n[" + this + "]\n[" + treeNode + "]" - ); - } - incCount(treeNode.getCount()); - LeafTreeNode otherLeaf = (LeafTreeNode) treeNode; - for (TextCategorization group : otherLeaf.textCategorizations) { - if (getAndUpdateTextCategorization(group.getCategorization(), group.getCount()).isPresent() == false) { - putNewTextCategorization(group); - } - } - } - - @Override - public long ramBytesUsed() { - return Long.BYTES // count - + NUM_BYTES_OBJECT_REF // list reference - + Integer.BYTES // similarityThreshold - + sizeOfCollection(textCategorizations); - } - - @Override - public TextCategorization addText(int[] tokenIds, long docCount, CategorizationTokenTree treeNodeFactory) { - return getAndUpdateTextCategorization(tokenIds, docCount).orElseGet(() -> { - // Need to update the tree if possible - TextCategorization categorization = putNewTextCategorization(treeNodeFactory.newCategorization(docCount, tokenIds)); - // Get the regular size bytes from the TextCategorization and how much it costs to reference it - treeNodeFactory.incSize(categorization.ramBytesUsed() + RamUsageEstimator.NUM_BYTES_OBJECT_REF); - return categorization; - }); - } - - @Override - List getAllChildrenTextCategorizations() { - return textCategorizations; - } - - @Override - void collapseTinyChildren() {} - - private Optional getAndUpdateTextCategorization(int[] tokenIds, long docCount) { - return getBestCategorization(tokenIds).map(bestGroupAndSimilarity -> { - if ((bestGroupAndSimilarity.v2() * 100) >= similarityThreshold) { - bestGroupAndSimilarity.v1().addTokens(tokenIds, docCount); - return bestGroupAndSimilarity.v1(); - } - return null; - }); - } - - TextCategorization putNewTextCategorization(TextCategorization categorization) { - textCategorizations.add(categorization); - return categorization; - } - - private Optional> getBestCategorization(int[] tokenIds) { - if (textCategorizations.isEmpty()) { - return Optional.empty(); - } - if (textCategorizations.size() == 1) { - return Optional.of( - new Tuple<>(textCategorizations.get(0), textCategorizations.get(0).calculateSimilarity(tokenIds).getSimilarity()) - ); - } - TextCategorization.Similarity maxSimilarity = null; - TextCategorization bestGroup = null; - for (TextCategorization textCategorization : this.textCategorizations) { - TextCategorization.Similarity groupSimilarity = textCategorization.calculateSimilarity(tokenIds); - if (maxSimilarity == null || groupSimilarity.compareTo(maxSimilarity) > 0) { - maxSimilarity = groupSimilarity; - bestGroup = textCategorization; - } - } - return Optional.of(new Tuple<>(bestGroup, maxSimilarity.getSimilarity())); - } - - @Override - public TextCategorization getCategorization(final int[] tokenIds) { - return getBestCategorization(tokenIds).map(Tuple::v1).orElse(null); - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - LeafTreeNode that = (LeafTreeNode) o; - return that.similarityThreshold == similarityThreshold && Objects.equals(textCategorizations, that.textCategorizations); - } - - @Override - public int hashCode() { - return Objects.hash(textCategorizations, similarityThreshold); - } - } - - static class InnerTreeNode extends TreeNode { - - // TODO: Change to LongObjectMap? - private final Map children; - private final int childrenTokenPos; - private final int maxChildren; - private final PriorityQueue smallestChild; - - InnerTreeNode(long count, int childrenTokenPos, int maxChildren) { - super(count); - children = new HashMap<>(); - this.childrenTokenPos = childrenTokenPos; - this.maxChildren = maxChildren; - this.smallestChild = new PriorityQueue<>(maxChildren, Comparator.comparing(NativeIntLongPair::count)); - } - - @Override - boolean isLeaf() { - return false; - } - - @Override - public TextCategorization getCategorization(final int[] tokenIds) { - return getChild(tokenIds[childrenTokenPos]).or(() -> getChild(WILD_CARD_ID)) - .map(node -> node.getCategorization(tokenIds)) - .orElse(null); - } - - @Override - public long ramBytesUsed() { - return Long.BYTES // count - + NUM_BYTES_OBJECT_REF // children reference - + Integer.BYTES // childrenTokenPos - + Integer.BYTES // maxChildren - + NUM_BYTES_OBJECT_REF // smallestChildReference - + sizeOfMap(children, NUM_BYTES_OBJECT_REF) // children, - // Number of items in the queue, reference to tuple, and then the tuple references - + (long) smallestChild.size() * (NUM_BYTES_OBJECT_REF + Integer.BYTES + Long.BYTES); - } - - @Override - public TextCategorization addText(final int[] tokenIds, final long docCount, final CategorizationTokenTree treeNodeFactory) { - final int currentToken = tokenIds[childrenTokenPos]; - TreeNode child = getChild(currentToken).map(node -> { - node.incCount(docCount); - if (smallestChild.isEmpty() == false && smallestChild.peek().tokenId == currentToken) { - smallestChild.add(smallestChild.poll()); - } - return node; - }).orElseGet(() -> { - TreeNode newNode = treeNodeFactory.newNode(docCount, childrenTokenPos + 1, tokenIds); - // The size of the node + entry (since it is a map entry) + extra reference for priority queue - treeNodeFactory.incSize( - newNode.ramBytesUsed() + RamUsageEstimator.HASHTABLE_RAM_BYTES_PER_ENTRY + RamUsageEstimator.NUM_BYTES_OBJECT_REF - ); - return addChild(currentToken, newNode); - }); - return child.addText(tokenIds, docCount, treeNodeFactory); - } - - @Override - void collapseTinyChildren() { - if (this.isLeaf()) { - return; - } - if (children.size() <= 1) { - return; - } - Optional maybeWildChild = getChild(WILD_CARD_ID).or(() -> { - if (smallestChild.size() > 0 && (double) smallestChild.peek().count / this.getCount() <= 1.0 / maxChildren) { - TreeNode tinyChild = children.remove(smallestChild.poll().tokenId); - return Optional.of(addChild(WILD_CARD_ID, tinyChild)); - } - return Optional.empty(); - }); - if (maybeWildChild.isPresent()) { - TreeNode wildChild = maybeWildChild.get(); - NativeIntLongPair tinyNode; - while ((tinyNode = smallestChild.poll()) != null) { - // If we have no more tiny nodes, stop iterating over them - if ((double) tinyNode.count / this.getCount() > 1.0 / maxChildren) { - smallestChild.add(tinyNode); - break; - } else { - wildChild.mergeWith(children.remove(tinyNode.tokenId)); - } - } - } - children.values().forEach(TreeNode::collapseTinyChildren); - } - - @Override - void mergeWith(TreeNode treeNode) { - if (treeNode == null) { - return; - } - incCount(treeNode.count); - if (treeNode.isLeaf()) { - throw new UnsupportedOperationException( - "cannot merge non-leaf node with leaf node in categorization tree \n[" + this + "]\n[" + treeNode + "]" - ); - } - InnerTreeNode innerTreeNode = (InnerTreeNode) treeNode; - TreeNode siblingWildChild = innerTreeNode.children.remove(WILD_CARD_ID); - addChild(WILD_CARD_ID, siblingWildChild); - NativeIntLongPair siblingChild; - while ((siblingChild = innerTreeNode.smallestChild.poll()) != null) { - TreeNode nephewNode = innerTreeNode.children.remove(siblingChild.tokenId); - addChild(siblingChild.tokenId, nephewNode); - } - } - - private TreeNode addChild(int tokenId, TreeNode node) { - if (node == null) { - return null; - } - Optional existingChild = getChild(tokenId).map(existingNode -> { - existingNode.mergeWith(node); - if (smallestChild.isEmpty() == false && smallestChild.peek().tokenId == tokenId) { - smallestChild.poll(); - smallestChild.add(NativeIntLongPair.of(tokenId, existingNode.getCount())); - } - return existingNode; - }); - if (existingChild.isPresent()) { - return existingChild.get(); - } - if (children.size() == maxChildren) { - return getChild(WILD_CARD_ID).map(wildChild -> { - final TreeNode toMerge; - final TreeNode toReturn; - if (smallestChild.isEmpty() == false && node.getCount() > smallestChild.peek().count) { - toMerge = children.remove(smallestChild.poll().tokenId); - addChildAndUpdateSmallest(tokenId, node); - toReturn = node; - } else { - toMerge = node; - toReturn = wildChild; - } - wildChild.mergeWith(toMerge); - return toReturn; - }).orElseThrow(() -> new AggregationExecutionException("Missing wild_card child even though maximum children reached")); - } - // we are about to hit the limit, add a wild card if we need to and then add the new child as appropriate - if (children.size() == maxChildren - 1) { - // If we already have a wild token, simply adding the new token is acceptable as we won't breach our limit - if (children.containsKey(WILD_CARD_ID)) { - addChildAndUpdateSmallest(tokenId, node); - } else { // if we don't have a wild card child, we need to add one now - if (tokenId == WILD_CARD_ID) { - addChildAndUpdateSmallest(tokenId, node); - } else { - if (smallestChild.isEmpty() == false && node.count > smallestChild.peek().count) { - addChildAndUpdateSmallest(WILD_CARD_ID, children.remove(smallestChild.poll().tokenId)); - addChildAndUpdateSmallest(tokenId, node); - } else { - addChildAndUpdateSmallest(WILD_CARD_ID, node); - } - } - } - } else { - addChildAndUpdateSmallest(tokenId, node); - } - return node; - } - - private void addChildAndUpdateSmallest(int tokenId, TreeNode node) { - children.put(tokenId, node); - if (tokenId != WILD_CARD_ID) { - smallestChild.add(NativeIntLongPair.of(tokenId, node.count)); - } - } - - private Optional getChild(int tokenId) { - return Optional.ofNullable(children.get(tokenId)); - } - - public List getAllChildrenTextCategorizations() { - return children.values().stream().flatMap(c -> c.getAllChildrenTextCategorizations().stream()).collect(Collectors.toList()); - } - - boolean hasChild(int tokenId) { - return children.containsKey(tokenId); - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - InnerTreeNode treeNode = (InnerTreeNode) o; - return childrenTokenPos == treeNode.childrenTokenPos - && getCount() == treeNode.getCount() - && Objects.equals(children, treeNode.children) - && Objects.equals(smallestChild, treeNode.smallestChild); - } - - @Override - public int hashCode() { - return Objects.hash(children, childrenTokenPos, smallestChild, getCount()); - } - } - - private static class NativeIntLongPair { - private final int tokenId; - private final long count; - - static NativeIntLongPair of(int tokenId, long count) { - return new NativeIntLongPair(tokenId, count); - } - - NativeIntLongPair(int tokenId, long count) { - this.tokenId = tokenId; - this.count = count; - } - - public long count() { - return count; - } - } - -} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/UnmappedCategorizationAggregation.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/UnmappedCategorizationAggregation.java index 0ada736a2e108..dcce5b64ed5d5 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/UnmappedCategorizationAggregation.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/UnmappedCategorizationAggregation.java @@ -19,25 +19,15 @@ protected UnmappedCategorizationAggregation( String name, int requiredSize, long minDocCount, - int maxChildren, - int maxDepth, int similarityThreshold, Map metadata ) { - super(name, requiredSize, minDocCount, maxChildren, maxDepth, similarityThreshold, metadata); + super(name, requiredSize, minDocCount, similarityThreshold, metadata); } @Override public InternalCategorizationAggregation create(List buckets) { - return new UnmappedCategorizationAggregation( - name, - getRequiredSize(), - getMinDocCount(), - getMaxUniqueTokens(), - getMaxMatchTokens(), - getSimilarityThreshold(), - super.metadata - ); + return new UnmappedCategorizationAggregation(name, getRequiredSize(), getMinDocCount(), getSimilarityThreshold(), super.metadata); } @Override @@ -47,15 +37,7 @@ public Bucket createBucket(InternalAggregations aggregations, Bucket prototype) @Override public InternalAggregation reduce(List aggregations, AggregationReduceContext reduceContext) { - return new UnmappedCategorizationAggregation( - name, - getRequiredSize(), - getMinDocCount(), - getMaxUniqueTokens(), - getMaxMatchTokens(), - getSimilarityThreshold(), - super.metadata - ); + return new UnmappedCategorizationAggregation(name, getRequiredSize(), getMinDocCount(), getSimilarityThreshold(), super.metadata); } @Override @@ -67,5 +49,4 @@ public boolean isMapped() { public List getBuckets() { return List.of(); } - } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization2/CategorizationBytesRefHash.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization2/CategorizationBytesRefHash.java deleted file mode 100644 index 7a17507c2f14a..0000000000000 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization2/CategorizationBytesRefHash.java +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.ml.aggs.categorization2; - -import org.apache.lucene.util.BytesRef; -import org.elasticsearch.common.logging.LoggerMessageFormat; -import org.elasticsearch.common.util.BytesRefHash; -import org.elasticsearch.core.Releasable; -import org.elasticsearch.search.aggregations.AggregationExecutionException; - -class CategorizationBytesRefHash implements Releasable { - - private final BytesRefHash bytesRefHash; - - CategorizationBytesRefHash(BytesRefHash bytesRefHash) { - this.bytesRefHash = bytesRefHash; - } - - int[] getIds(BytesRef[] tokens) { - int[] ids = new int[tokens.length]; - for (int i = 0; i < tokens.length; i++) { - ids[i] = put(tokens[i]); - } - return ids; - } - - BytesRef[] getDeeps(int[] ids) { - BytesRef[] tokens = new BytesRef[ids.length]; - for (int i = 0; i < tokens.length; i++) { - tokens[i] = getDeep(ids[i]); - } - return tokens; - } - - BytesRef getDeep(long id) { - BytesRef shallow = bytesRefHash.get(id, new BytesRef()); - return BytesRef.deepCopyOf(shallow); - } - - int put(BytesRef bytesRef) { - long hash = bytesRefHash.add(bytesRef); - if (hash < 0) { - // BytesRefHash returns -1 - hash if the entry already existed, but we just want to return the hash - return (int) (-1L - hash); - } - if (hash > Integer.MAX_VALUE) { - throw new AggregationExecutionException( - LoggerMessageFormat.format( - "more than [{}] unique terms encountered. " - + "Consider restricting the documents queried or adding [{}] in the {} configuration", - Integer.MAX_VALUE, - CategorizeTextAggregationBuilder.CATEGORIZATION_FILTERS.getPreferredName(), - CategorizeTextAggregationBuilder.NAME - ) - ); - } - return (int) hash; - } - - @Override - public void close() { - bytesRefHash.close(); - } -} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization2/CategorizeTextAggregationBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization2/CategorizeTextAggregationBuilder.java deleted file mode 100644 index e4a1596b489a9..0000000000000 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization2/CategorizeTextAggregationBuilder.java +++ /dev/null @@ -1,320 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.ml.aggs.categorization2; - -import org.elasticsearch.Version; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.search.aggregations.AbstractAggregationBuilder; -import org.elasticsearch.search.aggregations.AggregationBuilder; -import org.elasticsearch.search.aggregations.AggregatorFactories; -import org.elasticsearch.search.aggregations.AggregatorFactory; -import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder; -import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregator; -import org.elasticsearch.search.aggregations.support.AggregationContext; -import org.elasticsearch.xcontent.ObjectParser; -import org.elasticsearch.xcontent.ParseField; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xpack.core.ml.job.config.CategorizationAnalyzerConfig; -import org.elasticsearch.xpack.core.ml.job.messages.Messages; -import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; - -import java.io.IOException; -import java.util.List; -import java.util.Map; - -import static org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder.MIN_DOC_COUNT_FIELD_NAME; -import static org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder.REQUIRED_SIZE_FIELD_NAME; -import static org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder.SHARD_MIN_DOC_COUNT_FIELD_NAME; -import static org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder.SHARD_SIZE_FIELD_NAME; -import static org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig.Builder.isValidRegex; - -public class CategorizeTextAggregationBuilder extends AbstractAggregationBuilder { - - static final TermsAggregator.BucketCountThresholds DEFAULT_BUCKET_COUNT_THRESHOLDS = new TermsAggregator.BucketCountThresholds( - 1, - 0, - 10, - -1 - ); - - public static final String NAME = "categorize_text2"; - - static final ParseField FIELD_NAME = new ParseField("field"); - static final ParseField SIMILARITY_THRESHOLD = new ParseField("similarity_threshold"); - // The next two are unused, but accepted and ignored to avoid breaking client code - static final ParseField MAX_UNIQUE_TOKENS = new ParseField("max_unique_tokens").withAllDeprecated(); - static final ParseField MAX_MATCHED_TOKENS = new ParseField("max_matched_tokens").withAllDeprecated(); - static final ParseField CATEGORIZATION_FILTERS = new ParseField("categorization_filters"); - static final ParseField CATEGORIZATION_ANALYZER = new ParseField("categorization_analyzer"); - - public static final ObjectParser PARSER = ObjectParser.fromBuilder( - CategorizeTextAggregationBuilder.NAME, - CategorizeTextAggregationBuilder::new - ); - static { - PARSER.declareString(CategorizeTextAggregationBuilder::setFieldName, FIELD_NAME); - PARSER.declareInt(CategorizeTextAggregationBuilder::setSimilarityThreshold, SIMILARITY_THRESHOLD); - // The next two are unused, but accepted and ignored to avoid breaking client code - PARSER.declareInt((p, c) -> {}, MAX_UNIQUE_TOKENS); - PARSER.declareInt((p, c) -> {}, MAX_MATCHED_TOKENS); - PARSER.declareField( - CategorizeTextAggregationBuilder::setCategorizationAnalyzerConfig, - (p, c) -> CategorizationAnalyzerConfig.buildFromXContentFragment(p, false), - CATEGORIZATION_ANALYZER, - ObjectParser.ValueType.OBJECT_OR_STRING - ); - PARSER.declareStringArray(CategorizeTextAggregationBuilder::setCategorizationFilters, CATEGORIZATION_FILTERS); - PARSER.declareInt(CategorizeTextAggregationBuilder::shardSize, TermsAggregationBuilder.SHARD_SIZE_FIELD_NAME); - PARSER.declareLong(CategorizeTextAggregationBuilder::minDocCount, TermsAggregationBuilder.MIN_DOC_COUNT_FIELD_NAME); - PARSER.declareLong(CategorizeTextAggregationBuilder::shardMinDocCount, TermsAggregationBuilder.SHARD_MIN_DOC_COUNT_FIELD_NAME); - PARSER.declareInt(CategorizeTextAggregationBuilder::size, REQUIRED_SIZE_FIELD_NAME); - } - - public static CategorizeTextAggregationBuilder parse(String aggregationName, XContentParser parser) throws IOException { - return PARSER.parse(parser, new CategorizeTextAggregationBuilder(aggregationName), null); - } - - private TermsAggregator.BucketCountThresholds bucketCountThresholds = new TermsAggregator.BucketCountThresholds( - DEFAULT_BUCKET_COUNT_THRESHOLDS - ); - private CategorizationAnalyzerConfig categorizationAnalyzerConfig; - private String fieldName; - // Default of 70% matches the C++ code - private int similarityThreshold = 70; - - private CategorizeTextAggregationBuilder(String name) { - super(name); - } - - public CategorizeTextAggregationBuilder(String name, String fieldName) { - super(name); - this.fieldName = ExceptionsHelper.requireNonNull(fieldName, FIELD_NAME); - } - - @Override - public boolean supportsSampling() { - return true; - } - - public String getFieldName() { - return fieldName; - } - - public CategorizeTextAggregationBuilder setFieldName(String fieldName) { - this.fieldName = ExceptionsHelper.requireNonNull(fieldName, FIELD_NAME); - return this; - } - - public CategorizeTextAggregationBuilder(StreamInput in) throws IOException { - super(in); - this.bucketCountThresholds = new TermsAggregator.BucketCountThresholds(in); - this.fieldName = in.readString(); - this.similarityThreshold = in.readVInt(); - this.categorizationAnalyzerConfig = in.readOptionalWriteable(CategorizationAnalyzerConfig::new); - } - - public double getSimilarityThreshold() { - return similarityThreshold; - } - - public CategorizeTextAggregationBuilder setSimilarityThreshold(int similarityThreshold) { - this.similarityThreshold = similarityThreshold; - if (similarityThreshold < 1 || similarityThreshold > 100) { - throw ExceptionsHelper.badRequestException( - "[{}] must be in the range [1, 100]. Found [{}] in [{}]", - SIMILARITY_THRESHOLD.getPreferredName(), - similarityThreshold, - name - ); - } - return this; - } - - public CategorizeTextAggregationBuilder setCategorizationAnalyzerConfig(CategorizationAnalyzerConfig categorizationAnalyzerConfig) { - if (this.categorizationAnalyzerConfig != null) { - throw ExceptionsHelper.badRequestException( - "[{}] cannot be used with [{}] - instead specify them as pattern_replace char_filters in the analyzer", - CATEGORIZATION_FILTERS.getPreferredName(), - CATEGORIZATION_ANALYZER.getPreferredName() - ); - } - this.categorizationAnalyzerConfig = categorizationAnalyzerConfig; - return this; - } - - public CategorizeTextAggregationBuilder setCategorizationFilters(List categorizationFilters) { - if (categorizationFilters == null || categorizationFilters.isEmpty()) { - return this; - } - if (categorizationAnalyzerConfig != null) { - throw ExceptionsHelper.badRequestException( - "[{}] cannot be used with [{}] - instead specify them as pattern_replace char_filters in the analyzer", - CATEGORIZATION_FILTERS.getPreferredName(), - CATEGORIZATION_ANALYZER.getPreferredName() - ); - } - if (categorizationFilters.stream().distinct().count() != categorizationFilters.size()) { - throw ExceptionsHelper.badRequestException(Messages.JOB_CONFIG_CATEGORIZATION_FILTERS_CONTAINS_DUPLICATES); - } - if (categorizationFilters.stream().anyMatch(String::isEmpty)) { - throw ExceptionsHelper.badRequestException(Messages.getMessage(Messages.JOB_CONFIG_CATEGORIZATION_FILTERS_CONTAINS_EMPTY)); - } - for (String filter : categorizationFilters) { - if (isValidRegex(filter) == false) { - throw ExceptionsHelper.badRequestException( - Messages.getMessage(Messages.JOB_CONFIG_CATEGORIZATION_FILTERS_CONTAINS_INVALID_REGEX, filter) - ); - } - } - this.categorizationAnalyzerConfig = CategorizationAnalyzerConfig.buildStandardCategorizationAnalyzer(categorizationFilters); - return this; - } - - /** - * @param size indicating how many buckets should be returned - */ - public CategorizeTextAggregationBuilder size(int size) { - if (size <= 0) { - throw ExceptionsHelper.badRequestException( - "[{}] must be greater than 0. Found [{}] in [{}]", - REQUIRED_SIZE_FIELD_NAME.getPreferredName(), - size, - name - ); - } - bucketCountThresholds.setRequiredSize(size); - return this; - } - - /** - * @param shardSize - indicating the number of buckets each shard - * will return to the coordinating node (the node that coordinates the - * search execution). The higher the shard size is, the more accurate the - * results are. - */ - public CategorizeTextAggregationBuilder shardSize(int shardSize) { - if (shardSize <= 0) { - throw ExceptionsHelper.badRequestException( - "[{}] must be greater than 0. Found [{}] in [{}]", - SHARD_SIZE_FIELD_NAME.getPreferredName(), - shardSize, - name - ); - } - bucketCountThresholds.setShardSize(shardSize); - return this; - } - - /** - * @param minDocCount the minimum document count a text category should have in order to appear in - * the response. - */ - public CategorizeTextAggregationBuilder minDocCount(long minDocCount) { - if (minDocCount < 0) { - throw ExceptionsHelper.badRequestException( - "[{}] must be greater than or equal to 0. Found [{}] in [{}]", - MIN_DOC_COUNT_FIELD_NAME.getPreferredName(), - minDocCount, - name - ); - } - bucketCountThresholds.setMinDocCount(minDocCount); - return this; - } - - /** - * @param shardMinDocCount the minimum document count a text category should have on the shard in order to - * appear in the response. - */ - public CategorizeTextAggregationBuilder shardMinDocCount(long shardMinDocCount) { - if (shardMinDocCount < 0) { - throw ExceptionsHelper.badRequestException( - "[{}] must be greater than or equal to 0. Found [{}] in [{}]", - SHARD_MIN_DOC_COUNT_FIELD_NAME.getPreferredName(), - shardMinDocCount, - name - ); - } - bucketCountThresholds.setShardMinDocCount(shardMinDocCount); - return this; - } - - protected CategorizeTextAggregationBuilder( - CategorizeTextAggregationBuilder clone, - AggregatorFactories.Builder factoriesBuilder, - Map metadata - ) { - super(clone, factoriesBuilder, metadata); - this.bucketCountThresholds = new TermsAggregator.BucketCountThresholds(clone.bucketCountThresholds); - this.fieldName = clone.fieldName; - this.similarityThreshold = clone.similarityThreshold; - this.categorizationAnalyzerConfig = clone.categorizationAnalyzerConfig; - } - - @Override - protected void doWriteTo(StreamOutput out) throws IOException { - bucketCountThresholds.writeTo(out); - out.writeString(fieldName); - out.writeVInt(similarityThreshold); - out.writeOptionalWriteable(categorizationAnalyzerConfig); - } - - @Override - protected AggregatorFactory doBuild( - AggregationContext context, - AggregatorFactory parent, - AggregatorFactories.Builder subfactoriesBuilder - ) throws IOException { - return new CategorizeTextAggregatorFactory( - name, - fieldName, - similarityThreshold, - bucketCountThresholds, - categorizationAnalyzerConfig, - context, - parent, - subfactoriesBuilder, - metadata - ); - } - - @Override - protected XContentBuilder internalXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - bucketCountThresholds.toXContent(builder, params); - builder.field(FIELD_NAME.getPreferredName(), fieldName); - builder.field(SIMILARITY_THRESHOLD.getPreferredName(), similarityThreshold); - if (categorizationAnalyzerConfig != null) { - categorizationAnalyzerConfig.toXContent(builder, params); - } - builder.endObject(); - return null; - } - - @Override - protected AggregationBuilder shallowCopy(AggregatorFactories.Builder factoriesBuilder, Map metadata) { - return new CategorizeTextAggregationBuilder(this, factoriesBuilder, metadata); - } - - @Override - public BucketCardinality bucketCardinality() { - return BucketCardinality.MANY; - } - - @Override - public String getType() { - return NAME; - } - - @Override - public Version getMinimalSupportedVersion() { - return Version.V_8_3_0; - } -} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization2/CategorizeTextAggregator.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization2/CategorizeTextAggregator.java deleted file mode 100644 index f62642067e788..0000000000000 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization2/CategorizeTextAggregator.java +++ /dev/null @@ -1,198 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.ml.aggs.categorization2; - -import org.apache.lucene.analysis.Analyzer; -import org.apache.lucene.analysis.TokenStream; -import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.util.BytesRef; -import org.elasticsearch.common.util.BytesRefHash; -import org.elasticsearch.common.util.ObjectArray; -import org.elasticsearch.core.Releasables; -import org.elasticsearch.index.mapper.MappedFieldType; -import org.elasticsearch.search.aggregations.Aggregator; -import org.elasticsearch.search.aggregations.AggregatorFactories; -import org.elasticsearch.search.aggregations.CardinalityUpperBound; -import org.elasticsearch.search.aggregations.InternalAggregation; -import org.elasticsearch.search.aggregations.LeafBucketCollector; -import org.elasticsearch.search.aggregations.LeafBucketCollectorBase; -import org.elasticsearch.search.aggregations.bucket.DeferableBucketAggregator; -import org.elasticsearch.search.aggregations.bucket.terms.LongKeyedBucketOrds; -import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregator; -import org.elasticsearch.search.aggregations.support.AggregationContext; -import org.elasticsearch.search.lookup.SourceLookup; -import org.elasticsearch.xpack.core.ml.job.config.CategorizationAnalyzerConfig; -import org.elasticsearch.xpack.ml.aggs.categorization2.InternalCategorizationAggregation.Bucket; -import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer; - -import java.io.IOException; -import java.util.Arrays; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Optional; - -public class CategorizeTextAggregator extends DeferableBucketAggregator { - - private final TermsAggregator.BucketCountThresholds bucketCountThresholds; - private final SourceLookup sourceLookup; - private final MappedFieldType fieldType; - private final CategorizationAnalyzer analyzer; - private final String sourceFieldName; - private ObjectArray categorizers; - private final int similarityThreshold; - private final LongKeyedBucketOrds bucketOrds; - private final CategorizationBytesRefHash bytesRefHash; - private final CategorizationPartOfSpeechDictionary partOfSpeechDictionary; - - protected CategorizeTextAggregator( - String name, - AggregatorFactories factories, - AggregationContext context, - Aggregator parent, - String sourceFieldName, - MappedFieldType fieldType, - TermsAggregator.BucketCountThresholds bucketCountThresholds, - int similarityThreshold, - CategorizationAnalyzerConfig categorizationAnalyzerConfig, - Map metadata - ) throws IOException { - super(name, factories, context, parent, metadata); - this.sourceLookup = context.lookup().source(); - this.sourceFieldName = sourceFieldName; - this.fieldType = fieldType; - CategorizationAnalyzerConfig analyzerConfig = Optional.ofNullable(categorizationAnalyzerConfig) - .orElse(CategorizationAnalyzerConfig.buildStandardCategorizationAnalyzer(List.of())); - final String analyzerName = analyzerConfig.getAnalyzer(); - if (analyzerName != null) { - Analyzer globalAnalyzer = context.getNamedAnalyzer(analyzerName); - if (globalAnalyzer == null) { - throw new IllegalArgumentException("Failed to find global analyzer [" + analyzerName + "]"); - } - this.analyzer = new CategorizationAnalyzer(globalAnalyzer, false); - } else { - this.analyzer = new CategorizationAnalyzer( - context.buildCustomAnalyzer( - context.getIndexSettings(), - false, - analyzerConfig.getTokenizer(), - analyzerConfig.getCharFilters(), - analyzerConfig.getTokenFilters() - ), - true - ); - } - this.categorizers = bigArrays().newObjectArray(1); - this.similarityThreshold = similarityThreshold; - this.bucketOrds = LongKeyedBucketOrds.build(bigArrays(), CardinalityUpperBound.MANY); - this.bucketCountThresholds = bucketCountThresholds; - this.bytesRefHash = new CategorizationBytesRefHash(new BytesRefHash(2048, bigArrays())); - // TODO: make it possible to choose a language instead of or as well as English for the part-of-speech dictionary - this.partOfSpeechDictionary = CategorizationPartOfSpeechDictionary.getInstance(); - } - - @Override - protected void doClose() { - super.doClose(); - Releasables.close(this.analyzer, this.bytesRefHash, this.bucketOrds, this.categorizers); - } - - @Override - public InternalAggregation[] buildAggregations(long[] ordsToCollect) throws IOException { - Bucket[][] topBucketsPerOrd = new Bucket[ordsToCollect.length][]; - for (int ordIdx = 0; ordIdx < ordsToCollect.length; ordIdx++) { - final TokenListCategorizer categorizer = categorizers.get(ordsToCollect[ordIdx]); - if (categorizer == null) { - topBucketsPerOrd[ordIdx] = new Bucket[0]; - continue; - } - int size = (int) Math.min(bucketOrds.bucketsInOrd(ordIdx), bucketCountThresholds.getShardSize()); - topBucketsPerOrd[ordIdx] = categorizer.toOrderedBuckets(size); - } - buildSubAggsForAllBuckets(topBucketsPerOrd, Bucket::getBucketOrd, Bucket::setAggregations); - InternalAggregation[] results = new InternalAggregation[ordsToCollect.length]; - for (int ordIdx = 0; ordIdx < ordsToCollect.length; ordIdx++) { - results[ordIdx] = new InternalCategorizationAggregation( - name, - bucketCountThresholds.getRequiredSize(), - bucketCountThresholds.getMinDocCount(), - similarityThreshold, - metadata(), - Arrays.asList(topBucketsPerOrd[ordIdx]) - ); - } - return results; - } - - @Override - public InternalAggregation buildEmptyAggregation() { - return new InternalCategorizationAggregation( - name, - bucketCountThresholds.getRequiredSize(), - bucketCountThresholds.getMinDocCount(), - similarityThreshold, - metadata() - ); - } - - @Override - protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException { - return new LeafBucketCollectorBase(sub, null) { - @Override - public void collect(int doc, long owningBucketOrd) throws IOException { - categorizers = bigArrays().grow(categorizers, owningBucketOrd + 1); - TokenListCategorizer categorizer = categorizers.get(owningBucketOrd); - if (categorizer == null) { - categorizer = new TokenListCategorizer(bytesRefHash, partOfSpeechDictionary, (float) similarityThreshold / 100.0f); - addRequestCircuitBreakerBytes(categorizer.ramBytesUsed()); - categorizers.set(owningBucketOrd, categorizer); - } - collectFromSource(doc, owningBucketOrd, categorizer); - } - - private void collectFromSource(int doc, long owningBucketOrd, TokenListCategorizer categorizer) throws IOException { - sourceLookup.setSegmentAndDocument(ctx, doc); - Iterator itr = sourceLookup.extractRawValuesWithoutCaching(sourceFieldName).stream().map(obj -> { - if (obj instanceof BytesRef) { - return fieldType.valueForDisplay(obj).toString(); - } - return (obj == null) ? null : obj.toString(); - }).iterator(); - while (itr.hasNext()) { - String string = itr.next(); - try (TokenStream ts = analyzer.tokenStream(fieldType.name(), string)) { - processTokenStream(owningBucketOrd, ts, string.length(), doc, categorizer); - } - } - } - - private void processTokenStream( - long owningBucketOrd, - TokenStream ts, - int unfilteredLength, - int doc, - TokenListCategorizer categorizer - ) throws IOException { - long previousSize = categorizer.ramBytesUsed(); - TokenListCategory category = categorizer.computeCategory(ts, unfilteredLength, docCountProvider.getDocCount(doc)); - if (category == null) { - return; - } - long sizeDiff = categorizer.ramBytesUsed() - previousSize; - addRequestCircuitBreakerBytes(sizeDiff); - long bucketOrd = bucketOrds.add(owningBucketOrd, category.getId()); - if (bucketOrd < 0) { // already seen - collectExistingBucket(sub, doc, -1 - bucketOrd); - } else { - category.setBucketOrd(bucketOrd); - collectBucket(sub, doc, bucketOrd); - } - } - }; - } -} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization2/CategorizeTextAggregatorFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization2/CategorizeTextAggregatorFactory.java deleted file mode 100644 index f4c7cb11c9350..0000000000000 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization2/CategorizeTextAggregatorFactory.java +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.ml.aggs.categorization2; - -import org.elasticsearch.index.mapper.MappedFieldType; -import org.elasticsearch.search.aggregations.Aggregator; -import org.elasticsearch.search.aggregations.AggregatorFactories; -import org.elasticsearch.search.aggregations.AggregatorFactory; -import org.elasticsearch.search.aggregations.CardinalityUpperBound; -import org.elasticsearch.search.aggregations.InternalAggregation; -import org.elasticsearch.search.aggregations.NonCollectingAggregator; -import org.elasticsearch.search.aggregations.bucket.BucketUtils; -import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregator; -import org.elasticsearch.search.aggregations.support.AggregationContext; -import org.elasticsearch.xpack.core.ml.job.config.CategorizationAnalyzerConfig; - -import java.io.IOException; -import java.util.Map; - -public class CategorizeTextAggregatorFactory extends AggregatorFactory { - - private final MappedFieldType fieldType; - private final int similarityThreshold; - private final CategorizationAnalyzerConfig categorizationAnalyzerConfig; - private final TermsAggregator.BucketCountThresholds bucketCountThresholds; - - public CategorizeTextAggregatorFactory( - String name, - String fieldName, - int similarityThreshold, - TermsAggregator.BucketCountThresholds bucketCountThresholds, - CategorizationAnalyzerConfig categorizationAnalyzerConfig, - AggregationContext context, - AggregatorFactory parent, - AggregatorFactories.Builder subFactoriesBuilder, - Map metadata - ) throws IOException { - super(name, context, parent, subFactoriesBuilder, metadata); - this.fieldType = context.getFieldType(fieldName); - this.similarityThreshold = similarityThreshold; - this.categorizationAnalyzerConfig = categorizationAnalyzerConfig; - this.bucketCountThresholds = bucketCountThresholds; - } - - protected Aggregator createUnmapped(Aggregator parent, Map metadata) throws IOException { - final InternalAggregation aggregation = new UnmappedCategorizationAggregation( - name, - bucketCountThresholds.getRequiredSize(), - bucketCountThresholds.getMinDocCount(), - similarityThreshold, - metadata - ); - return new NonCollectingAggregator(name, context, parent, factories, metadata) { - @Override - public InternalAggregation buildEmptyAggregation() { - return aggregation; - } - }; - } - - @Override - protected Aggregator createInternal(Aggregator parent, CardinalityUpperBound cardinality, Map metadata) - throws IOException { - if (fieldType == null) { - return createUnmapped(parent, metadata); - } - TermsAggregator.BucketCountThresholds bucketCountThresholds = new TermsAggregator.BucketCountThresholds(this.bucketCountThresholds); - if (bucketCountThresholds.getShardSize() == CategorizeTextAggregationBuilder.DEFAULT_BUCKET_COUNT_THRESHOLDS.getShardSize()) { - // The user has not made a shardSize selection. Use default - // heuristic to avoid any wrong-ranking caused by distributed - // counting - // TODO significant text does a 2x here, should we as well? - bucketCountThresholds.setShardSize(BucketUtils.suggestShardSideQueueSize(bucketCountThresholds.getRequiredSize())); - } - bucketCountThresholds.ensureValidity(); - - return new CategorizeTextAggregator( - name, - factories, - context, - parent, - fieldType.name(), - fieldType, - bucketCountThresholds, - similarityThreshold, - categorizationAnalyzerConfig, - metadata - ); - } -} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization2/InternalCategorizationAggregation.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization2/InternalCategorizationAggregation.java deleted file mode 100644 index 5914e9ef91319..0000000000000 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization2/InternalCategorizationAggregation.java +++ /dev/null @@ -1,337 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.ml.aggs.categorization2; - -import org.apache.lucene.util.BytesRef; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.util.BytesRefHash; -import org.elasticsearch.search.aggregations.AggregationReduceContext; -import org.elasticsearch.search.aggregations.Aggregations; -import org.elasticsearch.search.aggregations.InternalAggregation; -import org.elasticsearch.search.aggregations.InternalAggregations; -import org.elasticsearch.search.aggregations.InternalMultiBucketAggregation; -import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation; -import org.elasticsearch.search.aggregations.support.SamplingContext; -import org.elasticsearch.xcontent.ToXContentFragment; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.ml.job.results.CategoryDefinition; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Comparator; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.stream.Collectors; - -public class InternalCategorizationAggregation extends InternalMultiBucketAggregation< - InternalCategorizationAggregation, - InternalCategorizationAggregation.Bucket> { - - static class BucketKey implements ToXContentFragment, Comparable { - - private final BytesRef[] key; - - BucketKey(SerializableTokenListCategory serializableCategory) { - this.key = serializableCategory.getKeyTokens(); - } - - BucketKey(BytesRef[] key) { - this.key = key; - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - return builder.value(toString()); - } - - @Override - public String toString() { - return Arrays.stream(key).map(BytesRef::utf8ToString).collect(Collectors.joining(" ")); - } - - @Override - public int hashCode() { - return Arrays.hashCode(key); - } - - @Override - public boolean equals(Object other) { - if (other == this) { - return true; - } - if (other == null || getClass() != other.getClass()) { - return false; - } - BucketKey that = (BucketKey) other; - return Arrays.equals(this.key, that.key); - } - - public BytesRef[] keyAsTokens() { - return key; - } - - @Override - public int compareTo(BucketKey o) { - return Arrays.compare(key, o.key); - } - } - - public static class Bucket extends InternalMultiBucketAggregation.InternalBucket - implements - MultiBucketsAggregation.Bucket, - Comparable { - - private final SerializableTokenListCategory serializableCategory; - private final BucketKey key; - private long bucketOrd; - private InternalAggregations aggregations; - - public Bucket(SerializableTokenListCategory serializableCategory, long bucketOrd) { - this(serializableCategory, bucketOrd, InternalAggregations.EMPTY); - } - - public Bucket(SerializableTokenListCategory serializableCategory, long bucketOrd, InternalAggregations aggregations) { - this.serializableCategory = serializableCategory; - this.key = new BucketKey(serializableCategory); - this.bucketOrd = bucketOrd; - this.aggregations = Objects.requireNonNull(aggregations); - } - - public Bucket(StreamInput in) throws IOException { - serializableCategory = new SerializableTokenListCategory(in); - key = new BucketKey(serializableCategory); - bucketOrd = -1; - aggregations = InternalAggregations.readFrom(in); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - serializableCategory.writeTo(out); - aggregations.writeTo(out); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(CommonFields.DOC_COUNT.getPreferredName(), serializableCategory.getNumMatches()); - builder.field(CommonFields.KEY.getPreferredName()); - key.toXContent(builder, params); - builder.field(CategoryDefinition.MAX_MATCHING_LENGTH.getPreferredName(), serializableCategory.maxMatchingStringLen()); - aggregations.toXContentInternal(builder, params); - builder.endObject(); - return builder; - } - - BucketKey getRawKey() { - return key; - } - - @Override - public Object getKey() { - return key; - } - - @Override - public String getKeyAsString() { - return key.toString(); - } - - @Override - public long getDocCount() { - return serializableCategory.getNumMatches(); - } - - @Override - public Aggregations getAggregations() { - return aggregations; - } - - void setAggregations(InternalAggregations aggregations) { - this.aggregations = aggregations; - } - - long getBucketOrd() { - return bucketOrd; - } - - SerializableTokenListCategory getSerializableCategory() { - return serializableCategory; - } - - @Override - public String toString() { - return "Bucket{key=" - + getKeyAsString() - + ", docCount=" - + serializableCategory.getNumMatches() - + ", aggregations=" - + aggregations.asMap() - + "}\n"; - } - - @Override - public int compareTo(Bucket other) { - return Long.signum(this.serializableCategory.getNumMatches() - other.serializableCategory.getNumMatches()); - } - } - - private final List buckets; - private final int similarityThreshold; - private final int requiredSize; - private final long minDocCount; - - protected InternalCategorizationAggregation( - String name, - int requiredSize, - long minDocCount, - int similarityThreshold, - Map metadata - ) { - this(name, requiredSize, minDocCount, similarityThreshold, metadata, new ArrayList<>()); - } - - protected InternalCategorizationAggregation( - String name, - int requiredSize, - long minDocCount, - int similarityThreshold, - Map metadata, - List buckets - ) { - super(name, metadata); - this.buckets = buckets; - this.similarityThreshold = similarityThreshold; - this.minDocCount = minDocCount; - this.requiredSize = requiredSize; - } - - public InternalCategorizationAggregation(StreamInput in) throws IOException { - super(in); - this.similarityThreshold = in.readVInt(); - this.buckets = in.readList(Bucket::new); - this.requiredSize = readSize(in); - this.minDocCount = in.readVLong(); - } - - @Override - protected void doWriteTo(StreamOutput out) throws IOException { - out.writeVInt(similarityThreshold); - out.writeList(buckets); - writeSize(requiredSize, out); - out.writeVLong(minDocCount); - } - - @Override - public XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException { - builder.startArray(CommonFields.BUCKETS.getPreferredName()); - for (Bucket bucket : buckets) { - bucket.toXContent(builder, params); - } - builder.endArray(); - return builder; - } - - @Override - public InternalCategorizationAggregation create(List buckets) { - return new InternalCategorizationAggregation(name, requiredSize, minDocCount, similarityThreshold, super.metadata, buckets); - } - - @Override - public Bucket createBucket(InternalAggregations aggregations, Bucket prototype) { - return new Bucket(prototype.serializableCategory, prototype.bucketOrd, aggregations); - } - - @Override - protected Bucket reduceBucket(List buckets, AggregationReduceContext context) { - throw new UnsupportedOperationException("For optimization purposes, typical bucket path is not supported"); - } - - @Override - public List getBuckets() { - return buckets; - } - - @Override - public String getWriteableName() { - return CategorizeTextAggregationBuilder.NAME; - } - - @Override - public InternalAggregation reduce(List aggregations, AggregationReduceContext reduceContext) { - try (CategorizationBytesRefHash hash = new CategorizationBytesRefHash(new BytesRefHash(1L, reduceContext.bigArrays()))) { - TokenListCategorizer categorizer = new TokenListCategorizer( - hash, - null, // part-of-speech dictionary is not needed for the reduce phase as weights are already decided - (float) similarityThreshold / 100.0f - ); - // Merge all the categories into the newly created empty categorizer to combine them - for (InternalAggregation aggregation : aggregations) { - InternalCategorizationAggregation categorizationAggregation = (InternalCategorizationAggregation) aggregation; - for (Bucket bucket : categorizationAggregation.buckets) { - categorizer.mergeWireCategory(bucket.serializableCategory).addSubAggs((InternalAggregations) bucket.getAggregations()); - if (reduceContext.isCanceled().get()) { - break; - } - } - } - final int size = reduceContext.isFinalReduce() - ? Math.min(requiredSize, categorizer.getCategoryCount()) - : categorizer.getCategoryCount(); - Bucket[] mergedBuckets = categorizer.toOrderedBuckets(size, reduceContext.isFinalReduce() ? minDocCount : 0, reduceContext); - // TODO: not sure if this next line is correct - if we discarded some categories due to size or minDocCount is this handled? - reduceContext.consumeBucketsAndMaybeBreak(mergedBuckets.length); - // Keep the top categories top, but then sort by the key for those with duplicate counts - if (reduceContext.isFinalReduce()) { - Arrays.sort(mergedBuckets, Comparator.comparing(Bucket::getDocCount).reversed().thenComparing(Bucket::getRawKey)); - } - return new InternalCategorizationAggregation( - name, - requiredSize, - minDocCount, - similarityThreshold, - metadata, - Arrays.asList(mergedBuckets) - ); - } - } - - @Override - public InternalAggregation finalizeSampling(SamplingContext samplingContext) { - return new InternalCategorizationAggregation( - name, - requiredSize, - minDocCount, - similarityThreshold, - metadata, - buckets.stream() - .map( - b -> new Bucket( - new SerializableTokenListCategory(b.getSerializableCategory(), samplingContext.scaleUp(b.getDocCount())), - b.getBucketOrd(), - InternalAggregations.finalizeSampling(b.aggregations, samplingContext) - ) - ) - .collect(Collectors.toList()) - ); - } - - public int getSimilarityThreshold() { - return similarityThreshold; - } - - public int getRequiredSize() { - return requiredSize; - } - - public long getMinDocCount() { - return minDocCount; - } -} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization2/UnmappedCategorizationAggregation.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization2/UnmappedCategorizationAggregation.java deleted file mode 100644 index eb7a2be14f371..0000000000000 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization2/UnmappedCategorizationAggregation.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.ml.aggs.categorization2; - -import org.elasticsearch.search.aggregations.AggregationReduceContext; -import org.elasticsearch.search.aggregations.InternalAggregation; -import org.elasticsearch.search.aggregations.InternalAggregations; - -import java.util.List; -import java.util.Map; - -class UnmappedCategorizationAggregation extends InternalCategorizationAggregation { - protected UnmappedCategorizationAggregation( - String name, - int requiredSize, - long minDocCount, - int similarityThreshold, - Map metadata - ) { - super(name, requiredSize, minDocCount, similarityThreshold, metadata); - } - - @Override - public InternalCategorizationAggregation create(List buckets) { - return new UnmappedCategorizationAggregation(name, getRequiredSize(), getMinDocCount(), getSimilarityThreshold(), super.metadata); - } - - @Override - public Bucket createBucket(InternalAggregations aggregations, Bucket prototype) { - throw new UnsupportedOperationException("not supported for UnmappedCategorizationAggregation"); - } - - @Override - public InternalAggregation reduce(List aggregations, AggregationReduceContext reduceContext) { - return new UnmappedCategorizationAggregation(name, getRequiredSize(), getMinDocCount(), getSimilarityThreshold(), super.metadata); - } - - @Override - public boolean isMapped() { - return false; - } - - @Override - public List getBuckets() { - return List.of(); - } -} diff --git a/x-pack/plugin/ml/src/main/resources/org/elasticsearch/xpack/ml/aggs/categorization2/ml-en.dict b/x-pack/plugin/ml/src/main/resources/org/elasticsearch/xpack/ml/aggs/categorization/ml-en.dict similarity index 100% rename from x-pack/plugin/ml/src/main/resources/org/elasticsearch/xpack/ml/aggs/categorization2/ml-en.dict rename to x-pack/plugin/ml/src/main/resources/org/elasticsearch/xpack/ml/aggs/categorization/ml-en.dict diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization2/CategorizationPartOfSpeechDictionaryTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationPartOfSpeechDictionaryTests.java similarity index 94% rename from x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization2/CategorizationPartOfSpeechDictionaryTests.java rename to x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationPartOfSpeechDictionaryTests.java index 35008ddb962ce..948f2105eb6f2 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization2/CategorizationPartOfSpeechDictionaryTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationPartOfSpeechDictionaryTests.java @@ -5,10 +5,10 @@ * 2.0. */ -package org.elasticsearch.xpack.ml.aggs.categorization2; +package org.elasticsearch.xpack.ml.aggs.categorization; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.ml.aggs.categorization2.CategorizationPartOfSpeechDictionary.PartOfSpeech; +import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationPartOfSpeechDictionary.PartOfSpeech; import java.io.IOException; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization2/CategorizationTestCase.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTestCase.java similarity index 88% rename from x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization2/CategorizationTestCase.java rename to x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTestCase.java index 481e23a8a5708..172257028e414 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization2/CategorizationTestCase.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTestCase.java @@ -5,13 +5,13 @@ * 2.0. */ -package org.elasticsearch.xpack.ml.aggs.categorization2; +package org.elasticsearch.xpack.ml.aggs.categorization; import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.BytesRefHash; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.ml.aggs.categorization2.TokenListCategory.TokenAndWeight; +import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategory.TokenAndWeight; import org.junit.After; import org.junit.Before; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilderTests.java index 2bcc010b26694..3467992d7b6b6 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilderTests.java @@ -17,9 +17,6 @@ import java.util.stream.Collectors; import java.util.stream.Stream; -import static org.elasticsearch.xpack.ml.aggs.categorization.CategorizeTextAggregationBuilder.MAX_MAX_MATCHED_TOKENS; -import static org.elasticsearch.xpack.ml.aggs.categorization.CategorizeTextAggregationBuilder.MAX_MAX_UNIQUE_TOKENS; - public class CategorizeTextAggregationBuilderTests extends BaseAggregationTestCase { @Override @@ -37,12 +34,6 @@ protected CategorizeTextAggregationBuilder createTestAggregatorBuilder() { if (setFilters == false) { builder.setCategorizationAnalyzerConfig(CategorizationAnalyzerConfigTests.createRandomized().build()); } - if (randomBoolean()) { - builder.setMaxUniqueTokens(randomIntBetween(1, MAX_MAX_UNIQUE_TOKENS)); - } - if (randomBoolean()) { - builder.setMaxMatchedTokens(randomIntBetween(1, MAX_MAX_MATCHED_TOKENS)); - } if (randomBoolean()) { builder.setSimilarityThreshold(randomIntBetween(1, 100)); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorTests.java index 413c1a050a85f..ebd7b4ce4da61 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorTests.java @@ -64,9 +64,9 @@ public void testCategorizationWithoutSubAggs() throws Exception { CategorizeTextAggregatorTests::writeTestDocs, (InternalCategorizationAggregation result) -> { assertThat(result.getBuckets(), hasSize(2)); - assertThat(result.getBuckets().get(0).docCount, equalTo(6L)); + assertThat(result.getBuckets().get(0).getDocCount(), equalTo(6L)); assertThat(result.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); - assertThat(result.getBuckets().get(1).docCount, equalTo(2L)); + assertThat(result.getBuckets().get(1).getDocCount(), equalTo(2L)); assertThat( result.getBuckets().get(1).getKeyAsString(), equalTo("Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception") @@ -89,20 +89,20 @@ public void testCategorizationWithSubAggs() throws Exception { CategorizeTextAggregatorTests::writeTestDocs, (InternalCategorizationAggregation result) -> { assertThat(result.getBuckets(), hasSize(2)); - assertThat(result.getBuckets().get(0).docCount, equalTo(6L)); + assertThat(result.getBuckets().get(0).getDocCount(), equalTo(6L)); assertThat(result.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); - assertThat(((Max) result.getBuckets().get(0).aggregations.get("max")).value(), equalTo(5.0)); - assertThat(((Min) result.getBuckets().get(0).aggregations.get("min")).value(), equalTo(0.0)); - assertThat(((Avg) result.getBuckets().get(0).aggregations.get("avg")).getValue(), equalTo(2.5)); + assertThat(((Max) result.getBuckets().get(0).getAggregations().get("max")).value(), equalTo(5.0)); + assertThat(((Min) result.getBuckets().get(0).getAggregations().get("min")).value(), equalTo(0.0)); + assertThat(((Avg) result.getBuckets().get(0).getAggregations().get("avg")).getValue(), equalTo(2.5)); - assertThat(result.getBuckets().get(1).docCount, equalTo(2L)); + assertThat(result.getBuckets().get(1).getDocCount(), equalTo(2L)); assertThat( result.getBuckets().get(1).getKeyAsString(), equalTo("Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception") ); - assertThat(((Max) result.getBuckets().get(1).aggregations.get("max")).value(), equalTo(4.0)); - assertThat(((Min) result.getBuckets().get(1).aggregations.get("min")).value(), equalTo(0.0)); - assertThat(((Avg) result.getBuckets().get(1).aggregations.get("avg")).getValue(), equalTo(2.0)); + assertThat(((Max) result.getBuckets().get(1).getAggregations().get("max")).value(), equalTo(4.0)); + assertThat(((Min) result.getBuckets().get(1).getAggregations().get("min")).value(), equalTo(0.0)); + assertThat(((Avg) result.getBuckets().get(1).getAggregations().get("avg")).getValue(), equalTo(2.0)); }, new TextFieldMapper.TextFieldType(TEXT_FIELD_NAME), longField(NUMERIC_FIELD_NAME) @@ -123,9 +123,9 @@ public void testCategorizationWithMultiBucketSubAggs() throws Exception { CategorizeTextAggregatorTests::writeTestDocs, (InternalCategorizationAggregation result) -> { assertThat(result.getBuckets(), hasSize(2)); - assertThat(result.getBuckets().get(0).docCount, equalTo(6L)); + assertThat(result.getBuckets().get(0).getDocCount(), equalTo(6L)); assertThat(result.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); - Histogram histo = result.getBuckets().get(0).aggregations.get("histo"); + Histogram histo = result.getBuckets().get(0).getAggregations().get("histo"); assertThat(histo.getBuckets(), hasSize(3)); for (Histogram.Bucket bucket : histo.getBuckets()) { assertThat(bucket.getDocCount(), equalTo(2L)); @@ -140,12 +140,12 @@ public void testCategorizationWithMultiBucketSubAggs() throws Exception { assertThat(((Min) histo.getBuckets().get(2).getAggregations().get("min")).value(), equalTo(4.0)); assertThat(((Avg) histo.getBuckets().get(2).getAggregations().get("avg")).getValue(), equalTo(4.5)); - assertThat(result.getBuckets().get(1).docCount, equalTo(2L)); + assertThat(result.getBuckets().get(1).getDocCount(), equalTo(2L)); assertThat( result.getBuckets().get(1).getKeyAsString(), equalTo("Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception") ); - histo = result.getBuckets().get(1).aggregations.get("histo"); + histo = result.getBuckets().get(1).getAggregations().get("histo"); assertThat(histo.getBuckets(), hasSize(3)); assertThat(histo.getBuckets().get(0).getDocCount(), equalTo(1L)); assertThat(histo.getBuckets().get(1).getDocCount(), equalTo(0L)); @@ -175,49 +175,49 @@ public void testCategorizationAsSubAgg() throws Exception { assertThat(result.getBuckets().get(0).getDocCount(), equalTo(3L)); InternalCategorizationAggregation categorizationAggregation = result.getBuckets().get(0).getAggregations().get("my_agg"); assertThat(categorizationAggregation.getBuckets(), hasSize(2)); - assertThat(categorizationAggregation.getBuckets().get(0).docCount, equalTo(2L)); + assertThat(categorizationAggregation.getBuckets().get(0).getDocCount(), equalTo(2L)); assertThat(categorizationAggregation.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); - assertThat(((Max) categorizationAggregation.getBuckets().get(0).aggregations.get("max")).value(), equalTo(1.0)); - assertThat(((Min) categorizationAggregation.getBuckets().get(0).aggregations.get("min")).value(), equalTo(0.0)); - assertThat(((Avg) categorizationAggregation.getBuckets().get(0).aggregations.get("avg")).getValue(), equalTo(0.5)); + assertThat(((Max) categorizationAggregation.getBuckets().get(0).getAggregations().get("max")).value(), equalTo(1.0)); + assertThat(((Min) categorizationAggregation.getBuckets().get(0).getAggregations().get("min")).value(), equalTo(0.0)); + assertThat(((Avg) categorizationAggregation.getBuckets().get(0).getAggregations().get("avg")).getValue(), equalTo(0.5)); - assertThat(categorizationAggregation.getBuckets().get(1).docCount, equalTo(1L)); + assertThat(categorizationAggregation.getBuckets().get(1).getDocCount(), equalTo(1L)); assertThat( categorizationAggregation.getBuckets().get(1).getKeyAsString(), equalTo("Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception") ); - assertThat(((Max) categorizationAggregation.getBuckets().get(1).aggregations.get("max")).value(), equalTo(0.0)); - assertThat(((Min) categorizationAggregation.getBuckets().get(1).aggregations.get("min")).value(), equalTo(0.0)); - assertThat(((Avg) categorizationAggregation.getBuckets().get(1).aggregations.get("avg")).getValue(), equalTo(0.0)); + assertThat(((Max) categorizationAggregation.getBuckets().get(1).getAggregations().get("max")).value(), equalTo(0.0)); + assertThat(((Min) categorizationAggregation.getBuckets().get(1).getAggregations().get("min")).value(), equalTo(0.0)); + assertThat(((Avg) categorizationAggregation.getBuckets().get(1).getAggregations().get("avg")).getValue(), equalTo(0.0)); // Second histo bucket assertThat(result.getBuckets().get(1).getDocCount(), equalTo(2L)); categorizationAggregation = result.getBuckets().get(1).getAggregations().get("my_agg"); assertThat(categorizationAggregation.getBuckets(), hasSize(1)); - assertThat(categorizationAggregation.getBuckets().get(0).docCount, equalTo(2L)); + assertThat(categorizationAggregation.getBuckets().get(0).getDocCount(), equalTo(2L)); assertThat(categorizationAggregation.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); - assertThat(((Max) categorizationAggregation.getBuckets().get(0).aggregations.get("max")).value(), equalTo(3.0)); - assertThat(((Min) categorizationAggregation.getBuckets().get(0).aggregations.get("min")).value(), equalTo(2.0)); - assertThat(((Avg) categorizationAggregation.getBuckets().get(0).aggregations.get("avg")).getValue(), equalTo(2.5)); + assertThat(((Max) categorizationAggregation.getBuckets().get(0).getAggregations().get("max")).value(), equalTo(3.0)); + assertThat(((Min) categorizationAggregation.getBuckets().get(0).getAggregations().get("min")).value(), equalTo(2.0)); + assertThat(((Avg) categorizationAggregation.getBuckets().get(0).getAggregations().get("avg")).getValue(), equalTo(2.5)); // Third histo bucket assertThat(result.getBuckets().get(2).getDocCount(), equalTo(3L)); categorizationAggregation = result.getBuckets().get(2).getAggregations().get("my_agg"); assertThat(categorizationAggregation.getBuckets(), hasSize(2)); - assertThat(categorizationAggregation.getBuckets().get(0).docCount, equalTo(2L)); + assertThat(categorizationAggregation.getBuckets().get(0).getDocCount(), equalTo(2L)); assertThat(categorizationAggregation.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); - assertThat(((Max) categorizationAggregation.getBuckets().get(0).aggregations.get("max")).value(), equalTo(5.0)); - assertThat(((Min) categorizationAggregation.getBuckets().get(0).aggregations.get("min")).value(), equalTo(4.0)); - assertThat(((Avg) categorizationAggregation.getBuckets().get(0).aggregations.get("avg")).getValue(), equalTo(4.5)); + assertThat(((Max) categorizationAggregation.getBuckets().get(0).getAggregations().get("max")).value(), equalTo(5.0)); + assertThat(((Min) categorizationAggregation.getBuckets().get(0).getAggregations().get("min")).value(), equalTo(4.0)); + assertThat(((Avg) categorizationAggregation.getBuckets().get(0).getAggregations().get("avg")).getValue(), equalTo(4.5)); - assertThat(categorizationAggregation.getBuckets().get(1).docCount, equalTo(1L)); + assertThat(categorizationAggregation.getBuckets().get(1).getDocCount(), equalTo(1L)); assertThat( categorizationAggregation.getBuckets().get(1).getKeyAsString(), equalTo("Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception") ); - assertThat(((Max) categorizationAggregation.getBuckets().get(1).aggregations.get("max")).value(), equalTo(4.0)); - assertThat(((Min) categorizationAggregation.getBuckets().get(1).aggregations.get("min")).value(), equalTo(4.0)); - assertThat(((Avg) categorizationAggregation.getBuckets().get(1).aggregations.get("avg")).getValue(), equalTo(4.0)); + assertThat(((Max) categorizationAggregation.getBuckets().get(1).getAggregations().get("max")).value(), equalTo(4.0)); + assertThat(((Min) categorizationAggregation.getBuckets().get(1).getAggregations().get("min")).value(), equalTo(4.0)); + assertThat(((Avg) categorizationAggregation.getBuckets().get(1).getAggregations().get("avg")).getValue(), equalTo(4.0)); }, new TextFieldMapper.TextFieldType(TEXT_FIELD_NAME), longField(NUMERIC_FIELD_NAME)); } @@ -235,12 +235,12 @@ public void testCategorizationWithSubAggsManyDocs() throws Exception { CategorizeTextAggregatorTests::writeManyTestDocs, (InternalCategorizationAggregation result) -> { assertThat(result.getBuckets(), hasSize(2)); - assertThat(result.getBuckets().get(0).docCount, equalTo(30_000L)); + assertThat(result.getBuckets().get(0).getDocCount(), equalTo(30000L)); assertThat(result.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); - Histogram histo = result.getBuckets().get(0).aggregations.get("histo"); + Histogram histo = result.getBuckets().get(0).getAggregations().get("histo"); assertThat(histo.getBuckets(), hasSize(3)); for (Histogram.Bucket bucket : histo.getBuckets()) { - assertThat(bucket.getDocCount(), equalTo(10_000L)); + assertThat(bucket.getDocCount(), equalTo(10000L)); } assertThat(((Max) histo.getBuckets().get(0).getAggregations().get("max")).value(), equalTo(1.0)); assertThat(((Min) histo.getBuckets().get(0).getAggregations().get("min")).value(), equalTo(0.0)); @@ -252,16 +252,16 @@ public void testCategorizationWithSubAggsManyDocs() throws Exception { assertThat(((Min) histo.getBuckets().get(2).getAggregations().get("min")).value(), equalTo(4.0)); assertThat(((Avg) histo.getBuckets().get(2).getAggregations().get("avg")).getValue(), equalTo(4.5)); - assertThat(result.getBuckets().get(1).docCount, equalTo(10_000L)); + assertThat(result.getBuckets().get(1).getDocCount(), equalTo(10000L)); assertThat( result.getBuckets().get(1).getKeyAsString(), equalTo("Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception") ); - histo = result.getBuckets().get(1).aggregations.get("histo"); + histo = result.getBuckets().get(1).getAggregations().get("histo"); assertThat(histo.getBuckets(), hasSize(3)); - assertThat(histo.getBuckets().get(0).getDocCount(), equalTo(5_000L)); + assertThat(histo.getBuckets().get(0).getDocCount(), equalTo(5000L)); assertThat(histo.getBuckets().get(1).getDocCount(), equalTo(0L)); - assertThat(histo.getBuckets().get(2).getDocCount(), equalTo(5_000L)); + assertThat(histo.getBuckets().get(2).getDocCount(), equalTo(5000L)); assertThat(((Avg) histo.getBuckets().get(0).getAggregations().get("avg")).getValue(), equalTo(0.0)); assertThat(((Avg) histo.getBuckets().get(2).getAggregations().get("avg")).getValue(), equalTo(4.0)); }, @@ -328,7 +328,7 @@ private static void writeTestDocs(RandomIndexWriter w) throws IOException { } private static void writeManyTestDocs(RandomIndexWriter w) throws IOException { - for (int i = 0; i < 5_000; i++) { + for (int i = 0; i < 5000; i++) { writeTestDocs(w); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InnerTreeNodeTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InnerTreeNodeTests.java deleted file mode 100644 index e7f78a01d0130..0000000000000 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InnerTreeNodeTests.java +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.ml.aggs.categorization; - -import org.apache.lucene.util.BytesRef; -import org.elasticsearch.common.util.BytesRefHash; -import org.elasticsearch.test.ESTestCase; -import org.junit.After; -import org.junit.Before; - -import static org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash.WILD_CARD_ID; -import static org.elasticsearch.xpack.ml.aggs.categorization.TextCategorizationTests.getTokens; -import static org.elasticsearch.xpack.ml.aggs.categorization.TextCategorizationTests.mockBigArrays; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.is; - -public class InnerTreeNodeTests extends ESTestCase { - - private final CategorizationTokenTree factory = new CategorizationTokenTree(3, 4, 60); - private CategorizationBytesRefHash bytesRefHash; - - @Before - public void createRefHash() { - bytesRefHash = new CategorizationBytesRefHash(new BytesRefHash(1L, mockBigArrays())); - } - - @After - public void closeRefHash() { - bytesRefHash.close(); - } - - public void testAddText() { - TreeNode.InnerTreeNode innerTreeNode = new TreeNode.InnerTreeNode(1, 0, 3); - TextCategorization group = innerTreeNode.addText(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 1, factory); - assertArrayEquals(group.getCategorization(), getTokens(bytesRefHash, "foo", "bar", "baz", "biz")); - - assertArrayEquals( - innerTreeNode.addText(getTokens(bytesRefHash, "foo2", "bar", "baz", "biz"), 1, factory).getCategorization(), - getTokens(bytesRefHash, "foo2", "bar", "baz", "biz") - ); - assertArrayEquals( - innerTreeNode.addText(getTokens(bytesRefHash, "foo3", "bar", "baz", "biz"), 1, factory).getCategorization(), - getTokens(bytesRefHash, "foo3", "bar", "baz", "biz") - ); - assertArrayEquals( - innerTreeNode.addText(getTokens(bytesRefHash, "foo4", "bar", "baz", "biz"), 1, factory).getCategorization(), - getTokens(bytesRefHash, "*", "bar", "baz", "biz") - ); - assertArrayEquals( - innerTreeNode.addText(getTokens(bytesRefHash, "foo", "bar", "baz", "bizzy"), 1, factory).getCategorization(), - getTokens(bytesRefHash, "foo", "bar", "baz", "*") - ); - } - - public void testAddTokensWithLargerIncoming() { - TreeNode.InnerTreeNode innerTreeNode = new TreeNode.InnerTreeNode(1, 0, 3); - TextCategorization group = innerTreeNode.addText(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 100, factory); - assertArrayEquals(group.getCategorization(), getTokens(bytesRefHash, "foo", "bar", "baz", "biz")); - - assertArrayEquals( - innerTreeNode.addText(getTokens(bytesRefHash, "foo2", "bar", "baz", "biz"), 100, factory).getCategorization(), - getTokens(bytesRefHash, "foo2", "bar", "baz", "biz") - ); - assertArrayEquals( - innerTreeNode.addText(getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz"), 1, factory).getCategorization(), - getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz") - ); - assertArrayEquals( - innerTreeNode.addText(getTokens(bytesRefHash, "foobigun", "bar", "baz", "biz"), 1000, factory).getCategorization(), - getTokens(bytesRefHash, "foobigun", "bar", "baz", "biz") - ); - assertThat( - innerTreeNode.getCategorization(getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz")).getCategorization(), - equalTo(getTokens(bytesRefHash, "*", "bar", "baz", "biz")) - ); - } - - public void testCollapseTinyChildren() { - TreeNode.InnerTreeNode innerTreeNode = new TreeNode.InnerTreeNode(1000, 0, 4); - TextCategorization group = innerTreeNode.addText(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 1000, factory); - assertArrayEquals(group.getCategorization(), getTokens(bytesRefHash, "foo", "bar", "baz", "biz")); - - assertArrayEquals( - innerTreeNode.addText(getTokens(bytesRefHash, "foo2", "bar", "baz", "biz"), 1000, factory).getCategorization(), - getTokens(bytesRefHash, "foo2", "bar", "baz", "biz") - ); - innerTreeNode.incCount(1000); - assertArrayEquals( - innerTreeNode.addText(getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz"), 1, factory).getCategorization(), - getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz") - ); - innerTreeNode.incCount(1); - innerTreeNode.collapseTinyChildren(); - assertThat(innerTreeNode.hasChild(bytesRefHash.put(new BytesRef("foosmall"))), is(false)); - assertThat(innerTreeNode.hasChild(WILD_CARD_ID), is(true)); - } - - public void testMergeWith() { - TreeNode.InnerTreeNode innerTreeNode = new TreeNode.InnerTreeNode(1000, 0, 3); - innerTreeNode.addText(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 1000, factory); - innerTreeNode.incCount(1000); - innerTreeNode.addText(getTokens(bytesRefHash, "foo2", "bar", "baz", "biz"), 1000, factory); - - expectThrows(UnsupportedOperationException.class, () -> innerTreeNode.mergeWith(new TreeNode.LeafTreeNode(1, 60))); - - TreeNode.InnerTreeNode mergeWith = new TreeNode.InnerTreeNode(1, 0, 3); - innerTreeNode.addText(getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz"), 1, factory); - innerTreeNode.incCount(1); - innerTreeNode.addText(getTokens(bytesRefHash, "footiny", "bar", "baz", "biz"), 1, factory); - - innerTreeNode.mergeWith(mergeWith); - assertThat(innerTreeNode.hasChild(WILD_CARD_ID), is(true)); - assertArrayEquals( - innerTreeNode.getCategorization(getTokens(bytesRefHash, "footiny", "bar", "baz", "biz")).getCategorization(), - getTokens(bytesRefHash, "*", "bar", "baz", "biz") - ); - } -} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregationTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregationTests.java index 6c4b43bb08acb..add1638a58b86 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregationTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregationTests.java @@ -7,8 +7,9 @@ package org.elasticsearch.xpack.ml.aggs.categorization; -import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BytesRefHash; import org.elasticsearch.common.util.CollectionUtils; import org.elasticsearch.plugins.SearchPlugin; import org.elasticsearch.search.aggregations.Aggregation; @@ -18,19 +19,34 @@ import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xpack.ml.MachineLearning; +import org.junit.After; +import org.junit.Before; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.Stream; +import static org.hamcrest.Matchers.equalTo; + public class InternalCategorizationAggregationTests extends InternalMultiBucketAggregationTestCase { + private CategorizationBytesRefHash bytesRefHash; + + @Before + public void createHash() { + bytesRefHash = new CategorizationBytesRefHash(new BytesRefHash(2048, BigArrays.NON_RECYCLING_INSTANCE)); + } + + @After + public void destroyHash() { + bytesRefHash.close(); + } + @Override protected SearchPlugin registerPlugin() { return new MachineLearning(Settings.EMPTY); @@ -55,7 +71,7 @@ protected void assertReduced(InternalCategorizationAggregation reduced, List expectedReducedCounts = new HashMap<>(totalCounts); expectedReducedCounts.keySet().retainAll(reducedCounts.keySet()); - assertEquals(expectedReducedCounts, reducedCounts); + assertThat(reducedCounts, equalTo(expectedReducedCounts)); } @Override @@ -63,13 +79,6 @@ protected Predicate excludePathsFromXContentInsertion() { return p -> p.contains("key"); } - static InternalCategorizationAggregation.BucketKey randomKey() { - int numVals = randomIntBetween(1, 50); - return new InternalCategorizationAggregation.BucketKey( - Stream.generate(() -> randomAlphaOfLength(10)).limit(numVals).map(BytesRef::new).toArray(BytesRef[]::new) - ); - } - @Override protected InternalCategorizationAggregation createTestInstance( String name, @@ -78,22 +87,16 @@ protected InternalCategorizationAggregation createTestInstance( ) { List buckets = new ArrayList<>(); final int numBuckets = randomNumberOfBuckets(); - HashSet keys = new HashSet<>(); for (int i = 0; i < numBuckets; ++i) { - InternalCategorizationAggregation.BucketKey key = randomValueOtherThanMany( - l -> keys.add(l) == false, - InternalCategorizationAggregationTests::randomKey + buckets.add( + new InternalCategorizationAggregation.Bucket(SerializableTokenListCategoryTests.createTestInstance(bytesRefHash), -1) ); - int docCount = randomIntBetween(1, 100); - buckets.add(new InternalCategorizationAggregation.Bucket(key, docCount, aggregations)); } Collections.sort(buckets); return new InternalCategorizationAggregation( name, randomIntBetween(10, 100), randomLongBetween(1, 10), - randomIntBetween(1, 500), - randomIntBetween(1, 10), randomIntBetween(1, 100), metadata, buckets diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LeafTreeNodeTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LeafTreeNodeTests.java deleted file mode 100644 index 2bef18993e019..0000000000000 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LeafTreeNodeTests.java +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.ml.aggs.categorization; - -import org.elasticsearch.common.util.BytesRefHash; -import org.elasticsearch.test.ESTestCase; -import org.junit.After; -import org.junit.Before; - -import static org.elasticsearch.xpack.ml.aggs.categorization.TextCategorizationTests.getTokens; -import static org.elasticsearch.xpack.ml.aggs.categorization.TextCategorizationTests.mockBigArrays; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.greaterThan; -import static org.hamcrest.Matchers.hasSize; - -public class LeafTreeNodeTests extends ESTestCase { - - private final CategorizationTokenTree factory = new CategorizationTokenTree(10, 10, 60); - - private CategorizationBytesRefHash bytesRefHash; - - @Before - public void createRefHash() { - bytesRefHash = new CategorizationBytesRefHash(new BytesRefHash(1L, mockBigArrays())); - } - - @After - public void closeRefHash() { - bytesRefHash.close(); - } - - public void testAddGroup() { - TreeNode.LeafTreeNode leafTreeNode = new TreeNode.LeafTreeNode(0, 60); - TextCategorization group = leafTreeNode.addText(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 1, factory); - - assertArrayEquals(group.getCategorization(), getTokens(bytesRefHash, "foo", "bar", "baz", "biz")); - assertThat(group.getCount(), equalTo(1L)); - assertThat(leafTreeNode.getAllChildrenTextCategorizations(), hasSize(1)); - long previousBytesUsed = leafTreeNode.ramBytesUsed(); - - group = leafTreeNode.addText(getTokens(bytesRefHash, "foo", "bar", "bozo", "bizzy"), 1, factory); - assertArrayEquals(group.getCategorization(), getTokens(bytesRefHash, "foo", "bar", "bozo", "bizzy")); - assertThat(group.getCount(), equalTo(1L)); - assertThat(leafTreeNode.getAllChildrenTextCategorizations(), hasSize(2)); - assertThat(leafTreeNode.ramBytesUsed(), greaterThan(previousBytesUsed)); - previousBytesUsed = leafTreeNode.ramBytesUsed(); - - group = leafTreeNode.addText(getTokens(bytesRefHash, "foo", "bar", "baz", "different"), 3, factory); - assertArrayEquals(group.getCategorization(), getTokens(bytesRefHash, "foo", "bar", "baz", "*")); - assertThat(group.getCount(), equalTo(4L)); - assertThat(leafTreeNode.getAllChildrenTextCategorizations(), hasSize(2)); - assertThat(previousBytesUsed, equalTo(leafTreeNode.ramBytesUsed())); - } - - public void testMergeWith() { - TreeNode.LeafTreeNode leafTreeNode = new TreeNode.LeafTreeNode(0, 60); - leafTreeNode.mergeWith(null); - assertThat(leafTreeNode, equalTo(new TreeNode.LeafTreeNode(0, 60))); - - expectThrows(UnsupportedOperationException.class, () -> leafTreeNode.mergeWith(new TreeNode.InnerTreeNode(1, 2, 3))); - - leafTreeNode.incCount(5); - leafTreeNode.addText(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 5, factory); - - TreeNode.LeafTreeNode toMerge = new TreeNode.LeafTreeNode(0, 60); - leafTreeNode.incCount(1); - toMerge.addText(getTokens(bytesRefHash, "foo", "bar", "baz", "bizzy"), 1, factory); - leafTreeNode.incCount(1); - toMerge.addText(getTokens(bytesRefHash, "foo", "bart", "bat", "built"), 1, factory); - leafTreeNode.mergeWith(toMerge); - - assertThat(leafTreeNode.getAllChildrenTextCategorizations(), hasSize(2)); - assertThat(leafTreeNode.getCount(), equalTo(7L)); - assertArrayEquals( - leafTreeNode.getAllChildrenTextCategorizations().get(0).getCategorization(), - getTokens(bytesRefHash, "foo", "bar", "baz", "*") - ); - assertArrayEquals( - leafTreeNode.getAllChildrenTextCategorizations().get(1).getCategorization(), - getTokens(bytesRefHash, "foo", "bart", "bat", "built") - ); - } -} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/ParsedCategorization.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/ParsedCategorization.java index 283c200c27d0f..9144dfd11182c 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/ParsedCategorization.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/ParsedCategorization.java @@ -8,18 +8,25 @@ package org.elasticsearch.xpack.ml.aggs.categorization; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.CheckedBiConsumer; import org.elasticsearch.common.Strings; import org.elasticsearch.common.xcontent.XContentParserUtils; +import org.elasticsearch.search.aggregations.Aggregation; +import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.search.aggregations.ParsedMultiBucketAggregation; import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation; import org.elasticsearch.xcontent.ObjectParser; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.job.results.CategoryDefinition; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.function.Supplier; +// TODO: how close to the actual InternalCategorizationAggregation.Bucket class does this have to be to add any value? class ParsedCategorization extends ParsedMultiBucketAggregation { @Override @@ -50,6 +57,7 @@ public List getBuckets() { public static class ParsedBucket extends ParsedMultiBucketAggregation.ParsedBucket implements MultiBucketsAggregation.Bucket { private InternalCategorizationAggregation.BucketKey key; + private int maxMatchingLength; protected void setKeyAsString(String keyAsString) { if (keyAsString == null) { @@ -68,6 +76,10 @@ protected void setKeyAsString(String keyAsString) { ); } + private void setMaxMatchingLength(int maxMatchingLength) { + this.maxMatchingLength = maxMatchingLength; + } + @Override public Object getKey() { return key; @@ -75,7 +87,7 @@ public Object getKey() { @Override public String getKeyAsString() { - return key.asString(); + return key.toString(); } @Override @@ -83,6 +95,17 @@ protected XContentBuilder keyToXContent(XContentBuilder builder) throws IOExcept return builder.field(CommonFields.KEY.getPreferredName(), getKey()); } + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + keyToXContent(builder); + builder.field(CategoryDefinition.MAX_MATCHING_LENGTH.getPreferredName(), maxMatchingLength); + builder.field(CommonFields.DOC_COUNT.getPreferredName(), getDocCount()); + getAggregations().toXContentInternal(builder, params); + builder.endObject(); + return builder; + } + static InternalCategorizationAggregation.BucketKey parsedKey(final XContentParser parser) throws IOException { if (parser.currentToken() == XContentParser.Token.VALUE_STRING) { String toSplit = parser.text(); @@ -99,15 +122,48 @@ static InternalCategorizationAggregation.BucketKey parsedKey(final XContentParse } } - static ParsedBucket fromXContent(final XContentParser parser) throws IOException { - return ParsedMultiBucketAggregation.ParsedBucket.parseXContent( - parser, - false, - ParsedBucket::new, - (p, bucket) -> bucket.key = parsedKey(p) - ); + protected static ParsedBucket parseCategorizationBucketXContent( + final XContentParser parser, + final Supplier bucketSupplier, + final CheckedBiConsumer keyConsumer + ) throws IOException { + final ParsedBucket bucket = bucketSupplier.get(); + XContentParser.Token token; + String currentFieldName = parser.currentName(); + + List aggregations = new ArrayList<>(); + while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { + if (token == XContentParser.Token.FIELD_NAME) { + currentFieldName = parser.currentName(); + } else if (token.isValue()) { + if (CommonFields.KEY_AS_STRING.getPreferredName().equals(currentFieldName)) { + bucket.setKeyAsString(parser.text()); + } else if (CommonFields.KEY.getPreferredName().equals(currentFieldName)) { + keyConsumer.accept(parser, bucket); + } else if (CommonFields.DOC_COUNT.getPreferredName().equals(currentFieldName)) { + bucket.setDocCount(parser.longValue()); + } else if (CategoryDefinition.MAX_MATCHING_LENGTH.getPreferredName().equals(currentFieldName)) { + bucket.setMaxMatchingLength(parser.intValue()); + } + } else if (token == XContentParser.Token.START_OBJECT) { + if (CommonFields.KEY.getPreferredName().equals(currentFieldName)) { + keyConsumer.accept(parser, bucket); + } else { + XContentParserUtils.parseTypedKeysObject( + parser, + Aggregation.TYPED_KEYS_DELIMITER, + Aggregation.class, + aggregations::add + ); + } + } + } + bucket.setAggregations(new Aggregations(aggregations)); + return bucket; } + static ParsedBucket fromXContent(final XContentParser parser) throws IOException { + return parseCategorizationBucketXContent(parser, ParsedBucket::new, (p, bucket) -> bucket.key = parsedKey(p)); + } } - } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization2/SerializableTokenListCategoryTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/SerializableTokenListCategoryTests.java similarity index 96% rename from x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization2/SerializableTokenListCategoryTests.java rename to x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/SerializableTokenListCategoryTests.java index b1fd65fd162a6..a001b802e8b3b 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization2/SerializableTokenListCategoryTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/SerializableTokenListCategoryTests.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.ml.aggs.categorization2; +package org.elasticsearch.xpack.ml.aggs.categorization; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.util.BigArrays; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorizationTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorizationTests.java deleted file mode 100644 index 59129f8801937..0000000000000 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorizationTests.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.ml.aggs.categorization; - -import org.apache.lucene.util.BytesRef; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.util.BigArrays; -import org.elasticsearch.common.util.BytesRefHash; -import org.elasticsearch.common.util.MockBigArrays; -import org.elasticsearch.common.util.MockPageCacheRecycler; -import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; -import org.elasticsearch.test.ESTestCase; -import org.junit.After; -import org.junit.Before; - -import java.io.IOException; - -import static org.hamcrest.Matchers.closeTo; -import static org.hamcrest.Matchers.equalTo; - -public class TextCategorizationTests extends ESTestCase { - - static BigArrays mockBigArrays() { - return new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService()); - } - - private CategorizationBytesRefHash bytesRefHash; - - @Before - public void createRefHash() { - bytesRefHash = new CategorizationBytesRefHash(new BytesRefHash(1L, mockBigArrays())); - } - - @After - public void closeRefHash() throws IOException { - bytesRefHash.close(); - } - - public void testSimilarity() { - TextCategorization lg = new TextCategorization(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 1, 1); - TextCategorization.Similarity sims = lg.calculateSimilarity(getTokens(bytesRefHash, "not", "matching", "anything", "nope")); - assertThat(sims.getSimilarity(), equalTo(0.0)); - assertThat(sims.getWildCardCount(), equalTo(0)); - - sims = lg.calculateSimilarity(getTokens(bytesRefHash, "foo", "bar", "baz", "biz")); - assertThat(sims.getSimilarity(), equalTo(1.0)); - assertThat(sims.getWildCardCount(), equalTo(0)); - - sims = lg.calculateSimilarity(getTokens(bytesRefHash, "foo", "fooagain", "notbar", "biz")); - assertThat(sims.getSimilarity(), closeTo(0.5, 0.0001)); - assertThat(sims.getWildCardCount(), equalTo(0)); - } - - public void testAddTokens() { - TextCategorization lg = new TextCategorization(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 1, 1); - lg.addTokens(getTokens(bytesRefHash, "foo", "bar", "baz", "bozo"), 2); - assertThat(lg.getCount(), equalTo(3L)); - assertArrayEquals(lg.getCategorization(), getTokens(bytesRefHash, "foo", "bar", "baz", "*")); - } - - static int[] getTokens(CategorizationBytesRefHash bytesRefHash, String... tokens) { - BytesRef[] refs = new BytesRef[tokens.length]; - int i = 0; - for (String token : tokens) { - refs[i++] = new BytesRef(token); - } - return bytesRefHash.getIds(refs); - } - -} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization2/TokenListCategorizerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/TokenListCategorizerTests.java similarity index 90% rename from x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization2/TokenListCategorizerTests.java rename to x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/TokenListCategorizerTests.java index 65b0179fe8457..e4648f4a52404 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization2/TokenListCategorizerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/TokenListCategorizerTests.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.ml.aggs.categorization2; +package org.elasticsearch.xpack.ml.aggs.categorization; import org.apache.lucene.analysis.TokenStream; import org.elasticsearch.analysis.common.CommonAnalysisPlugin; @@ -79,16 +79,27 @@ public void testWeightCalculator() throws IOException { assertThat(weightCalculator.calculateWeight("my_host"), equalTo(1)); // not dictionary word assertThat(weightCalculator.calculateWeight("web"), equalTo(3)); // dictionary word, not verb, not 3rd in a row assertThat(weightCalculator.calculateWeight("service"), equalTo(3)); // dictionary word, not verb, not 3rd in a row - assertThat(weightCalculator.calculateWeight("starting"), equalTo(12)); // dictionary word, verb, 3rd in a row + assertThat(weightCalculator.calculateWeight("starting"), equalTo(31)); // dictionary word, verb, 3rd in a row assertThat(weightCalculator.calculateWeight("user123"), equalTo(1)); // not dictionary word assertThat(weightCalculator.calculateWeight("a"), equalTo(1)); // too short for dictionary weighting assertThat(weightCalculator.calculateWeight("cool"), equalTo(3)); // dictionary word, not verb, not 3rd in a row assertThat(weightCalculator.calculateWeight("web"), equalTo(3)); // dictionary word, not verb, not 3rd in a row - assertThat(weightCalculator.calculateWeight("service"), equalTo(9)); // dictionary word, not verb, 3rd in a row - assertThat(weightCalculator.calculateWeight("called"), equalTo(9)); // dictionary word, not verb, 4th in a row + assertThat(weightCalculator.calculateWeight("service"), equalTo(13)); // dictionary word, not verb, 3rd in a row + assertThat(weightCalculator.calculateWeight("called"), equalTo(13)); // dictionary word, not verb, 4th in a row assertThat(weightCalculator.calculateWeight("my_service"), equalTo(1)); // not dictionary word assertThat(weightCalculator.calculateWeight("is"), equalTo(3)); // dictionary word, not verb, not 3rd in a row assertThat(weightCalculator.calculateWeight("starting"), equalTo(6)); // dictionary word, verb, not 3rd in a row + + assertThat(TokenListCategorizer.WeightCalculator.getMinMatchingWeight(1), equalTo(1)); + assertThat(TokenListCategorizer.WeightCalculator.getMaxMatchingWeight(1), equalTo(1)); + assertThat(TokenListCategorizer.WeightCalculator.getMinMatchingWeight(3), equalTo(3)); + assertThat(TokenListCategorizer.WeightCalculator.getMaxMatchingWeight(3), equalTo(13)); + assertThat(TokenListCategorizer.WeightCalculator.getMinMatchingWeight(6), equalTo(6)); + assertThat(TokenListCategorizer.WeightCalculator.getMaxMatchingWeight(6), equalTo(31)); + assertThat(TokenListCategorizer.WeightCalculator.getMinMatchingWeight(13), equalTo(3)); + assertThat(TokenListCategorizer.WeightCalculator.getMaxMatchingWeight(13), equalTo(13)); + assertThat(TokenListCategorizer.WeightCalculator.getMinMatchingWeight(31), equalTo(6)); + assertThat(TokenListCategorizer.WeightCalculator.getMaxMatchingWeight(31), equalTo(31)); } public void testApacheData() throws IOException { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization2/TokenListCategoryTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/TokenListCategoryTests.java similarity index 92% rename from x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization2/TokenListCategoryTests.java rename to x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/TokenListCategoryTests.java index f4add94bf6c1e..8675788af7202 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization2/TokenListCategoryTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/TokenListCategoryTests.java @@ -5,12 +5,12 @@ * 2.0. */ -package org.elasticsearch.xpack.ml.aggs.categorization2; +package org.elasticsearch.xpack.ml.aggs.categorization; import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.BytesRefHash; -import org.elasticsearch.xpack.ml.aggs.categorization2.TokenListCategory.TokenAndWeight; +import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategory.TokenAndWeight; import java.util.ArrayList; import java.util.List; @@ -18,6 +18,7 @@ import java.util.stream.Collectors; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; public class TokenListCategoryTests extends CategorizationTestCase { @@ -34,7 +35,7 @@ public void testCommonTokensSameOrder() { tw("seashore", 2) ); - List baseUniqueTokenIds = baseTokenIds.stream().sorted().distinct().collect(Collectors.toList()); + List baseUniqueTokenIds = baseTokenIds.stream().sorted().distinct().toList(); TokenListCategory category = new TokenListCategory(1, baseString.length(), baseTokenIds, baseUniqueTokenIds, 1); @@ -48,7 +49,7 @@ public void testCommonTokensSameOrder() { tw("the", 2), tw("seashore", 2) ); - List newUniqueTokenIds = newTokenIds.stream().sorted().distinct().collect(Collectors.toList()); + List newUniqueTokenIds = newTokenIds.stream().sorted().distinct().toList(); category.addString(newString.length(), newTokenIds, newUniqueTokenIds, 1); @@ -82,7 +83,7 @@ public void testCommonTokensDifferentOrder() { tw("the", 2), tw("seashore", 2) ); - List baseUniqueTokenIds = baseTokenIds.stream().sorted().distinct().collect(Collectors.toList()); + List baseUniqueTokenIds = baseTokenIds.stream().sorted().distinct().toList(); TokenListCategory category = new TokenListCategory(1, baseString.length(), baseTokenIds, baseUniqueTokenIds, 1); @@ -96,7 +97,7 @@ public void testCommonTokensDifferentOrder() { tw("she", 2), tw("does", 2) ); - List newUniqueTokenIds1 = newTokenIds1.stream().sorted().distinct().collect(Collectors.toList()); + List newUniqueTokenIds1 = newTokenIds1.stream().sorted().distinct().toList(); category.addString(newString1.length(), newTokenIds1, newUniqueTokenIds1, 1); @@ -122,7 +123,7 @@ public void testCommonTokensDifferentOrder() { tw("the", 2), tw("seashore", 2) ); - List newUniqueTokenIds2 = newTokenIds2.stream().sorted().distinct().collect(Collectors.toList()); + List newUniqueTokenIds2 = newTokenIds2.stream().sorted().distinct().toList(); category.addString(newString2.length(), newTokenIds2, newUniqueTokenIds2, 1); @@ -146,7 +147,7 @@ public void testCommonTokensDifferentOrder() { String newString3 = "the rock"; List newTokenIds3 = List.of(tw("the", 2), tw("rock", 2)); - List newUniqueTokenIds3 = newTokenIds3.stream().sorted().distinct().collect(Collectors.toList()); + List newUniqueTokenIds3 = newTokenIds3.stream().sorted().distinct().toList(); category.addString(newString3.length(), newTokenIds3, newUniqueTokenIds3, 1); @@ -206,6 +207,16 @@ public void testRoundTripBetweenNodes() { assertThat(node1RoundTrippedCategory2.ramBytesUsed(), equalTo(node1RoundTrippedCategory2.ramBytesUsedSlow())); } + public void testMissingCommonTokenWeightZeroForSupersets() { + TokenListCategory category = createTestInstance(bytesRefHash, 1); + List uniqueTokenIds = new ArrayList<>(category.getCommonUniqueTokenIds()); + for (int i = 0; i < 5; ++i) { + assertThat(category.missingCommonTokenWeight(uniqueTokenIds), is(0)); + uniqueTokenIds.add(new TokenAndWeight(randomIntBetween(1, 10), randomIntBetween(1, 10))); + uniqueTokenIds.sort(TokenAndWeight::compareTo); + } + } + public static TokenListCategory createTestInstance(CategorizationBytesRefHash bytesRefHash, int id) { int unfilteredStringLength = 0; @@ -228,7 +239,7 @@ public static TokenListCategory createTestInstance(CategorizationBytesRefHash by .entrySet() .stream() .map(entry -> new TokenAndWeight(entry.getKey(), entry.getValue())) - .collect(Collectors.toList()); + .toList(); return new TokenListCategory(id, unfilteredStringLength, baseWeightedTokenIds, uniqueWeightedTokenIds, randomLongBetween(1, 10)); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization2/TokenListSimilarityTesterTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/TokenListSimilarityTesterTests.java similarity index 97% rename from x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization2/TokenListSimilarityTesterTests.java rename to x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/TokenListSimilarityTesterTests.java index b882a842fcd79..43c053307e30d 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization2/TokenListSimilarityTesterTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/TokenListSimilarityTesterTests.java @@ -5,9 +5,9 @@ * 2.0. */ -package org.elasticsearch.xpack.ml.aggs.categorization2; +package org.elasticsearch.xpack.ml.aggs.categorization; -import org.elasticsearch.xpack.ml.aggs.categorization2.TokenListCategory.TokenAndWeight; +import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategory.TokenAndWeight; import java.util.List; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization2/CategorizeTextAggregationBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization2/CategorizeTextAggregationBuilderTests.java deleted file mode 100644 index 039a906514d92..0000000000000 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization2/CategorizeTextAggregationBuilderTests.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.ml.aggs.categorization2; - -import org.elasticsearch.plugins.Plugin; -import org.elasticsearch.search.aggregations.BaseAggregationTestCase; -import org.elasticsearch.xpack.ml.MachineLearning; -import org.elasticsearch.xpack.ml.job.config.CategorizationAnalyzerConfigTests; - -import java.util.Collection; -import java.util.Collections; -import java.util.stream.Collectors; -import java.util.stream.Stream; - -public class CategorizeTextAggregationBuilderTests extends BaseAggregationTestCase { - - @Override - protected Collection> getExtraPlugins() { - return Collections.singletonList(MachineLearning.class); - } - - @Override - protected CategorizeTextAggregationBuilder createTestAggregatorBuilder() { - CategorizeTextAggregationBuilder builder = new CategorizeTextAggregationBuilder(randomAlphaOfLength(10), randomAlphaOfLength(10)); - final boolean setFilters = randomBoolean(); - if (setFilters) { - builder.setCategorizationFilters(Stream.generate(() -> randomAlphaOfLength(10)).limit(5).collect(Collectors.toList())); - } - if (setFilters == false) { - builder.setCategorizationAnalyzerConfig(CategorizationAnalyzerConfigTests.createRandomized().build()); - } - if (randomBoolean()) { - builder.setSimilarityThreshold(randomIntBetween(1, 100)); - } - if (randomBoolean()) { - builder.minDocCount(randomLongBetween(1, 100)); - } - if (randomBoolean()) { - builder.shardMinDocCount(randomLongBetween(1, 100)); - } - if (randomBoolean()) { - builder.size(randomIntBetween(1, 100)); - } - if (randomBoolean()) { - builder.shardSize(randomIntBetween(1, 100)); - } - return builder; - } -} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization2/CategorizeTextAggregatorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization2/CategorizeTextAggregatorTests.java deleted file mode 100644 index e380427efa43a..0000000000000 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization2/CategorizeTextAggregatorTests.java +++ /dev/null @@ -1,335 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.ml.aggs.categorization2; - -import org.apache.lucene.document.SortedNumericDocValuesField; -import org.apache.lucene.document.StoredField; -import org.apache.lucene.search.MatchAllDocsQuery; -import org.apache.lucene.tests.index.RandomIndexWriter; -import org.apache.lucene.util.BytesRef; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.env.Environment; -import org.elasticsearch.env.TestEnvironment; -import org.elasticsearch.index.mapper.TextFieldMapper; -import org.elasticsearch.indices.analysis.AnalysisModule; -import org.elasticsearch.plugins.SearchPlugin; -import org.elasticsearch.search.aggregations.AggregatorTestCase; -import org.elasticsearch.search.aggregations.bucket.histogram.Histogram; -import org.elasticsearch.search.aggregations.bucket.histogram.HistogramAggregationBuilder; -import org.elasticsearch.search.aggregations.bucket.histogram.InternalHistogram; -import org.elasticsearch.search.aggregations.metrics.Avg; -import org.elasticsearch.search.aggregations.metrics.AvgAggregationBuilder; -import org.elasticsearch.search.aggregations.metrics.Max; -import org.elasticsearch.search.aggregations.metrics.MaxAggregationBuilder; -import org.elasticsearch.search.aggregations.metrics.Min; -import org.elasticsearch.search.aggregations.metrics.MinAggregationBuilder; -import org.elasticsearch.xpack.ml.MachineLearning; - -import java.io.IOException; -import java.util.Arrays; -import java.util.List; - -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.hasSize; - -public class CategorizeTextAggregatorTests extends AggregatorTestCase { - - @Override - protected AnalysisModule createAnalysisModule() throws Exception { - return new AnalysisModule( - TestEnvironment.newEnvironment( - Settings.builder().put(Environment.PATH_HOME_SETTING.getKey(), createTempDir().toString()).build() - ), - List.of(new MachineLearning(Settings.EMPTY)) - ); - } - - @Override - protected List getSearchPlugins() { - return List.of(new MachineLearning(Settings.EMPTY)); - } - - private static final String TEXT_FIELD_NAME = "text"; - private static final String NUMERIC_FIELD_NAME = "value"; - - public void testCategorizationWithoutSubAggs() throws Exception { - testCase( - new CategorizeTextAggregationBuilder("my_agg", TEXT_FIELD_NAME), - new MatchAllDocsQuery(), - CategorizeTextAggregatorTests::writeTestDocs, - (InternalCategorizationAggregation result) -> { - assertThat(result.getBuckets(), hasSize(2)); - assertThat(result.getBuckets().get(0).getDocCount(), equalTo(6L)); - assertThat(result.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); - assertThat(result.getBuckets().get(1).getDocCount(), equalTo(2L)); - assertThat( - result.getBuckets().get(1).getKeyAsString(), - equalTo("Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception") - ); - }, - new TextFieldMapper.TextFieldType(TEXT_FIELD_NAME), - longField(NUMERIC_FIELD_NAME) - ); - } - - public void testCategorizationWithSubAggs() throws Exception { - CategorizeTextAggregationBuilder aggBuilder = new CategorizeTextAggregationBuilder("my_agg", TEXT_FIELD_NAME).subAggregation( - new MaxAggregationBuilder("max").field(NUMERIC_FIELD_NAME) - ) - .subAggregation(new AvgAggregationBuilder("avg").field(NUMERIC_FIELD_NAME)) - .subAggregation(new MinAggregationBuilder("min").field(NUMERIC_FIELD_NAME)); - testCase( - aggBuilder, - new MatchAllDocsQuery(), - CategorizeTextAggregatorTests::writeTestDocs, - (InternalCategorizationAggregation result) -> { - assertThat(result.getBuckets(), hasSize(2)); - assertThat(result.getBuckets().get(0).getDocCount(), equalTo(6L)); - assertThat(result.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); - assertThat(((Max) result.getBuckets().get(0).getAggregations().get("max")).value(), equalTo(5.0)); - assertThat(((Min) result.getBuckets().get(0).getAggregations().get("min")).value(), equalTo(0.0)); - assertThat(((Avg) result.getBuckets().get(0).getAggregations().get("avg")).getValue(), equalTo(2.5)); - - assertThat(result.getBuckets().get(1).getDocCount(), equalTo(2L)); - assertThat( - result.getBuckets().get(1).getKeyAsString(), - equalTo("Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception") - ); - assertThat(((Max) result.getBuckets().get(1).getAggregations().get("max")).value(), equalTo(4.0)); - assertThat(((Min) result.getBuckets().get(1).getAggregations().get("min")).value(), equalTo(0.0)); - assertThat(((Avg) result.getBuckets().get(1).getAggregations().get("avg")).getValue(), equalTo(2.0)); - }, - new TextFieldMapper.TextFieldType(TEXT_FIELD_NAME), - longField(NUMERIC_FIELD_NAME) - ); - } - - public void testCategorizationWithMultiBucketSubAggs() throws Exception { - CategorizeTextAggregationBuilder aggBuilder = new CategorizeTextAggregationBuilder("my_agg", TEXT_FIELD_NAME).subAggregation( - new HistogramAggregationBuilder("histo").field(NUMERIC_FIELD_NAME) - .interval(2) - .subAggregation(new MaxAggregationBuilder("max").field(NUMERIC_FIELD_NAME)) - .subAggregation(new AvgAggregationBuilder("avg").field(NUMERIC_FIELD_NAME)) - .subAggregation(new MinAggregationBuilder("min").field(NUMERIC_FIELD_NAME)) - ); - testCase( - aggBuilder, - new MatchAllDocsQuery(), - CategorizeTextAggregatorTests::writeTestDocs, - (InternalCategorizationAggregation result) -> { - assertThat(result.getBuckets(), hasSize(2)); - assertThat(result.getBuckets().get(0).getDocCount(), equalTo(6L)); - assertThat(result.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); - Histogram histo = result.getBuckets().get(0).getAggregations().get("histo"); - assertThat(histo.getBuckets(), hasSize(3)); - for (Histogram.Bucket bucket : histo.getBuckets()) { - assertThat(bucket.getDocCount(), equalTo(2L)); - } - assertThat(((Max) histo.getBuckets().get(0).getAggregations().get("max")).value(), equalTo(1.0)); - assertThat(((Min) histo.getBuckets().get(0).getAggregations().get("min")).value(), equalTo(0.0)); - assertThat(((Avg) histo.getBuckets().get(0).getAggregations().get("avg")).getValue(), equalTo(0.5)); - assertThat(((Max) histo.getBuckets().get(1).getAggregations().get("max")).value(), equalTo(3.0)); - assertThat(((Min) histo.getBuckets().get(1).getAggregations().get("min")).value(), equalTo(2.0)); - assertThat(((Avg) histo.getBuckets().get(1).getAggregations().get("avg")).getValue(), equalTo(2.5)); - assertThat(((Max) histo.getBuckets().get(2).getAggregations().get("max")).value(), equalTo(5.0)); - assertThat(((Min) histo.getBuckets().get(2).getAggregations().get("min")).value(), equalTo(4.0)); - assertThat(((Avg) histo.getBuckets().get(2).getAggregations().get("avg")).getValue(), equalTo(4.5)); - - assertThat(result.getBuckets().get(1).getDocCount(), equalTo(2L)); - assertThat( - result.getBuckets().get(1).getKeyAsString(), - equalTo("Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception") - ); - histo = result.getBuckets().get(1).getAggregations().get("histo"); - assertThat(histo.getBuckets(), hasSize(3)); - assertThat(histo.getBuckets().get(0).getDocCount(), equalTo(1L)); - assertThat(histo.getBuckets().get(1).getDocCount(), equalTo(0L)); - assertThat(histo.getBuckets().get(2).getDocCount(), equalTo(1L)); - assertThat(((Avg) histo.getBuckets().get(0).getAggregations().get("avg")).getValue(), equalTo(0.0)); - assertThat(((Avg) histo.getBuckets().get(2).getAggregations().get("avg")).getValue(), equalTo(4.0)); - }, - new TextFieldMapper.TextFieldType(TEXT_FIELD_NAME), - longField(NUMERIC_FIELD_NAME) - ); - } - - public void testCategorizationAsSubAgg() throws Exception { - HistogramAggregationBuilder aggBuilder = new HistogramAggregationBuilder("histo").field(NUMERIC_FIELD_NAME) - .interval(2) - .subAggregation( - new CategorizeTextAggregationBuilder("my_agg", TEXT_FIELD_NAME).subAggregation( - new MaxAggregationBuilder("max").field(NUMERIC_FIELD_NAME) - ) - .subAggregation(new AvgAggregationBuilder("avg").field(NUMERIC_FIELD_NAME)) - .subAggregation(new MinAggregationBuilder("min").field(NUMERIC_FIELD_NAME)) - ); - testCase(aggBuilder, new MatchAllDocsQuery(), CategorizeTextAggregatorTests::writeTestDocs, (InternalHistogram result) -> { - assertThat(result.getBuckets(), hasSize(3)); - - // First histo bucket - assertThat(result.getBuckets().get(0).getDocCount(), equalTo(3L)); - InternalCategorizationAggregation categorizationAggregation = result.getBuckets().get(0).getAggregations().get("my_agg"); - assertThat(categorizationAggregation.getBuckets(), hasSize(2)); - assertThat(categorizationAggregation.getBuckets().get(0).getDocCount(), equalTo(2L)); - assertThat(categorizationAggregation.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); - assertThat(((Max) categorizationAggregation.getBuckets().get(0).getAggregations().get("max")).value(), equalTo(1.0)); - assertThat(((Min) categorizationAggregation.getBuckets().get(0).getAggregations().get("min")).value(), equalTo(0.0)); - assertThat(((Avg) categorizationAggregation.getBuckets().get(0).getAggregations().get("avg")).getValue(), equalTo(0.5)); - - assertThat(categorizationAggregation.getBuckets().get(1).getDocCount(), equalTo(1L)); - assertThat( - categorizationAggregation.getBuckets().get(1).getKeyAsString(), - equalTo("Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception") - ); - assertThat(((Max) categorizationAggregation.getBuckets().get(1).getAggregations().get("max")).value(), equalTo(0.0)); - assertThat(((Min) categorizationAggregation.getBuckets().get(1).getAggregations().get("min")).value(), equalTo(0.0)); - assertThat(((Avg) categorizationAggregation.getBuckets().get(1).getAggregations().get("avg")).getValue(), equalTo(0.0)); - - // Second histo bucket - assertThat(result.getBuckets().get(1).getDocCount(), equalTo(2L)); - categorizationAggregation = result.getBuckets().get(1).getAggregations().get("my_agg"); - assertThat(categorizationAggregation.getBuckets(), hasSize(1)); - assertThat(categorizationAggregation.getBuckets().get(0).getDocCount(), equalTo(2L)); - assertThat(categorizationAggregation.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); - assertThat(((Max) categorizationAggregation.getBuckets().get(0).getAggregations().get("max")).value(), equalTo(3.0)); - assertThat(((Min) categorizationAggregation.getBuckets().get(0).getAggregations().get("min")).value(), equalTo(2.0)); - assertThat(((Avg) categorizationAggregation.getBuckets().get(0).getAggregations().get("avg")).getValue(), equalTo(2.5)); - - // Third histo bucket - assertThat(result.getBuckets().get(2).getDocCount(), equalTo(3L)); - categorizationAggregation = result.getBuckets().get(2).getAggregations().get("my_agg"); - assertThat(categorizationAggregation.getBuckets(), hasSize(2)); - assertThat(categorizationAggregation.getBuckets().get(0).getDocCount(), equalTo(2L)); - assertThat(categorizationAggregation.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); - assertThat(((Max) categorizationAggregation.getBuckets().get(0).getAggregations().get("max")).value(), equalTo(5.0)); - assertThat(((Min) categorizationAggregation.getBuckets().get(0).getAggregations().get("min")).value(), equalTo(4.0)); - assertThat(((Avg) categorizationAggregation.getBuckets().get(0).getAggregations().get("avg")).getValue(), equalTo(4.5)); - - assertThat(categorizationAggregation.getBuckets().get(1).getDocCount(), equalTo(1L)); - assertThat( - categorizationAggregation.getBuckets().get(1).getKeyAsString(), - equalTo("Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception") - ); - assertThat(((Max) categorizationAggregation.getBuckets().get(1).getAggregations().get("max")).value(), equalTo(4.0)); - assertThat(((Min) categorizationAggregation.getBuckets().get(1).getAggregations().get("min")).value(), equalTo(4.0)); - assertThat(((Avg) categorizationAggregation.getBuckets().get(1).getAggregations().get("avg")).getValue(), equalTo(4.0)); - }, new TextFieldMapper.TextFieldType(TEXT_FIELD_NAME), longField(NUMERIC_FIELD_NAME)); - } - - public void testCategorizationWithSubAggsManyDocs() throws Exception { - CategorizeTextAggregationBuilder aggBuilder = new CategorizeTextAggregationBuilder("my_agg", TEXT_FIELD_NAME).subAggregation( - new HistogramAggregationBuilder("histo").field(NUMERIC_FIELD_NAME) - .interval(2) - .subAggregation(new MaxAggregationBuilder("max").field(NUMERIC_FIELD_NAME)) - .subAggregation(new AvgAggregationBuilder("avg").field(NUMERIC_FIELD_NAME)) - .subAggregation(new MinAggregationBuilder("min").field(NUMERIC_FIELD_NAME)) - ); - testCase( - aggBuilder, - new MatchAllDocsQuery(), - CategorizeTextAggregatorTests::writeManyTestDocs, - (InternalCategorizationAggregation result) -> { - assertThat(result.getBuckets(), hasSize(2)); - assertThat(result.getBuckets().get(0).getDocCount(), equalTo(30000L)); - assertThat(result.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); - Histogram histo = result.getBuckets().get(0).getAggregations().get("histo"); - assertThat(histo.getBuckets(), hasSize(3)); - for (Histogram.Bucket bucket : histo.getBuckets()) { - assertThat(bucket.getDocCount(), equalTo(10000L)); - } - assertThat(((Max) histo.getBuckets().get(0).getAggregations().get("max")).value(), equalTo(1.0)); - assertThat(((Min) histo.getBuckets().get(0).getAggregations().get("min")).value(), equalTo(0.0)); - assertThat(((Avg) histo.getBuckets().get(0).getAggregations().get("avg")).getValue(), equalTo(0.5)); - assertThat(((Max) histo.getBuckets().get(1).getAggregations().get("max")).value(), equalTo(3.0)); - assertThat(((Min) histo.getBuckets().get(1).getAggregations().get("min")).value(), equalTo(2.0)); - assertThat(((Avg) histo.getBuckets().get(1).getAggregations().get("avg")).getValue(), equalTo(2.5)); - assertThat(((Max) histo.getBuckets().get(2).getAggregations().get("max")).value(), equalTo(5.0)); - assertThat(((Min) histo.getBuckets().get(2).getAggregations().get("min")).value(), equalTo(4.0)); - assertThat(((Avg) histo.getBuckets().get(2).getAggregations().get("avg")).getValue(), equalTo(4.5)); - - assertThat(result.getBuckets().get(1).getDocCount(), equalTo(10000L)); - assertThat( - result.getBuckets().get(1).getKeyAsString(), - equalTo("Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception") - ); - histo = result.getBuckets().get(1).getAggregations().get("histo"); - assertThat(histo.getBuckets(), hasSize(3)); - assertThat(histo.getBuckets().get(0).getDocCount(), equalTo(5000L)); - assertThat(histo.getBuckets().get(1).getDocCount(), equalTo(0L)); - assertThat(histo.getBuckets().get(2).getDocCount(), equalTo(5000L)); - assertThat(((Avg) histo.getBuckets().get(0).getAggregations().get("avg")).getValue(), equalTo(0.0)); - assertThat(((Avg) histo.getBuckets().get(2).getAggregations().get("avg")).getValue(), equalTo(4.0)); - }, - new TextFieldMapper.TextFieldType(TEXT_FIELD_NAME), - longField(NUMERIC_FIELD_NAME) - ); - } - - private static void writeTestDocs(RandomIndexWriter w) throws IOException { - w.addDocument( - Arrays.asList( - new StoredField("_source", new BytesRef("{\"text\":\"Node 1 started\"}")), - new SortedNumericDocValuesField(NUMERIC_FIELD_NAME, 0) - ) - ); - w.addDocument( - Arrays.asList( - new StoredField("_source", new BytesRef("{\"text\":\"Node 1 started\"}")), - new SortedNumericDocValuesField(NUMERIC_FIELD_NAME, 1) - ) - ); - w.addDocument( - Arrays.asList( - new StoredField( - "_source", - new BytesRef("{\"text\":\"Failed to shutdown [error org.aaaa.bbbb.Cccc line 54 caused by foo exception]\"}") - ), - new SortedNumericDocValuesField(NUMERIC_FIELD_NAME, 0) - ) - ); - w.addDocument( - Arrays.asList( - new StoredField( - "_source", - new BytesRef("{\"text\":\"Failed to shutdown [error org.aaaa.bbbb.Cccc line 54 caused by foo exception]\"}") - ), - new SortedNumericDocValuesField(NUMERIC_FIELD_NAME, 4) - ) - ); - w.addDocument( - Arrays.asList( - new StoredField("_source", new BytesRef("{\"text\":\"Node 2 started\"}")), - new SortedNumericDocValuesField(NUMERIC_FIELD_NAME, 2) - ) - ); - w.addDocument( - Arrays.asList( - new StoredField("_source", new BytesRef("{\"text\":\"Node 2 started\"}")), - new SortedNumericDocValuesField(NUMERIC_FIELD_NAME, 3) - ) - ); - w.addDocument( - Arrays.asList( - new StoredField("_source", new BytesRef("{\"text\":\"Node 3 started\"}")), - new SortedNumericDocValuesField(NUMERIC_FIELD_NAME, 4) - ) - ); - w.addDocument( - Arrays.asList( - new StoredField("_source", new BytesRef("{\"text\":\"Node 3 started\"}")), - new SortedNumericDocValuesField(NUMERIC_FIELD_NAME, 5) - ) - ); - } - - private static void writeManyTestDocs(RandomIndexWriter w) throws IOException { - for (int i = 0; i < 5000; i++) { - writeTestDocs(w); - } - } -} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization2/InternalCategorizationAggregationTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization2/InternalCategorizationAggregationTests.java deleted file mode 100644 index a2e666e1e2bdb..0000000000000 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization2/InternalCategorizationAggregationTests.java +++ /dev/null @@ -1,120 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.ml.aggs.categorization2; - -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.util.BigArrays; -import org.elasticsearch.common.util.BytesRefHash; -import org.elasticsearch.common.util.CollectionUtils; -import org.elasticsearch.plugins.SearchPlugin; -import org.elasticsearch.search.aggregations.Aggregation; -import org.elasticsearch.search.aggregations.InternalAggregations; -import org.elasticsearch.search.aggregations.ParsedMultiBucketAggregation; -import org.elasticsearch.test.InternalMultiBucketAggregationTestCase; -import org.elasticsearch.xcontent.NamedXContentRegistry; -import org.elasticsearch.xcontent.ParseField; -import org.elasticsearch.xpack.ml.MachineLearning; -import org.junit.After; -import org.junit.Before; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.function.Predicate; -import java.util.stream.Collectors; -import java.util.stream.Stream; - -import static org.hamcrest.Matchers.equalTo; - -public class InternalCategorizationAggregationTests extends InternalMultiBucketAggregationTestCase { - - private CategorizationBytesRefHash bytesRefHash; - - @Before - public void createHash() { - bytesRefHash = new CategorizationBytesRefHash(new BytesRefHash(2048, BigArrays.NON_RECYCLING_INSTANCE)); - } - - @After - public void destroyHash() { - bytesRefHash.close(); - } - - @Override - protected SearchPlugin registerPlugin() { - return new MachineLearning(Settings.EMPTY); - } - - @Override - protected List getNamedXContents() { - return CollectionUtils.appendToCopy( - super.getNamedXContents(), - new NamedXContentRegistry.Entry( - Aggregation.class, - new ParseField(CategorizeTextAggregationBuilder.NAME), - (p, c) -> ParsedCategorization.fromXContent(p, (String) c) - ) - ); - } - - @Override - protected void assertReduced(InternalCategorizationAggregation reduced, List inputs) { - Map reducedCounts = toCounts(reduced.getBuckets().stream()); - Map totalCounts = toCounts(inputs.stream().map(InternalCategorizationAggregation::getBuckets).flatMap(List::stream)); - - Map expectedReducedCounts = new HashMap<>(totalCounts); - expectedReducedCounts.keySet().retainAll(reducedCounts.keySet()); - assertThat(reducedCounts, equalTo(expectedReducedCounts)); - } - - @Override - protected Predicate excludePathsFromXContentInsertion() { - return p -> p.contains("key"); - } - - @Override - protected InternalCategorizationAggregation createTestInstance( - String name, - Map metadata, - InternalAggregations aggregations - ) { - List buckets = new ArrayList<>(); - final int numBuckets = randomNumberOfBuckets(); - for (int i = 0; i < numBuckets; ++i) { - buckets.add( - new InternalCategorizationAggregation.Bucket(SerializableTokenListCategoryTests.createTestInstance(bytesRefHash), -1) - ); - } - Collections.sort(buckets); - return new InternalCategorizationAggregation( - name, - randomIntBetween(10, 100), - randomLongBetween(1, 10), - randomIntBetween(1, 100), - metadata, - buckets - ); - } - - @Override - protected Class> implementationClass() { - return ParsedCategorization.class; - } - - private static Map toCounts(Stream buckets) { - return buckets.collect( - Collectors.toMap( - InternalCategorizationAggregation.Bucket::getKey, - InternalCategorizationAggregation.Bucket::getDocCount, - Long::sum - ) - ); - } -} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization2/ParsedCategorization.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization2/ParsedCategorization.java deleted file mode 100644 index ba44c87a1b2b8..0000000000000 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization2/ParsedCategorization.java +++ /dev/null @@ -1,169 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.ml.aggs.categorization2; - -import org.apache.lucene.util.BytesRef; -import org.elasticsearch.common.CheckedBiConsumer; -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.xcontent.XContentParserUtils; -import org.elasticsearch.search.aggregations.Aggregation; -import org.elasticsearch.search.aggregations.Aggregations; -import org.elasticsearch.search.aggregations.ParsedMultiBucketAggregation; -import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation; -import org.elasticsearch.xcontent.ObjectParser; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xpack.core.ml.job.results.CategoryDefinition; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.function.Supplier; - -// TODO: how close to the actual InternalCategorizationAggregation.Bucket class does this have to be to add any value? -class ParsedCategorization extends ParsedMultiBucketAggregation { - - @Override - public String getType() { - return CategorizeTextAggregationBuilder.NAME; - } - - private static final ObjectParser PARSER = new ObjectParser<>( - ParsedCategorization.class.getSimpleName(), - true, - ParsedCategorization::new - ); - static { - declareMultiBucketAggregationFields(PARSER, ParsedBucket::fromXContent, ParsedBucket::fromXContent); - } - - public static ParsedCategorization fromXContent(XContentParser parser, String name) throws IOException { - ParsedCategorization aggregation = PARSER.parse(parser, null); - aggregation.setName(name); - return aggregation; - } - - @Override - public List getBuckets() { - return buckets; - } - - public static class ParsedBucket extends ParsedMultiBucketAggregation.ParsedBucket implements MultiBucketsAggregation.Bucket { - - private InternalCategorizationAggregation.BucketKey key; - private int maxMatchingLength; - - protected void setKeyAsString(String keyAsString) { - if (keyAsString == null) { - key = null; - return; - } - if (keyAsString.isEmpty()) { - key = new InternalCategorizationAggregation.BucketKey(new BytesRef[0]); - return; - } - String[] split = Strings.tokenizeToStringArray(keyAsString, " "); - key = new InternalCategorizationAggregation.BucketKey( - split == null - ? new BytesRef[] { new BytesRef(keyAsString) } - : Arrays.stream(split).map(BytesRef::new).toArray(BytesRef[]::new) - ); - } - - private void setMaxMatchingLength(int maxMatchingLength) { - this.maxMatchingLength = maxMatchingLength; - } - - @Override - public Object getKey() { - return key; - } - - @Override - public String getKeyAsString() { - return key.toString(); - } - - @Override - protected XContentBuilder keyToXContent(XContentBuilder builder) throws IOException { - return builder.field(CommonFields.KEY.getPreferredName(), getKey()); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - keyToXContent(builder); - builder.field(CategoryDefinition.MAX_MATCHING_LENGTH.getPreferredName(), maxMatchingLength); - builder.field(CommonFields.DOC_COUNT.getPreferredName(), getDocCount()); - getAggregations().toXContentInternal(builder, params); - builder.endObject(); - return builder; - } - - static InternalCategorizationAggregation.BucketKey parsedKey(final XContentParser parser) throws IOException { - if (parser.currentToken() == XContentParser.Token.VALUE_STRING) { - String toSplit = parser.text(); - String[] split = Strings.tokenizeToStringArray(toSplit, " "); - return new InternalCategorizationAggregation.BucketKey( - split == null - ? new BytesRef[] { new BytesRef(toSplit) } - : Arrays.stream(split).map(BytesRef::new).toArray(BytesRef[]::new) - ); - } else { - return new InternalCategorizationAggregation.BucketKey( - XContentParserUtils.parseList(parser, p -> new BytesRef(p.binaryValue())).toArray(BytesRef[]::new) - ); - } - } - - protected static ParsedBucket parseCategorizationBucketXContent( - final XContentParser parser, - final Supplier bucketSupplier, - final CheckedBiConsumer keyConsumer - ) throws IOException { - final ParsedBucket bucket = bucketSupplier.get(); - XContentParser.Token token; - String currentFieldName = parser.currentName(); - - List aggregations = new ArrayList<>(); - while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { - if (token == XContentParser.Token.FIELD_NAME) { - currentFieldName = parser.currentName(); - } else if (token.isValue()) { - if (CommonFields.KEY_AS_STRING.getPreferredName().equals(currentFieldName)) { - bucket.setKeyAsString(parser.text()); - } else if (CommonFields.KEY.getPreferredName().equals(currentFieldName)) { - keyConsumer.accept(parser, bucket); - } else if (CommonFields.DOC_COUNT.getPreferredName().equals(currentFieldName)) { - bucket.setDocCount(parser.longValue()); - } else if (CategoryDefinition.MAX_MATCHING_LENGTH.getPreferredName().equals(currentFieldName)) { - bucket.setMaxMatchingLength(parser.intValue()); - } - } else if (token == XContentParser.Token.START_OBJECT) { - if (CommonFields.KEY.getPreferredName().equals(currentFieldName)) { - keyConsumer.accept(parser, bucket); - } else { - XContentParserUtils.parseTypedKeysObject( - parser, - Aggregation.TYPED_KEYS_DELIMITER, - Aggregation.class, - aggregations::add - ); - } - } - } - bucket.setAggregations(new Aggregations(aggregations)); - return bucket; - } - - static ParsedBucket fromXContent(final XContentParser parser) throws IOException { - return parseCategorizationBucketXContent(parser, ParsedBucket::new, (p, bucket) -> bucket.key = parsedKey(p)); - } - } -} diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/categorization_agg.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/categorization_agg.yml index a484755a79bfc..45ed89c4c77b4 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/categorization_agg.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/categorization_agg.yml @@ -71,8 +71,6 @@ setup: "categorize_text": { "field": "text", "size": 10, - "max_unique_tokens": 2, - "max_matched_tokens": 1, "similarity_threshold": 11 } } @@ -81,13 +79,13 @@ setup: - length: { aggregations.categories.buckets: 2 } - match: { aggregations.categories.buckets.0.doc_count: 4 } - - match: { aggregations.categories.buckets.0.key: "Node *" } + - match: { aggregations.categories.buckets.0.key: "Node" } - match: { aggregations.categories.buckets.1.doc_count: 3 } - - match: { aggregations.categories.buckets.1.key: "User Foo logging *" } + - match: { aggregations.categories.buckets.1.key: "User Foo logging" } --- "Test categorization aggregation against unsupported field": - do: - catch: /categorize_text agg \[categories\] only works on analyzable text fields/ + catch: /categorize_text agg \[categories\] only works on text and keyword fields/ search: index: to_categorize body: > @@ -105,71 +103,6 @@ setup: --- "Test categorization aggregation with poor settings": - - do: - catch: /\[max_unique_tokens\] must be greater than 0 and less than or equal \[100\]/ - search: - index: to_categorize - body: > - { - "size": 0, - "aggs": { - "categories": { - "categorize_text": { - "field": "text", - "max_unique_tokens": -2 - } - } - } - } - - do: - catch: /\[max_unique_tokens\] must be greater than 0 and less than or equal \[100\]/ - search: - index: to_categorize - body: > - { - "size": 0, - "aggs": { - "categories": { - "categorize_text": { - "field": "text", - "max_unique_tokens": 101 - } - } - } - } - - do: - catch: /\[max_matched_tokens\] must be greater than 0 and less than or equal \[100\]/ - search: - index: to_categorize - body: > - { - "size": 0, - "aggs": { - "categories": { - "categorize_text": { - "field": "text", - "max_matched_tokens": -2 - } - } - } - } - - do: - catch: /\[max_matched_tokens\] must be greater than 0 and less than or equal \[100\]/ - search: - index: to_categorize - body: > - { - "size": 0, - "aggs": { - "categories": { - "categorize_text": { - "field": "text", - "max_matched_tokens": 101 - } - } - } - } - - do: catch: /\[similarity_threshold\] must be in the range \[1, 100\]/ search: