From 7a7fffcb5a852561a473a022861ce5faec2725ee Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Mon, 4 Oct 2021 11:49:16 -0400 Subject: [PATCH] [ML] Text/Log categorization multi-bucket aggregation (#71752) This commit adds a new multi-bucket aggregation: `categorize_text` The aggregation follows a similar design to significant text in that it reads from `_source` and re-analyzes the the text as it is read. Key difference is that it does not use the indexed field's analyzer, but instead relies on the `ml_standard` tokenizer with specialized ML token filters. The tokenizer + filters are the same that machine learning categorization anomaly jobs utilize. The high level logical flow is as follows: - at each shard, read in the text field with a custom analyzer using `ml_standard` tokenizer - Read in the particular tokens from the analyzer - Feed these tokens to a token tree algorithm (an adaptation of the drain categorization algorithm) - Gather the individual log categories (the leaf nodes), sort them by doc_count, ship those buckets to be merged - Merge all buckets that have the EXACT same key - Once all buckets are merged, pass those keys + counts to a new token tree for additional merging - That tree builds the final buckets and that is returned to the user Algorithm explanation: - Each log is parsed with the ml-standard tokenizer - each token is passed into a token tree - For `max_match_token` each token is stored in the tree and at `max_match_token+1` (or `len(tokens)`) a log group is created - If another log group exists at that leaf, merge it if they have `similarity_threshold` percentage of tokens in common - merging simply replaces tokens that are different in the group with `*` - If a layer in the tree has `max_unique_tokens` we add a `*` child and any new tokens are passed through there. Catch here is that on the final merge, we first attempt to merge together subtrees with the smallest number of documents. Especially if the new sub tree has more documents counted. ## Aggregation configuration. Here is an example on some openstack logs ```js POST openstack/_search?size=0 { "aggs": { "categories": { "categorize_text": { "field": "message", // The field to categorize "similarity_threshold": 20, // merge log groups if they are this similar "max_unique_tokens": 20, // Max Number of children per token position "max_match_token": 4, // Maximum tokens to build prefix trees "size": 1 } } } } ``` This will return buckets like ```json "aggregations" : { "categories" : { "buckets" : [ { "doc_count" : 806, "key" : "nova-api.log.1.2017-05-16_13 INFO nova.osapi_compute.wsgi.server * HTTP/1.1 status len time" } ] } } ``` --- .../AggConstructionContentionBenchmark.java | 17 + docs/build.gradle | 33 ++ docs/reference/aggregations/bucket.asciidoc | 2 + .../categorize-text-aggregation.asciidoc | 469 ++++++++++++++++++ .../elasticsearch/search/SearchService.java | 1 + .../ParsedMultiBucketAggregation.java | 2 +- .../support/AggregationContext.java | 45 ++ .../index/mapper/MapperServiceTestCase.java | 17 + .../aggregations/AggregatorTestCase.java | 17 + .../test/AbstractBuilderTestCase.java | 14 +- .../core/ml/job/config/AnalysisConfig.java | 2 +- .../config/CategorizationAnalyzerConfig.java | 6 +- .../ml/qa/ml-with-security/build.gradle | 3 + .../CategorizationAggregationIT.java | 160 ++++++ .../xpack/ml/MachineLearning.java | 15 + .../CategorizationBytesRefHash.java | 83 ++++ .../CategorizationTokenTree.java | 143 ++++++ .../CategorizeTextAggregationBuilder.java | 349 +++++++++++++ .../CategorizeTextAggregator.java | 239 +++++++++ .../CategorizeTextAggregatorFactory.java | 112 +++++ .../InternalCategorizationAggregation.java | 453 +++++++++++++++++ .../categorization/TextCategorization.java | 117 +++++ .../ml/aggs/categorization/TreeNode.java | 406 +++++++++++++++ .../UnmappedCategorizationAggregation.java | 71 +++ .../CategorizationAnalyzer.java | 14 +- .../xpack/ml/LocalStateMachineLearning.java | 18 + ...CategorizeTextAggregationBuilderTests.java | 60 +++ .../CategorizeTextAggregatorTests.java | 342 +++++++++++++ .../categorization/InnerTreeNodeTests.java | 123 +++++ ...nternalCategorizationAggregationTests.java | 117 +++++ .../categorization/LeafTreeNodeTests.java | 88 ++++ .../categorization/ParsedCategorization.java | 113 +++++ .../TextCategorizationTests.java | 75 +++ .../test/ml/categorization_agg.yml | 153 ++++++ 34 files changed, 3871 insertions(+), 8 deletions(-) create mode 100644 docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc create mode 100644 x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/CategorizationAggregationIT.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationBytesRefHash.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilder.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorFactory.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorization.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/UnmappedCategorizationAggregation.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilderTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InnerTreeNodeTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregationTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LeafTreeNodeTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/ParsedCategorization.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorizationTests.java create mode 100644 x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/categorization_agg.yml diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/search/aggregations/AggConstructionContentionBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/search/aggregations/AggConstructionContentionBenchmark.java index f496608e1c273..43a2d1930d0e8 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/search/aggregations/AggConstructionContentionBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/search/aggregations/AggConstructionContentionBenchmark.java @@ -22,6 +22,7 @@ import org.elasticsearch.core.Releasables; import org.elasticsearch.index.Index; import org.elasticsearch.index.IndexSettings; +import org.elasticsearch.index.analysis.NameOrDefinition; import org.elasticsearch.index.analysis.NamedAnalyzer; import org.elasticsearch.index.cache.bitset.BitsetFilterCache; import org.elasticsearch.index.fielddata.IndexFieldData; @@ -197,6 +198,22 @@ public long nowInMillis() { return 0; } + @Override + public Analyzer getNamedAnalyzer(String analyzer) { + return null; + } + + @Override + public Analyzer buildCustomAnalyzer( + IndexSettings indexSettings, + boolean normalizer, + NameOrDefinition tokenizer, + List charFilters, + List tokenFilters + ) { + return null; + } + @Override protected IndexFieldData buildFieldData(MappedFieldType ft) { IndexFieldDataCache indexFieldDataCache = indicesFieldDataCache.buildIndexFieldDataCache(new IndexFieldDataCache.Listener() { diff --git a/docs/build.gradle b/docs/build.gradle index f84bc91cf7489..5ecaf7dc6faff 100644 --- a/docs/build.gradle +++ b/docs/build.gradle @@ -1071,6 +1071,39 @@ buildRestTests.setups['farequote_datafeed'] = buildRestTests.setups['farequote_j "indexes":"farequote" } ''' +buildRestTests.setups['categorize_text'] = ''' + - do: + indices.create: + index: log-messages + body: + settings: + number_of_shards: 1 + number_of_replicas: 0 + mappings: + properties: + time: + type: date + message: + type: text + + - do: + bulk: + index: log-messages + refresh: true + body: | + {"index": {"_id":"1"}} + {"time":"2016-02-07T00:01:00+0000", "message": "2016-02-07T00:00:00+0000 Node 3 shutting down"} + {"index": {"_id":"2"}} + {"time":"2016-02-07T00:02:00+0000", "message": "2016-02-07T00:00:00+0000 Node 5 starting up"} + {"index": {"_id":"3"}} + {"time":"2016-02-07T00:03:00+0000", "message": "2016-02-07T00:00:00+0000 Node 4 shutting down"} + {"index": {"_id":"4"}} + {"time":"2016-02-08T00:01:00+0000", "message": "2016-02-08T00:00:00+0000 Node 5 shutting down"} + {"index": {"_id":"5"}} + {"time":"2016-02-08T00:02:00+0000", "message": "2016-02-08T00:00:00+0000 User foo_325 logging on"} + {"index": {"_id":"6"}} + {"time":"2016-02-08T00:04:00+0000", "message": "2016-02-08T00:00:00+0000 User foo_864 logged off"} +''' buildRestTests.setups['server_metrics_index'] = ''' - do: indices.create: diff --git a/docs/reference/aggregations/bucket.asciidoc b/docs/reference/aggregations/bucket.asciidoc index 302e196caf3ce..dfdaca18e6cfb 100644 --- a/docs/reference/aggregations/bucket.asciidoc +++ b/docs/reference/aggregations/bucket.asciidoc @@ -20,6 +20,8 @@ include::bucket/adjacency-matrix-aggregation.asciidoc[] include::bucket/autodatehistogram-aggregation.asciidoc[] +include::bucket/categorize-text-aggregation.asciidoc[] + include::bucket/children-aggregation.asciidoc[] include::bucket/composite-aggregation.asciidoc[] diff --git a/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc b/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc new file mode 100644 index 0000000000000..cc0a0e787f844 --- /dev/null +++ b/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc @@ -0,0 +1,469 @@ +[[search-aggregations-bucket-categorize-text-aggregation]] +=== Categorize text aggregation +++++ +Categorize text +++++ + +experimental::[] + +A multi-bucket aggregation that groups semi-structured text into buckets. Each `text` field is re-analyzed +using a custom analyzer. The resulting tokens are then categorized creating buckets of similarly formatted +text values. This aggregation works best with machine generated text like system logs. + +NOTE: If you have considerable memory allocated to your JVM but are receiving circuit breaker exceptions from this + aggregation, you may be attempting to categorize text that is poorly formatted for categorization. Consider + adding `categorization_filters` or running under <> or + <> to explore the created categories. + +[[bucket-categorize-text-agg-syntax]] +==== Parameters + +`field`:: +(Required, string) +The semi-structured text field to categorize. + +`max_unique_tokens`:: +(Optional, integer, default: `50`) +The maximum number of unique tokens at any position up to `max_matched_tokens`. +Must be larger than 1. Smaller values use less memory and create fewer categories. +Larger values will use more memory and create narrower categories. + +`max_matched_tokens`:: +(Optional, integer, default: `5`) +The maximum number of token positions to match on before attempting to merge categories. +Larger values will use more memory and create narrower categories. + +Example: +`max_matched_tokens` of 2 would disallow merging of the categories +[`foo` `bar` `baz`] +[`foo` `baz` `bozo`] +As the first 2 tokens are required to match for the category. + +NOTE: Once `max_unique_tokens` is reached at a given position, a new `*` token is +added and all new tokens at that position are matched by the `*` token. + +`similarity_threshold`:: +(Optional, integer, default: `50`) +The minimum percentage of tokens that must match for text to be added to the +category bucket. +Must be between 1 and 100. The larger the value the narrower the categories. +Larger values will increase memory usage and create narrower categories. + +`categorization_filters`:: +(Optional, array of strings) +This property expects an array of regular expressions. The expressions +are used to filter out matching sequences from the categorization field values. +You can use this functionality to fine tune the categorization by excluding +sequences from consideration when categories are defined. For example, you can +exclude SQL statements that appear in your log files. This +property cannot be used at the same time as `categorization_analyzer`. If you +only want to define simple regular expression filters that are applied prior to +tokenization, setting this property is the easiest method. If you also want to +customize the tokenizer or post-tokenization filtering, use the +`categorization_analyzer` property instead and include the filters as +`pattern_replace` character filters. + +`categorization_analyzer`:: +(Optional, object or string) +The categorization analyzer specifies how the text is analyzed and tokenized before +being categorized. The syntax is very similar to that used to define the `analyzer` in the +<>. This +property cannot be used at the same time as `categorization_filters`. ++ +The `categorization_analyzer` field can be specified either as a string or as an +object. If it is a string it must refer to a +<> or one added by another plugin. If it +is an object it has the following properties: ++ +.Properties of `categorization_analyzer` +[%collapsible%open] +===== +`char_filter`:::: +(array of strings or objects) +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=char-filter] + +`tokenizer`:::: +(string or object) +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=tokenizer] + +`filter`:::: +(array of strings or objects) +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=filter] +===== +end::categorization-analyzer[] + +`shard_size`:: +(Optional, integer) +The number of categorization buckets to return from each shard before merging +all the results. + +`size`:: +(Optional, integer, default: `10`) +The number of buckets to return. + +`min_doc_count`:: +(Optional, integer) +The minimum number of documents for a bucket to be returned to the results. + +`shard_min_doc_count`:: +(Optional, integer) +The minimum number of documents for a bucket to be returned from the shard before +merging. + +==== Basic use + + +WARNING: Re-analyzing _large_ result sets will require a lot of time and memory. This aggregation should be + used in conjunction with <>. Additionally, you may consider + using the aggregation as a child of either the <> or + <> aggregation. + This will typically improve speed and memory use. + +Example: + +[source,console] +-------------------------------------------------- +POST log-messages/_search?filter_path=aggregations +{ + "aggs": { + "categories": { + "categorize_text": { + "field": "message" + } + } + } +} +-------------------------------------------------- +// TEST[setup:categorize_text] + +Response: + +[source,console-result] +-------------------------------------------------- +{ + "aggregations" : { + "categories" : { + "buckets" : [ + { + "doc_count" : 3, + "key" : "Node shutting down" + }, + { + "doc_count" : 1, + "key" : "Node starting up" + }, + { + "doc_count" : 1, + "key" : "User foo_325 logging on" + }, + { + "doc_count" : 1, + "key" : "User foo_864 logged off" + } + ] + } + } +} +-------------------------------------------------- + + +Here is an example using `categorization_filters` + +[source,console] +-------------------------------------------------- +POST log-messages/_search?filter_path=aggregations +{ + "aggs": { + "categories": { + "categorize_text": { + "field": "message", + "categorization_filters": ["\\w+\\_\\d{3}"] <1> + } + } + } +} +-------------------------------------------------- +// TEST[setup:categorize_text] + +<1> The filters to apply to the analyzed tokens. It filters + out tokens like `bar_123`. + +Note how the `foo_` tokens are not part of the +category results + +[source,console-result] +-------------------------------------------------- +{ + "aggregations" : { + "categories" : { + "buckets" : [ + { + "doc_count" : 3, + "key" : "Node shutting down" + }, + { + "doc_count" : 1, + "key" : "Node starting up" + }, + { + "doc_count" : 1, + "key" : "User logged off" + }, + { + "doc_count" : 1, + "key" : "User logging on" + } + ] + } + } +} +-------------------------------------------------- + +Here is an example using `categorization_filters`. +The default analyzer is a whitespace analyzer with a custom token filter +which filters out tokens that start with any number. +But, it may be that a token is a known highly-variable token (formatted usernames, emails, etc.). In that case, it is good to supply +custom `categorization_filters` to filter out those tokens for better categories. These filters will also reduce memory usage as fewer +tokens are held in memory for the categories. + +[source,console] +-------------------------------------------------- +POST log-messages/_search?filter_path=aggregations +{ + "aggs": { + "categories": { + "categorize_text": { + "field": "message", + "categorization_filters": ["\\w+\\_\\d{3}"], <1> + "max_matched_tokens": 2, <2> + "similarity_threshold": 30 <3> + } + } + } +} +-------------------------------------------------- +// TEST[setup:categorize_text] +<1> The filters to apply to the analyzed tokens. It filters +out tokens like `bar_123`. +<2> Require at least 2 tokens before the log categories attempt to merge together +<3> Require 30% of the tokens to match before expanding a log categories + to add a new log entry + +The resulting categories are now broad, matching the first token +and merging the log groups. + +[source,console-result] +-------------------------------------------------- +{ + "aggregations" : { + "categories" : { + "buckets" : [ + { + "doc_count" : 4, + "key" : "Node *" + }, + { + "doc_count" : 2, + "key" : "User *" + } + ] + } + } +} +-------------------------------------------------- + +This aggregation can have both sub-aggregations and itself be a sub-aggregation. This allows gathering the top daily categories and the +top sample doc as below. + +[source,console] +-------------------------------------------------- +POST log-messages/_search?filter_path=aggregations +{ + "aggs": { + "daily": { + "date_histogram": { + "field": "time", + "fixed_interval": "1d" + }, + "aggs": { + "categories": { + "categorize_text": { + "field": "message", + "categorization_filters": ["\\w+\\_\\d{3}"] + }, + "aggs": { + "hit": { + "top_hits": { + "size": 1, + "sort": ["time"], + "_source": "message" + } + } + } + } + } + } + } +} +-------------------------------------------------- +// TEST[setup:categorize_text] + +[source,console-result] +-------------------------------------------------- +{ + "aggregations" : { + "daily" : { + "buckets" : [ + { + "key_as_string" : "2016-02-07T00:00:00.000Z", + "key" : 1454803200000, + "doc_count" : 3, + "categories" : { + "buckets" : [ + { + "doc_count" : 2, + "key" : "Node shutting down", + "hit" : { + "hits" : { + "total" : { + "value" : 2, + "relation" : "eq" + }, + "max_score" : null, + "hits" : [ + { + "_index" : "log-messages", + "_id" : "1", + "_score" : null, + "_source" : { + "message" : "2016-02-07T00:00:00+0000 Node 3 shutting down" + }, + "sort" : [ + 1454803260000 + ] + } + ] + } + } + }, + { + "doc_count" : 1, + "key" : "Node starting up", + "hit" : { + "hits" : { + "total" : { + "value" : 1, + "relation" : "eq" + }, + "max_score" : null, + "hits" : [ + { + "_index" : "log-messages", + "_id" : "2", + "_score" : null, + "_source" : { + "message" : "2016-02-07T00:00:00+0000 Node 5 starting up" + }, + "sort" : [ + 1454803320000 + ] + } + ] + } + } + } + ] + } + }, + { + "key_as_string" : "2016-02-08T00:00:00.000Z", + "key" : 1454889600000, + "doc_count" : 3, + "categories" : { + "buckets" : [ + { + "doc_count" : 1, + "key" : "Node shutting down", + "hit" : { + "hits" : { + "total" : { + "value" : 1, + "relation" : "eq" + }, + "max_score" : null, + "hits" : [ + { + "_index" : "log-messages", + "_id" : "4", + "_score" : null, + "_source" : { + "message" : "2016-02-08T00:00:00+0000 Node 5 shutting down" + }, + "sort" : [ + 1454889660000 + ] + } + ] + } + } + }, + { + "doc_count" : 1, + "key" : "User logged off", + "hit" : { + "hits" : { + "total" : { + "value" : 1, + "relation" : "eq" + }, + "max_score" : null, + "hits" : [ + { + "_index" : "log-messages", + "_id" : "6", + "_score" : null, + "_source" : { + "message" : "2016-02-08T00:00:00+0000 User foo_864 logged off" + }, + "sort" : [ + 1454889840000 + ] + } + ] + } + } + }, + { + "doc_count" : 1, + "key" : "User logging on", + "hit" : { + "hits" : { + "total" : { + "value" : 1, + "relation" : "eq" + }, + "max_score" : null, + "hits" : [ + { + "_index" : "log-messages", + "_id" : "5", + "_score" : null, + "_source" : { + "message" : "2016-02-08T00:00:00+0000 User foo_325 logging on" + }, + "sort" : [ + 1454889720000 + ] + } + ] + } + } + } + ] + } + } + ] + } + } +} +-------------------------------------------------- diff --git a/server/src/main/java/org/elasticsearch/search/SearchService.java b/server/src/main/java/org/elasticsearch/search/SearchService.java index 01f00e1a00721..175f028435ed6 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchService.java +++ b/server/src/main/java/org/elasticsearch/search/SearchService.java @@ -986,6 +986,7 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc context.terminateAfter(source.terminateAfter()); if (source.aggregations() != null && includeAggregations) { AggregationContext aggContext = new ProductionAggregationContext( + indicesService.getAnalysis(), context.getSearchExecutionContext(), bigArrays, source.aggregations().bytesToPreallocate(), diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/ParsedMultiBucketAggregation.java b/server/src/main/java/org/elasticsearch/search/aggregations/ParsedMultiBucketAggregation.java index 76ca0a917fb5d..48bd678ce5f80 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/ParsedMultiBucketAggregation.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/ParsedMultiBucketAggregation.java @@ -48,7 +48,7 @@ protected XContentBuilder doXContentBody(XContentBuilder builder, Params params) return builder; } - protected static , T extends ParsedBucket> void declareMultiBucketAggregationFields( + public static , T extends ParsedBucket> void declareMultiBucketAggregationFields( final ObjectParser objectParser, final CheckedFunction bucketParser, final CheckedFunction keyedBucketParser diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/AggregationContext.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/AggregationContext.java index 5008fbe08eac4..8ce69009f09e0 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/support/AggregationContext.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/support/AggregationContext.java @@ -18,6 +18,8 @@ import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; import org.elasticsearch.index.IndexSettings; +import org.elasticsearch.index.analysis.AnalysisRegistry; +import org.elasticsearch.index.analysis.NameOrDefinition; import org.elasticsearch.index.analysis.NamedAnalyzer; import org.elasticsearch.index.cache.bitset.BitsetFilterCache; import org.elasticsearch.index.fielddata.IndexFieldData; @@ -95,6 +97,30 @@ public final FieldContext buildFieldContext(String field) { return new FieldContext(field, buildFieldData(ft), ft); } + /** + * Returns an existing registered analyzer that should NOT be closed when finished being used. + * @param analyzer The custom analyzer name + * @return The existing named analyzer. + */ + public abstract Analyzer getNamedAnalyzer(String analyzer) throws IOException; + + /** + * Creates a new custom analyzer that should be closed when finished being used. + * @param indexSettings The current index settings or null + * @param normalizer Is a normalizer + * @param tokenizer The tokenizer name or definition to use + * @param charFilters The char filter name or definition to use + * @param tokenFilters The token filter name or definition to use + * @return A new custom analyzer + */ + public abstract Analyzer buildCustomAnalyzer( + IndexSettings indexSettings, + boolean normalizer, + NameOrDefinition tokenizer, + List charFilters, + List tokenFilters + ) throws IOException; + /** * Lookup the context for an already resolved field type. */ @@ -277,10 +303,12 @@ public static class ProductionAggregationContext extends AggregationContext { private final Supplier isCancelled; private final Function filterQuery; private final boolean enableRewriteToFilterByFilter; + private final AnalysisRegistry analysisRegistry; private final List releaseMe = new ArrayList<>(); public ProductionAggregationContext( + AnalysisRegistry analysisRegistry, SearchExecutionContext context, BigArrays bigArrays, long bytesToPreallocate, @@ -295,6 +323,7 @@ public ProductionAggregationContext( Function filterQuery, boolean enableRewriteToFilterByFilter ) { + this.analysisRegistry = analysisRegistry; this.context = context; if (bytesToPreallocate == 0) { /* @@ -350,6 +379,22 @@ public long nowInMillis() { return context.nowInMillis(); } + @Override + public Analyzer getNamedAnalyzer(String analyzer) throws IOException { + return analysisRegistry.getAnalyzer(analyzer); + } + + @Override + public Analyzer buildCustomAnalyzer( + IndexSettings indexSettings, + boolean normalizer, + NameOrDefinition tokenizer, + List charFilters, + List tokenFilters + ) throws IOException { + return analysisRegistry.buildCustomAnalyzer(indexSettings, normalizer, tokenizer, charFilters, tokenFilters); + } + @Override protected IndexFieldData buildFieldData(MappedFieldType ft) { return context.getForField(ft); diff --git a/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperServiceTestCase.java b/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperServiceTestCase.java index f341ae905ce0b..c9d1529e45430 100644 --- a/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperServiceTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperServiceTestCase.java @@ -37,6 +37,7 @@ import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.analysis.AnalyzerScope; import org.elasticsearch.index.analysis.IndexAnalyzers; +import org.elasticsearch.index.analysis.NameOrDefinition; import org.elasticsearch.index.analysis.NamedAnalyzer; import org.elasticsearch.index.cache.bitset.BitsetFilterCache; import org.elasticsearch.index.fielddata.IndexFieldData; @@ -352,6 +353,22 @@ public long nowInMillis() { return 0; } + @Override + public Analyzer getNamedAnalyzer(String analyzer) { + return null; + } + + @Override + public Analyzer buildCustomAnalyzer( + IndexSettings indexSettings, + boolean normalizer, + NameOrDefinition tokenizer, + List charFilters, + List tokenFilters + ) { + return null; + } + @Override public boolean isFieldMapped(String field) { throw new UnsupportedOperationException(); diff --git a/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java b/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java index 577598b467792..ac684f99e97c7 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java @@ -94,6 +94,7 @@ import org.elasticsearch.index.shard.IndexShard; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.indices.IndicesModule; +import org.elasticsearch.indices.analysis.AnalysisModule; import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; import org.elasticsearch.plugins.SearchPlugin; @@ -137,6 +138,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.function.Consumer; import java.util.function.Supplier; @@ -160,6 +162,7 @@ public abstract class AggregatorTestCase extends ESTestCase { private List releasables = new ArrayList<>(); protected ValuesSourceRegistry valuesSourceRegistry; + private AnalysisModule analysisModule; // A list of field types that should not be tested, or are not currently supported private static final List TYPE_TEST_BLACKLIST = List.of( @@ -179,6 +182,19 @@ public void initValuesSourceRegistry() { valuesSourceRegistry = searchModule.getValuesSourceRegistry(); } + @Before + public void initAnalysisRegistry() throws Exception { + analysisModule = createAnalysisModule(); + } + + /** + * @return a new analysis module. Tests that require a fully constructed analysis module (used to create an analysis registry) + * should override this method + */ + protected AnalysisModule createAnalysisModule() throws Exception { + return null; + } + /** * Test cases should override this if they have plugins that need to be loaded, e.g. the plugins their aggregators are in. */ @@ -284,6 +300,7 @@ public void onCache(ShardId shardId, Accountable accountable) {} MultiBucketConsumer consumer = new MultiBucketConsumer(maxBucket, breakerService.getBreaker(CircuitBreaker.REQUEST)); AggregationContext context = new ProductionAggregationContext( + Optional.ofNullable(analysisModule).map(AnalysisModule::getAnalysisRegistry).orElse(null), searchExecutionContext, new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), breakerService), bytesToPreallocate, diff --git a/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java index 30bcfd62a8e99..73eef9442246e 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java @@ -142,6 +142,14 @@ protected Collection> getPlugins() { return Collections.singletonList(TestGeoShapeFieldMapperPlugin.class); } + /** + * Allows additional plugins other than the required `TestGeoShapeFieldMapperPlugin` + * Could probably be removed when dependencies against geo_shape is decoupled + */ + protected Collection> getExtraPlugins() { + return Collections.emptyList(); + } + protected void initializeAdditionalMappings(MapperService mapperService) throws IOException { } @@ -208,9 +216,11 @@ public void beforeTest() throws Exception { // this setup long masterSeed = SeedUtils.parseSeed(RandomizedTest.getContext().getRunnerSeedAsString()); RandomizedTest.getContext().runWithPrivateRandomness(masterSeed, (Callable) () -> { - serviceHolder = new ServiceHolder(nodeSettings, createTestIndexSettings(), getPlugins(), nowInMillis, + Collection> plugins = new ArrayList<>(getPlugins()); + plugins.addAll(getExtraPlugins()); + serviceHolder = new ServiceHolder(nodeSettings, createTestIndexSettings(), plugins, nowInMillis, AbstractBuilderTestCase.this, true); - serviceHolderWithNoType = new ServiceHolder(nodeSettings, createTestIndexSettings(), getPlugins(), nowInMillis, + serviceHolderWithNoType = new ServiceHolder(nodeSettings, createTestIndexSettings(), plugins, nowInMillis, AbstractBuilderTestCase.this, false); return null; }); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/AnalysisConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/AnalysisConfig.java index a27aa10215ef9..6e8cb3845de63 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/AnalysisConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/AnalysisConfig.java @@ -736,7 +736,7 @@ private void verifyCategorizationFiltersAreValidRegex() { } } - private static boolean isValidRegex(String exp) { + public static boolean isValidRegex(String exp) { try { Pattern.compile(exp); return true; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/CategorizationAnalyzerConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/CategorizationAnalyzerConfig.java index d9430762df5d7..cedc046b26674 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/CategorizationAnalyzerConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/CategorizationAnalyzerConfig.java @@ -86,8 +86,10 @@ public static CategorizationAnalyzerConfig buildFromXContentObject(XContentParse * * The parser is strict when parsing config and lenient when parsing cluster state. */ - static CategorizationAnalyzerConfig buildFromXContentFragment(XContentParser parser, boolean ignoreUnknownFields) throws IOException { - + public static CategorizationAnalyzerConfig buildFromXContentFragment( + XContentParser parser, + boolean ignoreUnknownFields + ) throws IOException { CategorizationAnalyzerConfig.Builder builder = new CategorizationAnalyzerConfig.Builder(); XContentParser.Token token = parser.currentToken(); diff --git a/x-pack/plugin/ml/qa/ml-with-security/build.gradle b/x-pack/plugin/ml/qa/ml-with-security/build.gradle index 9c53ed96f3724..448fc8b1fc39b 100644 --- a/x-pack/plugin/ml/qa/ml-with-security/build.gradle +++ b/x-pack/plugin/ml/qa/ml-with-security/build.gradle @@ -35,6 +35,9 @@ tasks.named("yamlRestTest").configure { 'ml/calendar_crud/Test put calendar given id contains invalid chars', 'ml/calendar_crud/Test delete event from non existing calendar', 'ml/calendar_crud/Test delete job from non existing calendar', + // These are searching tests with aggregations, and do not call any ML endpoints + 'ml/categorization_agg/Test categorization agg simple', + 'ml/categorization_agg/Test categorization aggregation with poor settings', 'ml/custom_all_field/Test querying custom all field', 'ml/datafeeds_crud/Test delete datafeed with missing id', 'ml/datafeeds_crud/Test put datafeed referring to missing job_id', diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/CategorizationAggregationIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/CategorizationAggregationIT.java new file mode 100644 index 0000000000000..bd85984b3619f --- /dev/null +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/CategorizationAggregationIT.java @@ -0,0 +1,160 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.integration; + +import org.elasticsearch.action.bulk.BulkRequestBuilder; +import org.elasticsearch.action.bulk.BulkResponse; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.search.aggregations.AggregationBuilders; +import org.elasticsearch.search.aggregations.metrics.Max; +import org.elasticsearch.search.aggregations.metrics.Min; +import org.elasticsearch.xpack.ml.aggs.categorization.CategorizeTextAggregationBuilder; +import org.elasticsearch.xpack.ml.aggs.categorization.InternalCategorizationAggregation; +import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase; +import org.junit.Before; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.notANumber; + +public class CategorizationAggregationIT extends BaseMlIntegTestCase { + + private static final String DATA_INDEX = "categorization-agg-data"; + + @Before + public void setupCluster() { + internalCluster().ensureAtLeastNumDataNodes(3); + ensureStableCluster(); + createSourceData(); + } + + public void testAggregation() { + SearchResponse response = client().prepareSearch(DATA_INDEX) + .setSize(0) + .setTrackTotalHits(false) + .addAggregation( + new CategorizeTextAggregationBuilder("categorize", "msg") + .subAggregation(AggregationBuilders.max("max").field("time")) + .subAggregation(AggregationBuilders.min("min").field("time")) + ).get(); + + InternalCategorizationAggregation agg = response.getAggregations().get("categorize"); + assertThat(agg.getBuckets(), hasSize(3)); + + assertCategorizationBucket(agg.getBuckets().get(0), "Node started", 3); + assertCategorizationBucket( + agg.getBuckets().get(1), + "Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception", + 2 + ); + assertCategorizationBucket(agg.getBuckets().get(2), "Node stopped", 1); + } + + public void testAggregationWithOnlyOneBucket() { + SearchResponse response = client().prepareSearch(DATA_INDEX) + .setSize(0) + .setTrackTotalHits(false) + .addAggregation( + new CategorizeTextAggregationBuilder("categorize", "msg") + .size(1) + .subAggregation(AggregationBuilders.max("max").field("time")) + .subAggregation(AggregationBuilders.min("min").field("time")) + ).get(); + InternalCategorizationAggregation agg = response.getAggregations().get("categorize"); + assertThat(agg.getBuckets(), hasSize(1)); + + assertCategorizationBucket(agg.getBuckets().get(0), "Node started", 3); + } + + public void testAggregationWithBroadCategories() { + SearchResponse response = client().prepareSearch(DATA_INDEX) + .setSize(0) + .setTrackTotalHits(false) + .addAggregation( + new CategorizeTextAggregationBuilder("categorize", "msg") + .setSimilarityThreshold(11) + .setMaxUniqueTokens(2) + .setMaxMatchedTokens(1) + .subAggregation(AggregationBuilders.max("max").field("time")) + .subAggregation(AggregationBuilders.min("min").field("time")) + ).get(); + InternalCategorizationAggregation agg = response.getAggregations().get("categorize"); + assertThat(agg.getBuckets(), hasSize(2)); + + assertCategorizationBucket(agg.getBuckets().get(0), "Node *", 4); + assertCategorizationBucket( + agg.getBuckets().get(1), + "Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception", + 2 + ); + } + + private void assertCategorizationBucket(InternalCategorizationAggregation.Bucket bucket, String key, long docCount) { + assertThat(bucket.getKeyAsString(), equalTo(key)); + assertThat(bucket.getDocCount(), equalTo(docCount)); + assertThat(((Max)bucket.getAggregations().get("max")).getValue(), not(notANumber())); + assertThat(((Min)bucket.getAggregations().get("min")).getValue(), not(notANumber())); + } + + private void ensureStableCluster() { + ensureStableCluster(internalCluster().getNodeNames().length, TimeValue.timeValueSeconds(60)); + } + + private void createSourceData() { + client().admin().indices().prepareCreate(DATA_INDEX) + .setMapping("time", "type=date,format=epoch_millis", + "msg", "type=text") + .get(); + + long nowMillis = System.currentTimeMillis(); + + BulkRequestBuilder bulkRequestBuilder = client().prepareBulk(); + IndexRequest indexRequest = new IndexRequest(DATA_INDEX); + indexRequest.source("time", nowMillis - TimeValue.timeValueHours(2).millis(), + "msg", "Node 1 started", + "part", "nodes"); + bulkRequestBuilder.add(indexRequest); + indexRequest = new IndexRequest(DATA_INDEX); + indexRequest.source("time", nowMillis - TimeValue.timeValueHours(2).millis() + 1, + "msg", "Failed to shutdown [error org.aaaa.bbbb.Cccc line 54 caused by foo exception]", + "part", "shutdowns"); + bulkRequestBuilder.add(indexRequest); + indexRequest = new IndexRequest(DATA_INDEX); + indexRequest.source("time", nowMillis - TimeValue.timeValueHours(2).millis() + 1, + "msg", "Failed to shutdown [error org.aaaa.bbbb.Cccc line 55 caused by foo exception]", + "part", "shutdowns"); + bulkRequestBuilder.add(indexRequest); + indexRequest = new IndexRequest(DATA_INDEX); + indexRequest.source("time", nowMillis - TimeValue.timeValueHours(1).millis(), + "msg", "Node 2 started", + "part", "nodes"); + bulkRequestBuilder.add(indexRequest); + indexRequest = new IndexRequest(DATA_INDEX); + indexRequest.source("time", nowMillis, + "msg", "Node 3 started", + "part", "nodes"); + bulkRequestBuilder.add(indexRequest); + + indexRequest = new IndexRequest(DATA_INDEX); + indexRequest.source("time", nowMillis, + "msg", "Node 3 stopped", + "part", "nodes"); + bulkRequestBuilder.add(indexRequest); + + BulkResponse bulkResponse = bulkRequestBuilder + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .get(); + assertThat(bulkResponse.hasFailures(), is(false)); + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index fa36eb5a3b425..970f856ce5142 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -264,6 +264,8 @@ import org.elasticsearch.xpack.ml.action.TransportUpgradeJobModelSnapshotAction; import org.elasticsearch.xpack.ml.action.TransportValidateDetectorAction; import org.elasticsearch.xpack.ml.action.TransportValidateJobConfigAction; +import org.elasticsearch.xpack.ml.aggs.categorization.CategorizeTextAggregationBuilder; +import org.elasticsearch.xpack.ml.aggs.categorization.InternalCategorizationAggregation; import org.elasticsearch.xpack.ml.aggs.correlation.BucketCorrelationAggregationBuilder; import org.elasticsearch.xpack.ml.aggs.correlation.CorrelationNamedContentProvider; import org.elasticsearch.xpack.ml.aggs.heuristic.PValueScore; @@ -1220,6 +1222,7 @@ public List> getExecutorBuilders(Settings settings) { return Arrays.asList(jobComms, utility, datafeed); } + @Override public Map> getCharFilters() { return MapBuilder.>newMapBuilder() .put(FirstNonBlankLineCharFilter.NAME, FirstNonBlankLineCharFilterFactory::new) @@ -1249,6 +1252,18 @@ public List> getSignificanceHeuristics() { ); } + @Override + public List getAggregations() { + return Arrays.asList( + new AggregationSpec( + CategorizeTextAggregationBuilder.NAME, + CategorizeTextAggregationBuilder::new, + CategorizeTextAggregationBuilder.PARSER + ).addResultReader(InternalCategorizationAggregation::new) + .setAggregatorRegistrar(s -> s.registerUsage(CategorizeTextAggregationBuilder.NAME)) + ); + } + @Override public UnaryOperator> getIndexTemplateMetadataUpgrader() { return UnaryOperator.identity(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationBytesRefHash.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationBytesRefHash.java new file mode 100644 index 0000000000000..6246683ddfe6e --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationBytesRefHash.java @@ -0,0 +1,83 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.logging.LoggerMessageFormat; +import org.elasticsearch.common.util.BytesRefHash; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.search.aggregations.AggregationExecutionException; + +class CategorizationBytesRefHash implements Releasable { + + /** + * Our special wild card value. + */ + static final BytesRef WILD_CARD_REF = new BytesRef("*"); + /** + * For all WILD_CARD references, the token ID is always -1 + */ + static final int WILD_CARD_ID = -1; + private final BytesRefHash bytesRefHash; + + CategorizationBytesRefHash(BytesRefHash bytesRefHash) { + this.bytesRefHash = bytesRefHash; + } + + int[] getIds(BytesRef[] tokens) { + int[] ids = new int[tokens.length]; + for (int i = 0; i < tokens.length; i++) { + ids[i] = put(tokens[i]); + } + return ids; + } + + BytesRef[] getDeeps(int[] ids) { + BytesRef[] tokens = new BytesRef[ids.length]; + for (int i = 0; i < tokens.length; i++) { + tokens[i] = getDeep(ids[i]); + } + return tokens; + } + + BytesRef getDeep(long id) { + if (id == WILD_CARD_ID) { + return WILD_CARD_REF; + } + BytesRef shallow = bytesRefHash.get(id, new BytesRef()); + return BytesRef.deepCopyOf(shallow); + } + + int put(BytesRef bytesRef) { + if (WILD_CARD_REF.equals(bytesRef)) { + return WILD_CARD_ID; + } + long hash = bytesRefHash.add(bytesRef); + if (hash < 0) { + return (int) (-1L - hash); + } else { + if (hash > Integer.MAX_VALUE) { + throw new AggregationExecutionException( + LoggerMessageFormat.format( + "more than [{}] unique terms encountered. " + + "Consider restricting the documents queried or adding [{}] in the {} configuration", + Integer.MAX_VALUE, + CategorizeTextAggregationBuilder.CATEGORIZATION_FILTERS.getPreferredName(), + CategorizeTextAggregationBuilder.NAME + ) + ); + } + return (int) hash; + } + } + + @Override + public void close() { + bytesRefHash.close(); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java new file mode 100644 index 0000000000000..f5b5e6daea956 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java @@ -0,0 +1,143 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.search.aggregations.InternalAggregations; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + + +/** + * Categorized semi-structured text utilizing the drain algorithm: https://arxiv.org/pdf/1806.04356.pdf + * With the following key differences + * - This structure keeps track of the "smallest" sub-tree. So, instead of naively adding a new "*" node, the smallest sub-tree + * is transformed if the incoming token has a higher doc_count. + * - Additionally, similarities are weighted, which allows for nicer merging of existing categories + * - An optional tree reduction step is available to collapse together tiny sub-trees + * + * + * The main implementation is a fixed-sized prefix tree. + * Consequently, this assumes that splits that give us more information come earlier in the text. + * + * Examples: + * + * Given token values: + * + * Node is online + * Node is offline + * + * With a fixed tree depth of 2 we would get the following splits + * 3 // initial root is the number of tokens + * | + * "Node" // first prefix node of value "Node" + * | + * "is" + * / \ + * [Node is online] [Node is offline] //the individual categories for this simple case + * + * If the similarityThreshold was less than 0.6, the result would be a single category [Node is *] + * + */ +public class CategorizationTokenTree implements Accountable { + + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(CategorizationTokenTree.class); + private final int maxMatchTokens; + private final int maxUniqueTokens; + private final int similarityThreshold; + private long idGenerator; + private final Map root = new HashMap<>(); + private long sizeInBytes; + + public CategorizationTokenTree(int maxUniqueTokens, int maxMatchTokens, int similarityThreshold) { + assert maxUniqueTokens > 0 && maxMatchTokens >= 0; + this.maxUniqueTokens = maxUniqueTokens; + this.maxMatchTokens = maxMatchTokens; + this.similarityThreshold = similarityThreshold; + this.sizeInBytes = SHALLOW_SIZE; + } + + public List toIntermediateBuckets(CategorizationBytesRefHash hash) { + return root.values().stream().flatMap(c -> c.getAllChildrenTextCategorizations().stream()).map(lg -> { + int[] categoryTokenIds = lg.getCategorization(); + BytesRef[] bytesRefs = new BytesRef[categoryTokenIds.length]; + for (int i = 0; i < categoryTokenIds.length; i++) { + bytesRefs[i] = hash.getDeep(categoryTokenIds[i]); + } + InternalCategorizationAggregation.Bucket bucket = new InternalCategorizationAggregation.Bucket( + new InternalCategorizationAggregation.BucketKey(bytesRefs), + lg.getCount(), + InternalAggregations.EMPTY + ); + bucket.bucketOrd = lg.bucketOrd; + return bucket; + }).collect(Collectors.toList()); + } + + void mergeSmallestChildren() { + root.values().forEach(TreeNode::collapseTinyChildren); + } + + /** + * This method does not mutate the underlying structure. Meaning, if a matching categories isn't found, it may return empty. + * + * @param tokenIds The tokens to categorize + * @return The category or `Optional.empty()` if one doesn't exist + */ + public Optional parseTokensConst(final int[] tokenIds) { + TreeNode currentNode = this.root.get(tokenIds.length); + if (currentNode == null) { // we are missing an entire sub tree. New token length found + return Optional.empty(); + } + return Optional.ofNullable(currentNode.getCategorization(tokenIds)); + } + + /** + * This categorizes the passed tokens, potentially mutating the structure by expanding an existing category or adding a new one. + * @param tokenIds The tokens to categorize + * @param docCount The count of docs for the given tokens + * @return An existing categorization or a newly created one + */ + public TextCategorization parseTokens(final int[] tokenIds, long docCount) { + TreeNode currentNode = this.root.get(tokenIds.length); + if (currentNode == null) { // we are missing an entire sub tree. New token length found + currentNode = newNode(docCount, 0, tokenIds); + incSize(currentNode.ramBytesUsed() + RamUsageEstimator.HASHTABLE_RAM_BYTES_PER_ENTRY + RamUsageEstimator.NUM_BYTES_OBJECT_REF); + this.root.put(tokenIds.length, currentNode); + } else { + currentNode.incCount(docCount); + } + return currentNode.addText(tokenIds, docCount, this); + } + + TreeNode newNode(long docCount, int tokenPos, int[] tokenIds) { + return tokenPos < maxMatchTokens - 1 && tokenPos < tokenIds.length + ? new TreeNode.InnerTreeNode(docCount, tokenPos, maxUniqueTokens) + : new TreeNode.LeafTreeNode(docCount, similarityThreshold); + } + + TextCategorization newCategorization(long docCount, int[] tokenIds) { + return new TextCategorization(tokenIds, docCount, idGenerator++); + } + + void incSize(long size) { + sizeInBytes += size; + } + + @Override + public long ramBytesUsed() { + return sizeInBytes; + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilder.java new file mode 100644 index 0000000000000..e84167fd1d0a9 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilder.java @@ -0,0 +1,349 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ParseField; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.aggregations.AbstractAggregationBuilder; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.AggregatorFactories; +import org.elasticsearch.search.aggregations.AggregatorFactory; +import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder; +import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregator; +import org.elasticsearch.search.aggregations.support.AggregationContext; +import org.elasticsearch.xpack.core.ml.job.config.CategorizationAnalyzerConfig; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder.MIN_DOC_COUNT_FIELD_NAME; +import static org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder.REQUIRED_SIZE_FIELD_NAME; +import static org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder.SHARD_MIN_DOC_COUNT_FIELD_NAME; +import static org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder.SHARD_SIZE_FIELD_NAME; +import static org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig.Builder.isValidRegex; + +public class CategorizeTextAggregationBuilder extends AbstractAggregationBuilder { + + static final TermsAggregator.BucketCountThresholds DEFAULT_BUCKET_COUNT_THRESHOLDS = new TermsAggregator.BucketCountThresholds( + 1, + 0, + 10, + -1 + ); + + static final int MAX_MAX_UNIQUE_TOKENS = 100; + static final int MAX_MAX_MATCHED_TOKENS = 100; + public static final String NAME = "categorize_text"; + + static final ParseField FIELD_NAME = new ParseField("field"); + static final ParseField MAX_UNIQUE_TOKENS = new ParseField("max_unique_tokens"); + static final ParseField SIMILARITY_THRESHOLD = new ParseField("similarity_threshold"); + static final ParseField MAX_MATCHED_TOKENS = new ParseField("max_matched_tokens"); + static final ParseField CATEGORIZATION_FILTERS = new ParseField("categorization_filters"); + static final ParseField CATEGORIZATION_ANALYZER = new ParseField("categorization_analyzer"); + + public static final ObjectParser PARSER = ObjectParser.fromBuilder( + CategorizeTextAggregationBuilder.NAME, + CategorizeTextAggregationBuilder::new + ); + static { + PARSER.declareString(CategorizeTextAggregationBuilder::setFieldName, FIELD_NAME); + PARSER.declareInt(CategorizeTextAggregationBuilder::setMaxUniqueTokens, MAX_UNIQUE_TOKENS); + PARSER.declareInt(CategorizeTextAggregationBuilder::setMaxMatchedTokens, MAX_MATCHED_TOKENS); + PARSER.declareInt(CategorizeTextAggregationBuilder::setSimilarityThreshold, SIMILARITY_THRESHOLD); + PARSER.declareField( + CategorizeTextAggregationBuilder::setCategorizationAnalyzerConfig, + (p, c) -> CategorizationAnalyzerConfig.buildFromXContentFragment(p, false), + CATEGORIZATION_ANALYZER, + ObjectParser.ValueType.OBJECT_OR_STRING + ); + PARSER.declareStringArray(CategorizeTextAggregationBuilder::setCategorizationFilters, CATEGORIZATION_FILTERS); + PARSER.declareInt(CategorizeTextAggregationBuilder::shardSize, TermsAggregationBuilder.SHARD_SIZE_FIELD_NAME); + PARSER.declareLong(CategorizeTextAggregationBuilder::minDocCount, TermsAggregationBuilder.MIN_DOC_COUNT_FIELD_NAME); + PARSER.declareLong(CategorizeTextAggregationBuilder::shardMinDocCount, TermsAggregationBuilder.SHARD_MIN_DOC_COUNT_FIELD_NAME); + PARSER.declareInt(CategorizeTextAggregationBuilder::size, REQUIRED_SIZE_FIELD_NAME); + } + + public static CategorizeTextAggregationBuilder parse(String aggregationName, XContentParser parser) throws IOException { + return PARSER.parse(parser, new CategorizeTextAggregationBuilder(aggregationName), null); + } + + private TermsAggregator.BucketCountThresholds bucketCountThresholds = new TermsAggregator.BucketCountThresholds( + DEFAULT_BUCKET_COUNT_THRESHOLDS + ); + private CategorizationAnalyzerConfig categorizationAnalyzerConfig; + private String fieldName; + private int maxUniqueTokens = 50; + private int similarityThreshold = 50; + private int maxMatchedTokens = 5; + + private CategorizeTextAggregationBuilder(String name) { + super(name); + } + + public CategorizeTextAggregationBuilder(String name, String fieldName) { + super(name); + this.fieldName = ExceptionsHelper.requireNonNull(fieldName, FIELD_NAME); + } + + public String getFieldName() { + return fieldName; + } + + public CategorizeTextAggregationBuilder setFieldName(String fieldName) { + this.fieldName = ExceptionsHelper.requireNonNull(fieldName, FIELD_NAME); + return this; + } + + public CategorizeTextAggregationBuilder(StreamInput in) throws IOException { + super(in); + this.bucketCountThresholds = new TermsAggregator.BucketCountThresholds(in); + this.fieldName = in.readString(); + this.maxUniqueTokens = in.readVInt(); + this.maxMatchedTokens = in.readVInt(); + this.similarityThreshold = in.readVInt(); + this.categorizationAnalyzerConfig = in.readOptionalWriteable(CategorizationAnalyzerConfig::new); + } + + public int getMaxUniqueTokens() { + return maxUniqueTokens; + } + + public CategorizeTextAggregationBuilder setMaxUniqueTokens(int maxUniqueTokens) { + this.maxUniqueTokens = maxUniqueTokens; + if (maxUniqueTokens <= 0) { + throw ExceptionsHelper.badRequestException( + "[{}] must be greater than 0 and less than [{}]. Found [{}] in [{}]", + MAX_UNIQUE_TOKENS.getPreferredName(), + MAX_MAX_UNIQUE_TOKENS, + maxUniqueTokens, + name + ); + } + return this; + } + + public double getSimilarityThreshold() { + return similarityThreshold; + } + + public CategorizeTextAggregationBuilder setSimilarityThreshold(int similarityThreshold) { + this.similarityThreshold = similarityThreshold; + if (similarityThreshold < 1 || similarityThreshold > 100) { + throw ExceptionsHelper.badRequestException( + "[{}] must be in the range [1, 100]. Found [{}] in [{}]", + SIMILARITY_THRESHOLD.getPreferredName(), + similarityThreshold, + name + ); + } + return this; + } + + public CategorizeTextAggregationBuilder setCategorizationAnalyzerConfig(CategorizationAnalyzerConfig categorizationAnalyzerConfig) { + this.categorizationAnalyzerConfig = categorizationAnalyzerConfig; + return this; + } + + public CategorizeTextAggregationBuilder setCategorizationFilters(List categorizationFilters) { + if (categorizationFilters == null || categorizationFilters.isEmpty()) { + return this; + } + if (categorizationAnalyzerConfig != null) { + throw ExceptionsHelper.badRequestException( + "[{}] cannot be used with [{}] - instead specify them as pattern_replace char_filters in the analyzer", + CATEGORIZATION_FILTERS.getPreferredName(), + CATEGORIZATION_ANALYZER.getPreferredName() + ); + } + if (categorizationFilters.stream().distinct().count() != categorizationFilters.size()) { + throw ExceptionsHelper.badRequestException(Messages.JOB_CONFIG_CATEGORIZATION_FILTERS_CONTAINS_DUPLICATES); + } + if (categorizationFilters.stream().anyMatch(String::isEmpty)) { + throw ExceptionsHelper.badRequestException(Messages.getMessage(Messages.JOB_CONFIG_CATEGORIZATION_FILTERS_CONTAINS_EMPTY)); + } + for (String filter : categorizationFilters) { + if (isValidRegex(filter) == false) { + throw ExceptionsHelper.badRequestException( + Messages.getMessage(Messages.JOB_CONFIG_CATEGORIZATION_FILTERS_CONTAINS_INVALID_REGEX, filter) + ); + } + } + this.categorizationAnalyzerConfig = CategorizationAnalyzerConfig.buildStandardCategorizationAnalyzer(categorizationFilters); + return this; + } + + public int getMaxMatchedTokens() { + return maxMatchedTokens; + } + + public CategorizeTextAggregationBuilder setMaxMatchedTokens(int maxMatchedTokens) { + this.maxMatchedTokens = maxMatchedTokens; + if (maxMatchedTokens <= 0) { + throw ExceptionsHelper.badRequestException( + "[{}] must be greater than 0 and less than [{}]. Found [{}] in [{}]", + MAX_MATCHED_TOKENS.getPreferredName(), + MAX_MAX_MATCHED_TOKENS, + maxMatchedTokens, + name + ); + } + return this; + } + + /** + * @param size indicating how many buckets should be returned + */ + public CategorizeTextAggregationBuilder size(int size) { + if (size <= 0) { + throw ExceptionsHelper.badRequestException( + "[{}] must be greater than 0. Found [{}] in [{}]", + REQUIRED_SIZE_FIELD_NAME.getPreferredName(), + size, + name + ); + } + bucketCountThresholds.setRequiredSize(size); + return this; + } + + /** + * @param shardSize - indicating the number of buckets each shard + * will return to the coordinating node (the node that coordinates the + * search execution). The higher the shard size is, the more accurate the + * results are. + */ + public CategorizeTextAggregationBuilder shardSize(int shardSize) { + if (shardSize <= 0) { + throw ExceptionsHelper.badRequestException( + "[{}] must be greater than 0. Found [{}] in [{}]", + SHARD_SIZE_FIELD_NAME.getPreferredName(), + shardSize, + name + ); + } + bucketCountThresholds.setShardSize(shardSize); + return this; + } + + /** + * @param minDocCount the minimum document count a text category should have in order to appear in + * the response. + */ + public CategorizeTextAggregationBuilder minDocCount(long minDocCount) { + if (minDocCount < 0) { + throw ExceptionsHelper.badRequestException( + "[{}] must be greater than or equal to 0. Found [{}] in [{}]", + MIN_DOC_COUNT_FIELD_NAME.getPreferredName(), + minDocCount, + name + ); + } + bucketCountThresholds.setMinDocCount(minDocCount); + return this; + } + + /** + * @param shardMinDocCount the minimum document count a text category should have on the shard in order to + * appear in the response. + */ + public CategorizeTextAggregationBuilder shardMinDocCount(long shardMinDocCount) { + if (shardMinDocCount < 0) { + throw ExceptionsHelper.badRequestException( + "[{}] must be greater than or equal to 0. Found [{}] in [{}]", + SHARD_MIN_DOC_COUNT_FIELD_NAME.getPreferredName(), + shardMinDocCount, + name + ); + } + bucketCountThresholds.setShardMinDocCount(shardMinDocCount); + return this; + } + + protected CategorizeTextAggregationBuilder( + CategorizeTextAggregationBuilder clone, + AggregatorFactories.Builder factoriesBuilder, + Map metadata + ) { + super(clone, factoriesBuilder, metadata); + this.bucketCountThresholds = new TermsAggregator.BucketCountThresholds(clone.bucketCountThresholds); + this.fieldName = clone.fieldName; + this.maxUniqueTokens = clone.maxUniqueTokens; + this.maxMatchedTokens = clone.maxMatchedTokens; + this.similarityThreshold = clone.similarityThreshold; + this.categorizationAnalyzerConfig = clone.categorizationAnalyzerConfig; + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + bucketCountThresholds.writeTo(out); + out.writeString(fieldName); + out.writeVInt(maxUniqueTokens); + out.writeVInt(maxMatchedTokens); + out.writeVInt(similarityThreshold); + out.writeOptionalWriteable(categorizationAnalyzerConfig); + } + + @Override + protected AggregatorFactory doBuild( + AggregationContext context, + AggregatorFactory parent, + AggregatorFactories.Builder subfactoriesBuilder + ) throws IOException { + return new CategorizeTextAggregatorFactory( + name, + fieldName, + maxUniqueTokens, + maxMatchedTokens, + similarityThreshold, + bucketCountThresholds, + categorizationAnalyzerConfig, + context, + parent, + subfactoriesBuilder, + metadata + ); + } + + @Override + protected XContentBuilder internalXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + bucketCountThresholds.toXContent(builder, params); + builder.field(FIELD_NAME.getPreferredName(), fieldName); + builder.field(MAX_UNIQUE_TOKENS.getPreferredName(), maxUniqueTokens); + builder.field(MAX_MATCHED_TOKENS.getPreferredName(), maxMatchedTokens); + builder.field(SIMILARITY_THRESHOLD.getPreferredName(), similarityThreshold); + if (categorizationAnalyzerConfig != null) { + categorizationAnalyzerConfig.toXContent(builder, params); + } + builder.endObject(); + return null; + } + + @Override + protected AggregationBuilder shallowCopy(AggregatorFactories.Builder factoriesBuilder, Map metadata) { + return new CategorizeTextAggregationBuilder(this, factoriesBuilder, metadata); + } + + @Override + public BucketCardinality bucketCardinality() { + return BucketCardinality.MANY; + } + + @Override + public String getType() { + return NAME; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java new file mode 100644 index 0000000000000..16058fbdae4f2 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java @@ -0,0 +1,239 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.PriorityQueue; +import org.elasticsearch.common.util.BytesRefHash; +import org.elasticsearch.common.util.ObjectArray; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.search.aggregations.Aggregator; +import org.elasticsearch.search.aggregations.AggregatorFactories; +import org.elasticsearch.search.aggregations.CardinalityUpperBound; +import org.elasticsearch.search.aggregations.InternalAggregation; +import org.elasticsearch.search.aggregations.LeafBucketCollector; +import org.elasticsearch.search.aggregations.LeafBucketCollectorBase; +import org.elasticsearch.search.aggregations.bucket.DeferableBucketAggregator; +import org.elasticsearch.search.aggregations.bucket.terms.LongKeyedBucketOrds; +import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregator; +import org.elasticsearch.search.aggregations.support.AggregationContext; +import org.elasticsearch.search.lookup.SourceLookup; +import org.elasticsearch.xpack.core.ml.job.config.CategorizationAnalyzerConfig; +import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.Iterator; +import java.util.Map; +import java.util.Optional; + +public class CategorizeTextAggregator extends DeferableBucketAggregator { + + private final TermsAggregator.BucketCountThresholds bucketCountThresholds; + private final SourceLookup sourceLookup; + private final MappedFieldType fieldType; + private final CategorizationAnalyzer analyzer; + private final String sourceFieldName; + private ObjectArray categorizers; + private final int maxUniqueTokens; + private final int maxMatchTokens; + private final int similarityThreshold; + private final LongKeyedBucketOrds bucketOrds; + private final CategorizationBytesRefHash bytesRefHash; + + protected CategorizeTextAggregator( + String name, + AggregatorFactories factories, + AggregationContext context, + Aggregator parent, + String sourceFieldName, + MappedFieldType fieldType, + TermsAggregator.BucketCountThresholds bucketCountThresholds, + int maxUniqueTokens, + int maxMatchTokens, + int similarityThreshold, + CategorizationAnalyzerConfig categorizationAnalyzerConfig, + Map metadata + ) throws IOException { + super(name, factories, context, parent, metadata); + this.sourceLookup = context.lookup().source(); + this.sourceFieldName = sourceFieldName; + this.fieldType = fieldType; + CategorizationAnalyzerConfig analyzerConfig = Optional.ofNullable(categorizationAnalyzerConfig) + .orElse(CategorizationAnalyzerConfig.buildStandardCategorizationAnalyzer(Collections.emptyList())); + final String analyzerName = analyzerConfig.getAnalyzer(); + if (analyzerName != null) { + Analyzer globalAnalyzer = context.getNamedAnalyzer(analyzerName); + if (globalAnalyzer == null) { + throw new IllegalArgumentException("Failed to find global analyzer [" + analyzerName + "]"); + } + this.analyzer = new CategorizationAnalyzer(globalAnalyzer, false); + } else { + this.analyzer = new CategorizationAnalyzer( + context.buildCustomAnalyzer( + context.getIndexSettings(), + false, + analyzerConfig.getTokenizer(), + analyzerConfig.getCharFilters(), + analyzerConfig.getTokenFilters() + ), + true + ); + } + this.categorizers = bigArrays().newObjectArray(1); + this.maxUniqueTokens = maxUniqueTokens; + this.maxMatchTokens = maxMatchTokens; + this.similarityThreshold = similarityThreshold; + this.bucketOrds = LongKeyedBucketOrds.build(bigArrays(), CardinalityUpperBound.MANY); + this.bucketCountThresholds = bucketCountThresholds; + this.bytesRefHash = new CategorizationBytesRefHash(new BytesRefHash(2048, bigArrays())); + } + + @Override + protected void doClose() { + super.doClose(); + Releasables.close(this.analyzer, this.bytesRefHash); + } + + @Override + public InternalAggregation[] buildAggregations(long[] ordsToCollect) throws IOException { + InternalCategorizationAggregation.Bucket[][] topBucketsPerOrd = + new InternalCategorizationAggregation.Bucket[ordsToCollect.length][]; + for (int ordIdx = 0; ordIdx < ordsToCollect.length; ordIdx++) { + final CategorizationTokenTree categorizationTokenTree = categorizers.get(ordsToCollect[ordIdx]); + if (categorizationTokenTree == null) { + topBucketsPerOrd[ordIdx] = new InternalCategorizationAggregation.Bucket[0]; + continue; + } + int size = (int) Math.min(bucketOrds.bucketsInOrd(ordIdx), bucketCountThresholds.getShardSize()); + PriorityQueue ordered = + new InternalCategorizationAggregation.BucketCountPriorityQueue(size); + for (InternalCategorizationAggregation.Bucket bucket : categorizationTokenTree.toIntermediateBuckets(bytesRefHash)) { + if (bucket.docCount < bucketCountThresholds.getShardMinDocCount()) { + continue; + } + ordered.insertWithOverflow(bucket); + } + topBucketsPerOrd[ordIdx] = new InternalCategorizationAggregation.Bucket[ordered.size()]; + for (int i = ordered.size() - 1; i >= 0; --i) { + topBucketsPerOrd[ordIdx][i] = ordered.pop(); + } + } + buildSubAggsForAllBuckets(topBucketsPerOrd, b -> b.bucketOrd, (b, a) -> b.aggregations = a); + InternalAggregation[] results = new InternalAggregation[ordsToCollect.length]; + for (int ordIdx = 0; ordIdx < ordsToCollect.length; ordIdx++) { + InternalCategorizationAggregation.Bucket[] bucketArray = topBucketsPerOrd[ordIdx]; + Arrays.sort(bucketArray, Comparator.naturalOrder()); + results[ordIdx] = new InternalCategorizationAggregation( + name, + bucketCountThresholds.getRequiredSize(), + bucketCountThresholds.getMinDocCount(), + maxUniqueTokens, + maxMatchTokens, + similarityThreshold, + metadata(), + Arrays.asList(bucketArray) + ); + } + return results; + } + + @Override + public InternalAggregation buildEmptyAggregation() { + return new InternalCategorizationAggregation( + name, + bucketCountThresholds.getRequiredSize(), + bucketCountThresholds.getMinDocCount(), + maxUniqueTokens, + maxMatchTokens, + similarityThreshold, + metadata() + ); + } + + @Override + protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException { + return new LeafBucketCollectorBase(sub, null) { + @Override + public void collect(int doc, long owningBucketOrd) throws IOException { + categorizers = bigArrays().grow(categorizers, owningBucketOrd + 1); + CategorizationTokenTree categorizer = categorizers.get(owningBucketOrd); + if (categorizer == null) { + categorizer = new CategorizationTokenTree(maxUniqueTokens, maxMatchTokens, similarityThreshold); + addRequestCircuitBreakerBytes(categorizer.ramBytesUsed()); + categorizers.set(owningBucketOrd, categorizer); + } + collectFromSource(doc, owningBucketOrd, categorizer); + } + + private void collectFromSource(int doc, long owningBucketOrd, CategorizationTokenTree categorizer) throws IOException { + sourceLookup.setSegmentAndDocument(ctx, doc); + Iterator itr = sourceLookup.extractRawValues(sourceFieldName).stream().map(obj -> { + if (obj == null) { + return null; + } + if (obj instanceof BytesRef) { + return fieldType.valueForDisplay(obj).toString(); + } + return obj.toString(); + }).iterator(); + while (itr.hasNext()) { + TokenStream ts = analyzer.tokenStream(fieldType.name(), itr.next()); + processTokenStream(owningBucketOrd, ts, doc, categorizer); + } + } + + private void processTokenStream( + long owningBucketOrd, + TokenStream ts, + int doc, + CategorizationTokenTree categorizer + ) throws IOException { + ArrayList tokens = new ArrayList<>(); + try { + CharTermAttribute termAtt = ts.addAttribute(CharTermAttribute.class); + ts.reset(); + while (ts.incrementToken()) { + tokens.add(bytesRefHash.put(new BytesRef(termAtt))); + } + if (tokens.isEmpty()) { + return; + } + } finally { + ts.close(); + } + long previousSize = categorizer.ramBytesUsed(); + TextCategorization lg = categorizer.parseTokens( + tokens.stream().mapToInt(Integer::valueOf).toArray(), + docCountProvider.getDocCount(doc) + ); + long newSize = categorizer.ramBytesUsed(); + if (newSize - previousSize > 0) { + addRequestCircuitBreakerBytes(newSize - previousSize); + } + + long bucketOrd = bucketOrds.add(owningBucketOrd, lg.getId()); + if (bucketOrd < 0) { // already seen + bucketOrd = -1 - bucketOrd; + collectExistingBucket(sub, doc, bucketOrd); + } else { + lg.bucketOrd = bucketOrd; + collectBucket(sub, doc, bucketOrd); + } + } + }; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorFactory.java new file mode 100644 index 0000000000000..f63b4ba1f802b --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorFactory.java @@ -0,0 +1,112 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.search.aggregations.Aggregator; +import org.elasticsearch.search.aggregations.AggregatorFactories; +import org.elasticsearch.search.aggregations.AggregatorFactory; +import org.elasticsearch.search.aggregations.CardinalityUpperBound; +import org.elasticsearch.search.aggregations.InternalAggregation; +import org.elasticsearch.search.aggregations.NonCollectingAggregator; +import org.elasticsearch.search.aggregations.bucket.BucketUtils; +import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregator; +import org.elasticsearch.search.aggregations.support.AggregationContext; +import org.elasticsearch.xpack.core.ml.job.config.CategorizationAnalyzerConfig; + +import java.io.IOException; +import java.util.Map; + +public class CategorizeTextAggregatorFactory extends AggregatorFactory { + + private final MappedFieldType fieldType; + private final String indexedFieldName; + private final int maxUniqueTokens; + private final int maxMatchTokens; + private final int similarityThreshold; + private final CategorizationAnalyzerConfig categorizationAnalyzerConfig; + private final TermsAggregator.BucketCountThresholds bucketCountThresholds; + + public CategorizeTextAggregatorFactory( + String name, + String fieldName, + int maxUniqueTokens, + int maxMatchTokens, + int similarityThreshold, + TermsAggregator.BucketCountThresholds bucketCountThresholds, + CategorizationAnalyzerConfig categorizationAnalyzerConfig, + AggregationContext context, + AggregatorFactory parent, + AggregatorFactories.Builder subFactoriesBuilder, + Map metadata + ) throws IOException { + super(name, context, parent, subFactoriesBuilder, metadata); + this.fieldType = context.getFieldType(fieldName); + if (fieldType != null) { + this.indexedFieldName = fieldType.name(); + } else { + this.indexedFieldName = null; + } + this.maxUniqueTokens = maxUniqueTokens; + this.maxMatchTokens = maxMatchTokens; + this.similarityThreshold = similarityThreshold; + this.categorizationAnalyzerConfig = categorizationAnalyzerConfig; + this.bucketCountThresholds = bucketCountThresholds; + } + + protected Aggregator createUnmapped(Aggregator parent, Map metadata) throws IOException { + final InternalAggregation aggregation = new UnmappedCategorizationAggregation( + name, + bucketCountThresholds.getRequiredSize(), + bucketCountThresholds.getMinDocCount(), + maxUniqueTokens, + maxMatchTokens, + similarityThreshold, + metadata + ); + return new NonCollectingAggregator(name, context, parent, factories, metadata) { + @Override + public InternalAggregation buildEmptyAggregation() { + return aggregation; + } + }; + } + + @Override + protected Aggregator createInternal(Aggregator parent, CardinalityUpperBound cardinality, Map metadata) + throws IOException { + if (fieldType == null) { + return createUnmapped(parent, metadata); + } + TermsAggregator.BucketCountThresholds bucketCountThresholds = new TermsAggregator.BucketCountThresholds(this.bucketCountThresholds); + if (bucketCountThresholds.getShardSize() == CategorizeTextAggregationBuilder.DEFAULT_BUCKET_COUNT_THRESHOLDS.getShardSize()) { + // The user has not made a shardSize selection. Use default + // heuristic to avoid any wrong-ranking caused by distributed + // counting + // TODO significant text does a 2x here, should we as well? + bucketCountThresholds.setShardSize(BucketUtils.suggestShardSideQueueSize(bucketCountThresholds.getRequiredSize())); + } + bucketCountThresholds.ensureValidity(); + + return new CategorizeTextAggregator( + name, + factories, + context, + parent, + indexedFieldName, + fieldType, + bucketCountThresholds, + maxUniqueTokens, + maxMatchTokens, + similarityThreshold, + categorizationAnalyzerConfig, + metadata + ); + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java new file mode 100644 index 0000000000000..92c51b8d75b4c --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java @@ -0,0 +1,453 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.PriorityQueue; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.util.BytesRefHash; +import org.elasticsearch.common.xcontent.ToXContentFragment; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.search.aggregations.AggregationExecutionException; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.InternalAggregation; +import org.elasticsearch.search.aggregations.InternalAggregations; +import org.elasticsearch.search.aggregations.InternalMultiBucketAggregation; +import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash.WILD_CARD_REF; + + +public class InternalCategorizationAggregation extends InternalMultiBucketAggregation< + InternalCategorizationAggregation, + InternalCategorizationAggregation.Bucket> { + + // Carries state allowing for delayed reduction of the bucket + // This allows us to keep from accidentally calling "reduce" on the sub-aggs more than once + private static class DelayedCategorizationBucket { + private final BucketKey key; + private long docCount; + private final List toReduce; + + DelayedCategorizationBucket(BucketKey key, List toReduce, long docCount) { + this.key = key; + this.toReduce = new ArrayList<>(toReduce); + this.docCount = docCount; + } + + public long getDocCount() { + return docCount; + } + + public Bucket reduce(BucketKey key, ReduceContext reduceContext) { + List innerAggs = new ArrayList<>(toReduce.size()); + long docCount = 0; + for (Bucket bucket : toReduce) { + innerAggs.add(bucket.aggregations); + docCount += bucket.docCount; + } + return new Bucket(key, docCount, InternalAggregations.reduce(innerAggs, reduceContext)); + } + + public DelayedCategorizationBucket add(Bucket bucket) { + this.docCount += bucket.docCount; + this.toReduce.add(bucket); + return this; + } + + public DelayedCategorizationBucket add(DelayedCategorizationBucket bucket) { + this.docCount += bucket.docCount; + this.toReduce.addAll(bucket.toReduce); + return this; + } + } + + static class BucketCountPriorityQueue extends PriorityQueue { + BucketCountPriorityQueue(int size) { + super(size); + } + + @Override + protected boolean lessThan(Bucket a, Bucket b) { + return a.docCount < b.docCount; + } + } + + static class BucketKey implements ToXContentFragment, Writeable, Comparable { + + private final BytesRef[] key; + + static BucketKey withCollapsedWildcards(BytesRef[] key) { + if (key.length <= 1) { + return new BucketKey(key); + } + List collapsedWildCards = new ArrayList<>(); + boolean previousTokenWildCard = false; + for (BytesRef token : key) { + if (token.equals(WILD_CARD_REF)) { + if (previousTokenWildCard == false) { + previousTokenWildCard = true; + collapsedWildCards.add(WILD_CARD_REF); + } + } else { + previousTokenWildCard = false; + collapsedWildCards.add(token); + } + } + if (collapsedWildCards.size() == key.length) { + return new BucketKey(key); + } + return new BucketKey(collapsedWildCards.toArray(BytesRef[]::new)); + } + + BucketKey(BytesRef[] key) { + this.key = key; + } + + BucketKey(StreamInput in) throws IOException { + key = in.readArray(StreamInput::readBytesRef, BytesRef[]::new); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder.value(asString()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeArray(StreamOutput::writeBytesRef, key); + } + + public String asString() { + StringBuilder builder = new StringBuilder(); + for (int i = 0; i < key.length - 1; i++) { + builder.append(key[i].utf8ToString()).append(" "); + } + builder.append(key[key.length - 1].utf8ToString()); + return builder.toString(); + } + + @Override + public String toString() { + return asString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + BucketKey bucketKey = (BucketKey) o; + return Arrays.equals(key, bucketKey.key); + } + + @Override + public int hashCode() { + return Arrays.hashCode(key); + } + + public BytesRef[] keyAsTokens() { + return key; + } + + @Override + public int compareTo(BucketKey o) { + return Arrays.compare(key, o.key); + } + + } + + public static class Bucket extends InternalBucket implements MultiBucketsAggregation.Bucket, Comparable { + // Used on the shard level to keep track of sub aggregations + long bucketOrd; + + final BucketKey key; + final long docCount; + InternalAggregations aggregations; + + public Bucket(BucketKey key, long docCount, InternalAggregations aggregations) { + this.key = key; + this.docCount = docCount; + this.aggregations = aggregations; + } + + public Bucket(StreamInput in) throws IOException { + key = new BucketKey(in); + docCount = in.readVLong(); + aggregations = InternalAggregations.readFrom(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + key.writeTo(out); + out.writeVLong(getDocCount()); + aggregations.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(CommonFields.DOC_COUNT.getPreferredName(), docCount); + builder.field(CommonFields.KEY.getPreferredName()); + key.toXContent(builder, params); + aggregations.toXContentInternal(builder, params); + builder.endObject(); + return builder; + } + + BucketKey getRawKey() { + return key; + } + + @Override + public Object getKey() { + return key; + } + + @Override + public String getKeyAsString() { + return key.asString(); + } + + @Override + public long getDocCount() { + return docCount; + } + + @Override + public Aggregations getAggregations() { + return aggregations; + } + + @Override + public String toString() { + return "Bucket{" + "key=" + getKeyAsString() + ", docCount=" + docCount + ", aggregations=" + aggregations.asMap() + "}\n"; + } + + @Override + public int compareTo(Bucket o) { + return key.compareTo(o.key); + } + + } + + private final List buckets; + private final int maxUniqueTokens; + private final int similarityThreshold; + private final int maxMatchTokens; + private final int requiredSize; + private final long minDocCount; + + protected InternalCategorizationAggregation( + String name, + int requiredSize, + long minDocCount, + int maxUniqueTokens, + int maxMatchTokens, + int similarityThreshold, + Map metadata + ) { + this(name, requiredSize, minDocCount, maxUniqueTokens, maxMatchTokens, similarityThreshold, metadata, new ArrayList<>()); + } + + protected InternalCategorizationAggregation( + String name, + int requiredSize, + long minDocCount, + int maxUniqueTokens, + int maxMatchTokens, + int similarityThreshold, + Map metadata, + List buckets + ) { + super(name, metadata); + this.buckets = buckets; + this.maxUniqueTokens = maxUniqueTokens; + this.maxMatchTokens = maxMatchTokens; + this.similarityThreshold = similarityThreshold; + this.minDocCount = minDocCount; + this.requiredSize = requiredSize; + } + + public InternalCategorizationAggregation(StreamInput in) throws IOException { + super(in); + this.maxUniqueTokens = in.readVInt(); + this.maxMatchTokens = in.readVInt(); + this.similarityThreshold = in.readVInt(); + this.buckets = in.readList(Bucket::new); + this.requiredSize = readSize(in); + this.minDocCount = in.readVLong(); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeVInt(maxUniqueTokens); + out.writeVInt(maxMatchTokens); + out.writeVInt(similarityThreshold); + out.writeList(buckets); + writeSize(requiredSize, out); + out.writeVLong(minDocCount); + } + + @Override + public XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException { + builder.startArray(CommonFields.BUCKETS.getPreferredName()); + for (Bucket bucket : buckets) { + bucket.toXContent(builder, params); + } + builder.endArray(); + return builder; + } + + @Override + public InternalCategorizationAggregation create(List buckets) { + return new InternalCategorizationAggregation( + name, + requiredSize, + minDocCount, + maxUniqueTokens, + maxMatchTokens, + similarityThreshold, + super.metadata, + buckets + ); + } + + @Override + public Bucket createBucket(InternalAggregations aggregations, Bucket prototype) { + return new Bucket(prototype.key, prototype.docCount, aggregations); + } + + @Override + protected Bucket reduceBucket(List buckets, ReduceContext context) { + throw new IllegalArgumentException("For optimization purposes, typical bucket path is not supported"); + } + + @Override + public List getBuckets() { + return buckets; + } + + @Override + public String getWriteableName() { + return CategorizeTextAggregationBuilder.NAME; + } + + @Override + public InternalAggregation reduce(List aggregations, ReduceContext reduceContext) { + try (CategorizationBytesRefHash hash = new CategorizationBytesRefHash(new BytesRefHash(1L, reduceContext.bigArrays()))) { + CategorizationTokenTree categorizationTokenTree = new CategorizationTokenTree( + maxUniqueTokens, + maxMatchTokens, + similarityThreshold + ); + // TODO: Could we do a merge sort similar to terms? + // It would require us returning partial reductions sorted by key, not by doc_count + // First, make sure we have all the counts for equal categorizations + Map reduced = new HashMap<>(); + for (InternalAggregation aggregation : aggregations) { + InternalCategorizationAggregation categorizationAggregation = (InternalCategorizationAggregation) aggregation; + for (Bucket bucket : categorizationAggregation.buckets) { + reduced.computeIfAbsent(bucket.key, key -> new DelayedCategorizationBucket(key, new ArrayList<>(1), 0L)).add(bucket); + } + } + + reduced.values() + .stream() + .sorted(Comparator.comparing(DelayedCategorizationBucket::getDocCount).reversed()) + .forEach(bucket -> + // Parse tokens takes document count into account and merging on smallest groups + categorizationTokenTree.parseTokens(hash.getIds(bucket.key.keyAsTokens()), bucket.docCount) + ); + categorizationTokenTree.mergeSmallestChildren(); + Map mergedBuckets = new HashMap<>(); + for (DelayedCategorizationBucket delayedBucket : reduced.values()) { + TextCategorization group = categorizationTokenTree.parseTokensConst(hash.getIds(delayedBucket.key.keyAsTokens())) + .orElseThrow( + () -> new AggregationExecutionException( + "Unexpected null categorization group for bucket [" + delayedBucket.key.asString() + "]" + ) + ); + BytesRef[] categoryTokens = hash.getDeeps(group.getCategorization()); + + BucketKey key = reduceContext.isFinalReduce() ? + BucketKey.withCollapsedWildcards(categoryTokens) : + new BucketKey(categoryTokens); + mergedBuckets.computeIfAbsent( + key, + k -> new DelayedCategorizationBucket(k, new ArrayList<>(delayedBucket.toReduce.size()), 0L) + ).add(delayedBucket); + } + + final int size = reduceContext.isFinalReduce() == false ? mergedBuckets.size() : Math.min(requiredSize, mergedBuckets.size()); + final PriorityQueue pq = new BucketCountPriorityQueue(size); + for (Map.Entry keyAndBuckets : mergedBuckets.entrySet()) { + final BucketKey key = keyAndBuckets.getKey(); + DelayedCategorizationBucket bucket = keyAndBuckets.getValue(); + Bucket newBucket = bucket.reduce(key, reduceContext); + if ((newBucket.docCount >= minDocCount) || reduceContext.isFinalReduce() == false) { + Bucket removed = pq.insertWithOverflow(newBucket); + if (removed == null) { + reduceContext.consumeBucketsAndMaybeBreak(1); + } else { + reduceContext.consumeBucketsAndMaybeBreak(-countInnerBucket(removed)); + } + } else { + reduceContext.consumeBucketsAndMaybeBreak(-countInnerBucket(newBucket)); + } + } + Bucket[] bucketList = new Bucket[pq.size()]; + for (int i = pq.size() - 1; i >= 0; i--) { + bucketList[i] = pq.pop(); + } + // Keep the top categories top, but then sort by the key for those with duplicate counts + if (reduceContext.isFinalReduce()) { + Arrays.sort(bucketList, Comparator.comparing(Bucket::getDocCount).reversed().thenComparing(Bucket::getRawKey)); + } + return new InternalCategorizationAggregation( + name, + requiredSize, + minDocCount, + maxUniqueTokens, + maxMatchTokens, + similarityThreshold, + metadata, + Arrays.asList(bucketList) + ); + } + } + + public int getMaxUniqueTokens() { + return maxUniqueTokens; + } + + public int getSimilarityThreshold() { + return similarityThreshold; + } + + public int getMaxMatchTokens() { + return maxMatchTokens; + } + + public int getRequiredSize() { + return requiredSize; + } + + public long getMinDocCount() { + return minDocCount; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorization.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorization.java new file mode 100644 index 0000000000000..7ea72f489ae2d --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorization.java @@ -0,0 +1,117 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.RamUsageEstimator; + +import java.util.Arrays; + +import static org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash.WILD_CARD_ID; + +/** + * A text categorization group that provides methods for: + * - calculating similarity between it and a new text + * - expanding the existing categorization by adding a new array of tokens + */ +class TextCategorization implements Accountable { + + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TextCategorization.class); + private final long id; + private final int[] categorization; + private final long[] tokenCounts; + private long count; + + // Used at the shard level for tracking the bucket ordinal for collecting sub aggregations + long bucketOrd; + + TextCategorization(int[] tokenIds, long count, long id) { + this.id = id; + this.categorization = tokenIds; + this.count = count; + this.tokenCounts = new long[tokenIds.length]; + Arrays.fill(this.tokenCounts, count); + } + + public long getId() { + return id; + } + + int[] getCategorization() { + return categorization; + } + + public long getCount() { + return count; + } + + Similarity calculateSimilarity(int[] tokenIds) { + assert tokenIds.length == this.categorization.length; + int eqParams = 0; + long tokenCount = 0; + long tokensKept = 0; + for (int i = 0; i < tokenIds.length; i++) { + if (tokenIds[i] == this.categorization[i]) { + tokensKept += tokenCounts[i]; + tokenCount += tokenCounts[i]; + } else if (this.categorization[i] == WILD_CARD_ID) { + eqParams++; + } else { + tokenCount += tokenCounts[i]; + } + } + return new Similarity((double) tokensKept / tokenCount, eqParams); + } + + void addTokens(int[] tokenIds, long docCount) { + assert tokenIds.length == this.categorization.length; + for (int i = 0; i < tokenIds.length; i++) { + if (tokenIds[i] != this.categorization[i]) { + this.categorization[i] = WILD_CARD_ID; + } else { + tokenCounts[i] += docCount; + } + } + this.count += docCount; + } + + @Override + public long ramBytesUsed() { + return SHALLOW_SIZE + + RamUsageEstimator.sizeOf(categorization) // categorization token Ids + + RamUsageEstimator.sizeOf(tokenCounts); // tokenCounts + } + + static class Similarity implements Comparable { + private final double similarity; + private final int wildCardCount; + + private Similarity(double similarity, int wildCardCount) { + this.similarity = similarity; + this.wildCardCount = wildCardCount; + } + + @Override + public int compareTo(Similarity o) { + int d = Double.compare(similarity, o.similarity); + if (d != 0) { + return d; + } + return Integer.compare(wildCardCount, o.wildCardCount); + } + + public double getSimilarity() { + return similarity; + } + + public int getWildCardCount() { + return wildCardCount; + } + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java new file mode 100644 index 0000000000000..7b13e93d8f1ea --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java @@ -0,0 +1,406 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.search.aggregations.AggregationExecutionException; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.PriorityQueue; +import java.util.stream.Collectors; + +import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_REF; +import static org.apache.lucene.util.RamUsageEstimator.sizeOfCollection; +import static org.apache.lucene.util.RamUsageEstimator.sizeOfMap; +import static org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash.WILD_CARD_ID; + +/** + * Tree node classes for the categorization token tree. + * + * Two major node types exist: + * - Inner: which are nodes that have children token nodes + * - Leaf: Which collection multiple {@link TextCategorization} based on similarity restrictions + */ +abstract class TreeNode implements Accountable { + + private long count; + + TreeNode(long count) { + this.count = count; + } + + abstract void mergeWith(TreeNode otherNode); + + abstract boolean isLeaf(); + + final void incCount(long count) { + this.count += count; + } + + final long getCount() { + return count; + } + + // TODO add option for calculating the cost of adding the new group + abstract TextCategorization addText(int[] tokenIds, long docCount, CategorizationTokenTree treeNodeFactory); + + abstract TextCategorization getCategorization(int[] tokenIds); + + abstract List getAllChildrenTextCategorizations(); + + abstract void collapseTinyChildren(); + + static class LeafTreeNode extends TreeNode { + private final List textCategorizations; + private final int similarityThreshold; + + LeafTreeNode(long count, int similarityThreshold) { + super(count); + this.textCategorizations = new ArrayList<>(); + this.similarityThreshold = similarityThreshold; + if (similarityThreshold < 1 || similarityThreshold > 100) { + throw new IllegalArgumentException("similarityThreshold must be between 1 and 100"); + } + } + + @Override + public boolean isLeaf() { + return true; + } + + @Override + void mergeWith(TreeNode treeNode) { + if (treeNode == null) { + return; + } + if (treeNode.isLeaf() == false) { + throw new UnsupportedOperationException( + "cannot merge leaf node with non-leaf node in categorization tree \n[" + this + "]\n[" + treeNode + "]" + ); + } + incCount(treeNode.getCount()); + LeafTreeNode otherLeaf = (LeafTreeNode) treeNode; + for (TextCategorization group : otherLeaf.textCategorizations) { + if (getAndUpdateTextCategorization(group.getCategorization(), group.getCount()).isPresent() == false) { + putNewTextCategorization(group); + } + } + } + + @Override + public long ramBytesUsed() { + return Long.BYTES // count + + NUM_BYTES_OBJECT_REF // list reference + + Integer.BYTES // similarityThreshold + + sizeOfCollection(textCategorizations); + } + + @Override + public TextCategorization addText(int[] tokenIds, long docCount, CategorizationTokenTree treeNodeFactory) { + return getAndUpdateTextCategorization(tokenIds, docCount).orElseGet(() -> { + // Need to update the tree if possible + TextCategorization categorization = putNewTextCategorization(treeNodeFactory.newCategorization(docCount, tokenIds)); + // Get the regular size bytes from the TextCategorization and how much it costs to reference it + treeNodeFactory.incSize(categorization.ramBytesUsed() + RamUsageEstimator.NUM_BYTES_OBJECT_REF); + return categorization; + }); + } + + @Override + List getAllChildrenTextCategorizations() { + return textCategorizations; + } + + @Override + void collapseTinyChildren() {} + + private Optional getAndUpdateTextCategorization(int[] tokenIds, long docCount) { + return getBestCategorization(tokenIds).map(bestGroupAndSimilarity -> { + if ((bestGroupAndSimilarity.v2() * 100) >= similarityThreshold) { + bestGroupAndSimilarity.v1().addTokens(tokenIds, docCount); + return bestGroupAndSimilarity.v1(); + } + return null; + }); + } + + TextCategorization putNewTextCategorization(TextCategorization categorization) { + textCategorizations.add(categorization); + return categorization; + } + + private Optional> getBestCategorization(int[] tokenIds) { + if (textCategorizations.isEmpty()) { + return Optional.empty(); + } + if (textCategorizations.size() == 1) { + return Optional.of( + new Tuple<>(textCategorizations.get(0), textCategorizations.get(0).calculateSimilarity( tokenIds).getSimilarity()) + ); + } + TextCategorization.Similarity maxSimilarity = null; + TextCategorization bestGroup = null; + for (TextCategorization textCategorization : this.textCategorizations) { + TextCategorization.Similarity groupSimilarity = textCategorization.calculateSimilarity( tokenIds); + if (maxSimilarity == null || groupSimilarity.compareTo(maxSimilarity) > 0) { + maxSimilarity = groupSimilarity; + bestGroup = textCategorization; + } + } + return Optional.of(new Tuple<>(bestGroup, maxSimilarity.getSimilarity())); + } + + @Override + public TextCategorization getCategorization(final int[] tokenIds) { + return getBestCategorization(tokenIds).map(Tuple::v1).orElse(null); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + LeafTreeNode that = (LeafTreeNode) o; + return that.similarityThreshold == similarityThreshold + && Objects.equals(textCategorizations, that.textCategorizations); + } + + @Override + public int hashCode() { + return Objects.hash(textCategorizations, similarityThreshold); + } + } + + static class InnerTreeNode extends TreeNode { + + // TODO: Change to LongObjectMap? + private final Map children; + private final int childrenTokenPos; + private final int maxChildren; + private final PriorityQueue smallestChild; + + InnerTreeNode(long count, int childrenTokenPos, int maxChildren) { + super(count); + children = new HashMap<>(); + this.childrenTokenPos = childrenTokenPos; + this.maxChildren = maxChildren; + this.smallestChild = new PriorityQueue<>(maxChildren, Comparator.comparing(NativeIntLongPair::count)); + } + + @Override + boolean isLeaf() { + return false; + } + + @Override + public TextCategorization getCategorization(final int[] tokenIds) { + return getChild(tokenIds[childrenTokenPos]).or(() -> getChild(WILD_CARD_ID)) + .map(node -> node.getCategorization(tokenIds)) + .orElse(null); + } + + @Override + public long ramBytesUsed() { + return Long.BYTES // count + + NUM_BYTES_OBJECT_REF // children reference + + Integer.BYTES // childrenTokenPos + + Integer.BYTES // maxChildren + + NUM_BYTES_OBJECT_REF // smallestChildReference + + sizeOfMap(children, NUM_BYTES_OBJECT_REF) // children, + // Number of items in the queue, reference to tuple, and then the tuple references + + (long) smallestChild.size() * (NUM_BYTES_OBJECT_REF + Integer.BYTES + Long.BYTES); + } + + @Override + public TextCategorization addText(final int[] tokenIds, final long docCount, final CategorizationTokenTree treeNodeFactory) { + final int currentToken = tokenIds[childrenTokenPos]; + TreeNode child = getChild(currentToken).map(node -> { + node.incCount(docCount); + if (smallestChild.isEmpty() == false && smallestChild.peek().tokenId == currentToken) { + smallestChild.add(smallestChild.poll()); + } + return node; + }).orElseGet(() -> { + TreeNode newNode = treeNodeFactory.newNode(docCount, childrenTokenPos + 1, tokenIds); + // The size of the node + entry (since it is a map entry) + extra reference for priority queue + treeNodeFactory.incSize( + newNode.ramBytesUsed() + RamUsageEstimator.HASHTABLE_RAM_BYTES_PER_ENTRY + RamUsageEstimator.NUM_BYTES_OBJECT_REF + ); + return addChild(currentToken, newNode); + }); + return child.addText(tokenIds, docCount, treeNodeFactory); + } + + @Override + void collapseTinyChildren() { + if (this.isLeaf()) { + return; + } + if (children.size() <= 1) { + return; + } + Optional maybeWildChild = getChild(WILD_CARD_ID).or(() -> { + if ((double) smallestChild.peek().count / this.getCount() <= 1.0 / maxChildren) { + TreeNode tinyChild = children.remove(smallestChild.poll().tokenId); + return Optional.of(addChild(WILD_CARD_ID, tinyChild)); + } + return Optional.empty(); + }); + if (maybeWildChild.isPresent()) { + TreeNode wildChild = maybeWildChild.get(); + NativeIntLongPair tinyNode; + while ((tinyNode = smallestChild.poll()) != null) { + // If we have no more tiny nodes, stop iterating over them + if ((double) tinyNode.count / this.getCount() > 1.0 / maxChildren) { + smallestChild.add(tinyNode); + break; + } else { + wildChild.mergeWith(children.remove(tinyNode.count)); + } + } + } + children.values().forEach(TreeNode::collapseTinyChildren); + } + + @Override + void mergeWith(TreeNode treeNode) { + if (treeNode == null) { + return; + } + incCount(treeNode.count); + if (treeNode.isLeaf()) { + throw new UnsupportedOperationException( + "cannot merge non-leaf node with leaf node in categorization tree \n[" + this + "]\n[" + treeNode + "]" + ); + } + InnerTreeNode innerTreeNode = (InnerTreeNode) treeNode; + TreeNode siblingWildChild = innerTreeNode.children.remove(WILD_CARD_ID); + addChild(WILD_CARD_ID, siblingWildChild); + NativeIntLongPair siblingChild; + while ((siblingChild = innerTreeNode.smallestChild.poll()) != null) { + TreeNode nephewNode = innerTreeNode.children.remove(siblingChild.tokenId); + addChild(siblingChild.tokenId, nephewNode); + } + } + + private TreeNode addChild(int tokenId, TreeNode node) { + if (node == null) { + return null; + } + Optional existingChild = getChild(tokenId).map(existingNode -> { + existingNode.mergeWith(node); + if (smallestChild.isEmpty() == false && smallestChild.peek().tokenId == tokenId) { + smallestChild.poll(); + smallestChild.add(NativeIntLongPair.of(tokenId, existingNode.getCount())); + } + return existingNode; + }); + if (existingChild.isPresent()) { + return existingChild.get(); + } + if (children.size() == maxChildren) { + return getChild(WILD_CARD_ID).map(wildChild -> { + final TreeNode toMerge; + final TreeNode toReturn; + if (smallestChild.isEmpty() == false && node.getCount() > smallestChild.peek().count) { + toMerge = children.remove(smallestChild.poll().tokenId); + addChildAndUpdateSmallest(tokenId, node); + toReturn = node; + } else { + toMerge = node; + toReturn = wildChild; + } + wildChild.mergeWith(toMerge); + return toReturn; + }).orElseThrow(() -> new AggregationExecutionException("Missing wild_card child even though maximum children reached")); + } + // we are about to hit the limit, add a wild card if we need to and then add the new child as appropriate + if (children.size() == maxChildren - 1) { + // If we already have a wild token, simply adding the new token is acceptable as we won't breach our limit + if (children.containsKey(WILD_CARD_ID)) { + addChildAndUpdateSmallest(tokenId, node); + } else { // if we don't have a wild card child, we need to add one now + if (tokenId == WILD_CARD_ID) { + addChildAndUpdateSmallest(tokenId, node); + } else { + if (smallestChild.isEmpty() == false && node.count > smallestChild.peek().count) { + addChildAndUpdateSmallest(WILD_CARD_ID, children.remove(smallestChild.poll().tokenId)); + addChildAndUpdateSmallest(tokenId, node); + } else { + addChildAndUpdateSmallest(WILD_CARD_ID, node); + } + } + } + } else { + addChildAndUpdateSmallest(tokenId, node); + } + return node; + } + + private void addChildAndUpdateSmallest(int tokenId, TreeNode node) { + children.put(tokenId, node); + if (tokenId != WILD_CARD_ID) { + smallestChild.add(NativeIntLongPair.of(tokenId, node.count)); + } + } + + private Optional getChild(int tokenId) { + return Optional.ofNullable(children.get(tokenId)); + } + + public List getAllChildrenTextCategorizations() { + return children.values().stream().flatMap(c -> c.getAllChildrenTextCategorizations().stream()).collect(Collectors.toList()); + } + + boolean hasChild(int tokenId) { + return children.containsKey(tokenId); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + InnerTreeNode treeNode = (InnerTreeNode) o; + return childrenTokenPos == treeNode.childrenTokenPos + && getCount() == treeNode.getCount() + && Objects.equals(children, treeNode.children) + && Objects.equals(smallestChild, treeNode.smallestChild); + } + + @Override + public int hashCode() { + return Objects.hash(children, childrenTokenPos, smallestChild, getCount()); + } + } + + private static class NativeIntLongPair { + private final int tokenId; + private final long count; + + static NativeIntLongPair of(int tokenId, long count) { + return new NativeIntLongPair(tokenId, count); + } + + NativeIntLongPair(int tokenId, long count) { + this.tokenId = tokenId; + this.count = count; + } + + public long count() { + return count; + } + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/UnmappedCategorizationAggregation.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/UnmappedCategorizationAggregation.java new file mode 100644 index 0000000000000..ae1081f66d09f --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/UnmappedCategorizationAggregation.java @@ -0,0 +1,71 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.elasticsearch.search.aggregations.InternalAggregation; +import org.elasticsearch.search.aggregations.InternalAggregations; + +import java.util.List; +import java.util.Map; + + +class UnmappedCategorizationAggregation extends InternalCategorizationAggregation { + protected UnmappedCategorizationAggregation( + String name, + int requiredSize, + long minDocCount, + int maxChildren, + int maxDepth, + int similarityThreshold, + Map metadata + ) { + super(name, requiredSize, minDocCount, maxChildren, maxDepth, similarityThreshold, metadata); + } + + @Override + public InternalCategorizationAggregation create(List buckets) { + return new UnmappedCategorizationAggregation( + name, + getRequiredSize(), + getMinDocCount(), + getMaxUniqueTokens(), + getMaxMatchTokens(), + getSimilarityThreshold(), + super.metadata + ); + } + + @Override + public Bucket createBucket(InternalAggregations aggregations, Bucket prototype) { + throw new UnsupportedOperationException("not supported for UnmappedCategorizationAggregation"); + } + + @Override + public InternalAggregation reduce(List aggregations, ReduceContext reduceContext) { + return new UnmappedCategorizationAggregation( + name, + getRequiredSize(), + getMinDocCount(), + getMaxUniqueTokens(), + getMaxMatchTokens(), + getSimilarityThreshold(), + super.metadata + ); + } + + @Override + public boolean isMapped() { + return false; + } + + @Override + public List getBuckets() { + return List.of(); + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/categorization/CategorizationAnalyzer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/categorization/CategorizationAnalyzer.java index 4fc99e1502851..6147bc0256ca5 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/categorization/CategorizationAnalyzer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/categorization/CategorizationAnalyzer.java @@ -10,11 +10,11 @@ import org.apache.lucene.analysis.TokenStream; import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Tuple; import org.elasticsearch.index.analysis.AnalysisRegistry; import org.elasticsearch.xpack.core.ml.job.config.CategorizationAnalyzerConfig; -import java.io.Closeable; import java.io.IOException; import java.util.ArrayList; import java.util.List; @@ -25,7 +25,7 @@ * Converts messages to lists of tokens that will be fed to the ML categorization algorithm. * */ -public class CategorizationAnalyzer implements Closeable { +public class CategorizationAnalyzer implements Releasable { private final Analyzer analyzer; private final boolean closeAnalyzer; @@ -38,6 +38,16 @@ public CategorizationAnalyzer(AnalysisRegistry analysisRegistry, closeAnalyzer = tuple.v2(); } + public CategorizationAnalyzer(Analyzer analyzer, boolean closeAnalyzer) { + this.analyzer = analyzer; + this.closeAnalyzer = closeAnalyzer; + } + + public final TokenStream tokenStream(final String fieldName, + final String text) { + return analyzer.tokenStream(fieldName, text); + } + /** * Release resources held by the analyzer (unless it's global). */ diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/LocalStateMachineLearning.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/LocalStateMachineLearning.java index 750f36863a141..d6f147f831df7 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/LocalStateMachineLearning.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/LocalStateMachineLearning.java @@ -17,6 +17,9 @@ import org.elasticsearch.common.breaker.NoopCircuitBreaker; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.analysis.CharFilterFactory; +import org.elasticsearch.index.analysis.TokenizerFactory; +import org.elasticsearch.indices.analysis.AnalysisModule; import org.elasticsearch.license.LicenseService; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.plugins.ActionPlugin; @@ -33,6 +36,7 @@ import java.util.Collections; import java.util.HashSet; import java.util.List; +import java.util.Map; import static org.elasticsearch.xpack.ml.MachineLearning.TRAINED_MODEL_CIRCUIT_BREAKER_NAME; @@ -84,6 +88,20 @@ public void cleanUpFeature( mlPlugin.cleanUpFeature(clusterService, client, finalListener); } + @Override + public List getAggregations() { + return mlPlugin.getAggregations(); + } + + @Override + public Map> getCharFilters() { + return mlPlugin.getCharFilters(); + } + + @Override + public Map> getTokenizers() { + return mlPlugin.getTokenizers(); + } /** * This is only required as we now have to have the GetRollupIndexCapsAction as a valid action in our node. diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilderTests.java new file mode 100644 index 0000000000000..7b907ea3ecd29 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilderTests.java @@ -0,0 +1,60 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.search.aggregations.BaseAggregationTestCase; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.job.config.CategorizationAnalyzerConfigTests; + +import java.util.Collection; +import java.util.Collections; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class CategorizeTextAggregationBuilderTests extends BaseAggregationTestCase { + + @Override + protected Collection> getExtraPlugins() { + return Collections.singletonList(MachineLearning.class); + } + + @Override + protected CategorizeTextAggregationBuilder createTestAggregatorBuilder() { + CategorizeTextAggregationBuilder builder = new CategorizeTextAggregationBuilder(randomAlphaOfLength(10), randomAlphaOfLength(10)); + final boolean setFilters = randomBoolean(); + if (setFilters) { + builder.setCategorizationFilters(Stream.generate(() -> randomAlphaOfLength(10)).limit(5).collect(Collectors.toList())); + } + if (setFilters == false) { + builder.setCategorizationAnalyzerConfig(CategorizationAnalyzerConfigTests.createRandomized().build()); + } + if (randomBoolean()) { + builder.setMaxUniqueTokens(randomIntBetween(1, 500)); + } + if (randomBoolean()) { + builder.setMaxMatchedTokens(randomIntBetween(1, 10)); + } + if (randomBoolean()) { + builder.setSimilarityThreshold(randomIntBetween(1, 100)); + } + if (randomBoolean()) { + builder.minDocCount(randomLongBetween(1, 100)); + } + if (randomBoolean()) { + builder.shardMinDocCount(randomLongBetween(1, 100)); + } + if (randomBoolean()) { + builder.size(randomIntBetween(1, 100)); + } + if (randomBoolean()) { + builder.shardSize(randomIntBetween(1, 100)); + } + return builder; + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorTests.java new file mode 100644 index 0000000000000..95cfdcb0f8f8f --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorTests.java @@ -0,0 +1,342 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.apache.lucene.document.SortedNumericDocValuesField; +import org.apache.lucene.document.StoredField; +import org.apache.lucene.index.RandomIndexWriter; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.env.Environment; +import org.elasticsearch.env.TestEnvironment; +import org.elasticsearch.index.mapper.TextFieldMapper; +import org.elasticsearch.indices.analysis.AnalysisModule; +import org.elasticsearch.plugins.SearchPlugin; +import org.elasticsearch.search.aggregations.AggregatorTestCase; +import org.elasticsearch.search.aggregations.bucket.histogram.Histogram; +import org.elasticsearch.search.aggregations.bucket.histogram.HistogramAggregationBuilder; +import org.elasticsearch.search.aggregations.bucket.histogram.InternalHistogram; +import org.elasticsearch.search.aggregations.metrics.Avg; +import org.elasticsearch.search.aggregations.metrics.AvgAggregationBuilder; +import org.elasticsearch.search.aggregations.metrics.Max; +import org.elasticsearch.search.aggregations.metrics.MaxAggregationBuilder; +import org.elasticsearch.search.aggregations.metrics.Min; +import org.elasticsearch.search.aggregations.metrics.MinAggregationBuilder; +import org.elasticsearch.xpack.ml.MachineLearning; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; + +public class CategorizeTextAggregatorTests extends AggregatorTestCase { + + @Override + protected AnalysisModule createAnalysisModule() throws Exception { + return new AnalysisModule( + TestEnvironment.newEnvironment( + Settings.builder().put(Environment.PATH_HOME_SETTING.getKey(), createTempDir().toString()).build() + ), + List.of(new MachineLearning(Settings.EMPTY, null)) + ); + } + + @Override + protected List getSearchPlugins() { + return List.of(new MachineLearning(Settings.EMPTY, null)); + } + + private static final String TEXT_FIELD_NAME = "text"; + private static final String NUMERIC_FIELD_NAME = "value"; + + public void testCategorizationWithoutSubAggs() throws Exception { + testCase( + new CategorizeTextAggregationBuilder("my_agg", TEXT_FIELD_NAME), + new MatchAllDocsQuery(), + CategorizeTextAggregatorTests::writeTestDocs, + (InternalCategorizationAggregation result) -> { + assertThat(result.getBuckets(), hasSize(2)); + assertThat(result.getBuckets().get(0).docCount, equalTo(6L)); + assertThat(result.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); + assertThat(result.getBuckets().get(1).docCount, equalTo(2L)); + assertThat( + result.getBuckets().get(1).getKeyAsString(), + equalTo("Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception") + ); + }, + new TextFieldMapper.TextFieldType(TEXT_FIELD_NAME), + longField(NUMERIC_FIELD_NAME) + ); + } + + public void testCategorizationWithSubAggs() throws Exception { + CategorizeTextAggregationBuilder aggBuilder = new CategorizeTextAggregationBuilder("my_agg", TEXT_FIELD_NAME).subAggregation( + new MaxAggregationBuilder("max").field(NUMERIC_FIELD_NAME) + ) + .subAggregation(new AvgAggregationBuilder("avg").field(NUMERIC_FIELD_NAME)) + .subAggregation(new MinAggregationBuilder("min").field(NUMERIC_FIELD_NAME)); + testCase( + aggBuilder, + new MatchAllDocsQuery(), + CategorizeTextAggregatorTests::writeTestDocs, + (InternalCategorizationAggregation result) -> { + assertThat(result.getBuckets(), hasSize(2)); + assertThat(result.getBuckets().get(0).docCount, equalTo(6L)); + assertThat(result.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); + assertThat(((Max) result.getBuckets().get(0).aggregations.get("max")).getValue(), equalTo(5.0)); + assertThat(((Min) result.getBuckets().get(0).aggregations.get("min")).getValue(), equalTo(0.0)); + assertThat(((Avg) result.getBuckets().get(0).aggregations.get("avg")).getValue(), equalTo(2.5)); + + assertThat(result.getBuckets().get(1).docCount, equalTo(2L)); + assertThat( + result.getBuckets().get(1).getKeyAsString(), + equalTo("Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception") + ); + assertThat(((Max) result.getBuckets().get(1).aggregations.get("max")).getValue(), equalTo(4.0)); + assertThat(((Min) result.getBuckets().get(1).aggregations.get("min")).getValue(), equalTo(0.0)); + assertThat(((Avg) result.getBuckets().get(1).aggregations.get("avg")).getValue(), equalTo(2.0)); + }, + new TextFieldMapper.TextFieldType(TEXT_FIELD_NAME), + longField(NUMERIC_FIELD_NAME) + ); + } + + public void testCategorizationWithMultiBucketSubAggs() throws Exception { + CategorizeTextAggregationBuilder aggBuilder = new CategorizeTextAggregationBuilder("my_agg", TEXT_FIELD_NAME).subAggregation( + new HistogramAggregationBuilder("histo").field(NUMERIC_FIELD_NAME) + .interval(2) + .subAggregation(new MaxAggregationBuilder("max").field(NUMERIC_FIELD_NAME)) + .subAggregation(new AvgAggregationBuilder("avg").field(NUMERIC_FIELD_NAME)) + .subAggregation(new MinAggregationBuilder("min").field(NUMERIC_FIELD_NAME)) + ); + testCase( + aggBuilder, + new MatchAllDocsQuery(), + CategorizeTextAggregatorTests::writeTestDocs, + (InternalCategorizationAggregation result) -> { + assertThat(result.getBuckets(), hasSize(2)); + assertThat(result.getBuckets().get(0).docCount, equalTo(6L)); + assertThat(result.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); + Histogram histo = result.getBuckets().get(0).aggregations.get("histo"); + assertThat(histo.getBuckets(), hasSize(3)); + for (Histogram.Bucket bucket : histo.getBuckets()) { + assertThat(bucket.getDocCount(), equalTo(2L)); + } + assertThat(((Max) histo.getBuckets().get(0).getAggregations().get("max")).getValue(), equalTo(1.0)); + assertThat(((Min) histo.getBuckets().get(0).getAggregations().get("min")).getValue(), equalTo(0.0)); + assertThat(((Avg) histo.getBuckets().get(0).getAggregations().get("avg")).getValue(), equalTo(0.5)); + assertThat(((Max) histo.getBuckets().get(1).getAggregations().get("max")).getValue(), equalTo(3.0)); + assertThat(((Min) histo.getBuckets().get(1).getAggregations().get("min")).getValue(), equalTo(2.0)); + assertThat(((Avg) histo.getBuckets().get(1).getAggregations().get("avg")).getValue(), equalTo(2.5)); + assertThat(((Max) histo.getBuckets().get(2).getAggregations().get("max")).getValue(), equalTo(5.0)); + assertThat(((Min) histo.getBuckets().get(2).getAggregations().get("min")).getValue(), equalTo(4.0)); + assertThat(((Avg) histo.getBuckets().get(2).getAggregations().get("avg")).getValue(), equalTo(4.5)); + + assertThat(result.getBuckets().get(1).docCount, equalTo(2L)); + assertThat( + result.getBuckets().get(1).getKeyAsString(), + equalTo("Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception") + ); + histo = result.getBuckets().get(1).aggregations.get("histo"); + assertThat(histo.getBuckets(), hasSize(3)); + assertThat(histo.getBuckets().get(0).getDocCount(), equalTo(1L)); + assertThat(histo.getBuckets().get(1).getDocCount(), equalTo(0L)); + assertThat(histo.getBuckets().get(2).getDocCount(), equalTo(1L)); + assertThat(((Avg) histo.getBuckets().get(0).getAggregations().get("avg")).getValue(), equalTo(0.0)); + assertThat(((Avg) histo.getBuckets().get(2).getAggregations().get("avg")).getValue(), equalTo(4.0)); + }, + new TextFieldMapper.TextFieldType(TEXT_FIELD_NAME), + longField(NUMERIC_FIELD_NAME) + ); + } + + public void testCategorizationAsSubAgg() throws Exception { + HistogramAggregationBuilder aggBuilder = new HistogramAggregationBuilder("histo").field(NUMERIC_FIELD_NAME) + .interval(2) + .subAggregation( + new CategorizeTextAggregationBuilder("my_agg", TEXT_FIELD_NAME).subAggregation( + new MaxAggregationBuilder("max").field(NUMERIC_FIELD_NAME) + ) + .subAggregation(new AvgAggregationBuilder("avg").field(NUMERIC_FIELD_NAME)) + .subAggregation(new MinAggregationBuilder("min").field(NUMERIC_FIELD_NAME)) + ); + testCase( + aggBuilder, + new MatchAllDocsQuery(), + CategorizeTextAggregatorTests::writeTestDocs, + (InternalHistogram result) -> { + assertThat(result.getBuckets(), hasSize(3)); + + // First histo bucket + assertThat(result.getBuckets().get(0).getDocCount(), equalTo(3L)); + InternalCategorizationAggregation categorizationAggregation = result.getBuckets().get(0).getAggregations().get("my_agg"); + assertThat(categorizationAggregation.getBuckets(), hasSize(2)); + assertThat(categorizationAggregation.getBuckets().get(0).docCount, equalTo(2L)); + assertThat(categorizationAggregation.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); + assertThat(((Max) categorizationAggregation.getBuckets().get(0).aggregations.get("max")).getValue(), equalTo(1.0)); + assertThat(((Min) categorizationAggregation.getBuckets().get(0).aggregations.get("min")).getValue(), equalTo(0.0)); + assertThat(((Avg) categorizationAggregation.getBuckets().get(0).aggregations.get("avg")).getValue(), equalTo(0.5)); + + assertThat(categorizationAggregation.getBuckets().get(1).docCount, equalTo(1L)); + assertThat( + categorizationAggregation.getBuckets().get(1).getKeyAsString(), + equalTo("Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception") + ); + assertThat(((Max) categorizationAggregation.getBuckets().get(1).aggregations.get("max")).getValue(), equalTo(0.0)); + assertThat(((Min) categorizationAggregation.getBuckets().get(1).aggregations.get("min")).getValue(), equalTo(0.0)); + assertThat(((Avg) categorizationAggregation.getBuckets().get(1).aggregations.get("avg")).getValue(), equalTo(0.0)); + + // Second histo bucket + assertThat(result.getBuckets().get(1).getDocCount(), equalTo(2L)); + categorizationAggregation = result.getBuckets().get(1).getAggregations().get("my_agg"); + assertThat(categorizationAggregation.getBuckets(), hasSize(1)); + assertThat(categorizationAggregation.getBuckets().get(0).docCount, equalTo(2L)); + assertThat(categorizationAggregation.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); + assertThat(((Max) categorizationAggregation.getBuckets().get(0).aggregations.get("max")).getValue(), equalTo(3.0)); + assertThat(((Min) categorizationAggregation.getBuckets().get(0).aggregations.get("min")).getValue(), equalTo(2.0)); + assertThat(((Avg) categorizationAggregation.getBuckets().get(0).aggregations.get("avg")).getValue(), equalTo(2.5)); + + // Third histo bucket + assertThat(result.getBuckets().get(2).getDocCount(), equalTo(3L)); + categorizationAggregation = result.getBuckets().get(2).getAggregations().get("my_agg"); + assertThat(categorizationAggregation.getBuckets(), hasSize(2)); + assertThat(categorizationAggregation.getBuckets().get(0).docCount, equalTo(2L)); + assertThat(categorizationAggregation.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); + assertThat(((Max) categorizationAggregation.getBuckets().get(0).aggregations.get("max")).getValue(), equalTo(5.0)); + assertThat(((Min) categorizationAggregation.getBuckets().get(0).aggregations.get("min")).getValue(), equalTo(4.0)); + assertThat(((Avg) categorizationAggregation.getBuckets().get(0).aggregations.get("avg")).getValue(), equalTo(4.5)); + + assertThat(categorizationAggregation.getBuckets().get(1).docCount, equalTo(1L)); + assertThat( + categorizationAggregation.getBuckets().get(1).getKeyAsString(), + equalTo("Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception") + ); + assertThat(((Max) categorizationAggregation.getBuckets().get(1).aggregations.get("max")).getValue(), equalTo(4.0)); + assertThat(((Min) categorizationAggregation.getBuckets().get(1).aggregations.get("min")).getValue(), equalTo(4.0)); + assertThat(((Avg) categorizationAggregation.getBuckets().get(1).aggregations.get("avg")).getValue(), equalTo(4.0)); + }, + new TextFieldMapper.TextFieldType(TEXT_FIELD_NAME), + longField(NUMERIC_FIELD_NAME) + ); + } + + public void testCategorizationWithSubAggsManyDocs() throws Exception { + CategorizeTextAggregationBuilder aggBuilder = new CategorizeTextAggregationBuilder("my_agg", TEXT_FIELD_NAME).subAggregation( + new HistogramAggregationBuilder("histo").field(NUMERIC_FIELD_NAME) + .interval(2) + .subAggregation(new MaxAggregationBuilder("max").field(NUMERIC_FIELD_NAME)) + .subAggregation(new AvgAggregationBuilder("avg").field(NUMERIC_FIELD_NAME)) + .subAggregation(new MinAggregationBuilder("min").field(NUMERIC_FIELD_NAME)) + ); + testCase( + aggBuilder, + new MatchAllDocsQuery(), + CategorizeTextAggregatorTests::writeManyTestDocs, + (InternalCategorizationAggregation result) -> { + assertThat(result.getBuckets(), hasSize(2)); + assertThat(result.getBuckets().get(0).docCount, equalTo(30_000L)); + assertThat(result.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); + Histogram histo = result.getBuckets().get(0).aggregations.get("histo"); + assertThat(histo.getBuckets(), hasSize(3)); + for (Histogram.Bucket bucket : histo.getBuckets()) { + assertThat(bucket.getDocCount(), equalTo(10_000L)); + } + assertThat(((Max) histo.getBuckets().get(0).getAggregations().get("max")).getValue(), equalTo(1.0)); + assertThat(((Min) histo.getBuckets().get(0).getAggregations().get("min")).getValue(), equalTo(0.0)); + assertThat(((Avg) histo.getBuckets().get(0).getAggregations().get("avg")).getValue(), equalTo(0.5)); + assertThat(((Max) histo.getBuckets().get(1).getAggregations().get("max")).getValue(), equalTo(3.0)); + assertThat(((Min) histo.getBuckets().get(1).getAggregations().get("min")).getValue(), equalTo(2.0)); + assertThat(((Avg) histo.getBuckets().get(1).getAggregations().get("avg")).getValue(), equalTo(2.5)); + assertThat(((Max) histo.getBuckets().get(2).getAggregations().get("max")).getValue(), equalTo(5.0)); + assertThat(((Min) histo.getBuckets().get(2).getAggregations().get("min")).getValue(), equalTo(4.0)); + assertThat(((Avg) histo.getBuckets().get(2).getAggregations().get("avg")).getValue(), equalTo(4.5)); + + assertThat(result.getBuckets().get(1).docCount, equalTo(10_000L)); + assertThat( + result.getBuckets().get(1).getKeyAsString(), + equalTo("Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception") + ); + histo = result.getBuckets().get(1).aggregations.get("histo"); + assertThat(histo.getBuckets(), hasSize(3)); + assertThat(histo.getBuckets().get(0).getDocCount(), equalTo(5_000L)); + assertThat(histo.getBuckets().get(1).getDocCount(), equalTo(0L)); + assertThat(histo.getBuckets().get(2).getDocCount(), equalTo(5_000L)); + assertThat(((Avg) histo.getBuckets().get(0).getAggregations().get("avg")).getValue(), equalTo(0.0)); + assertThat(((Avg) histo.getBuckets().get(2).getAggregations().get("avg")).getValue(), equalTo(4.0)); + }, + new TextFieldMapper.TextFieldType(TEXT_FIELD_NAME), + longField(NUMERIC_FIELD_NAME) + ); + } + + private static void writeTestDocs(RandomIndexWriter w) throws IOException { + w.addDocument( + Arrays.asList( + new StoredField("_source", new BytesRef("{\"text\":\"Node 1 started\"}")), + new SortedNumericDocValuesField(NUMERIC_FIELD_NAME, 0) + ) + ); + w.addDocument( + Arrays.asList( + new StoredField("_source", new BytesRef("{\"text\":\"Node 1 started\"}")), + new SortedNumericDocValuesField(NUMERIC_FIELD_NAME, 1) + ) + ); + w.addDocument( + Arrays.asList( + new StoredField( + "_source", + new BytesRef("{\"text\":\"Failed to shutdown [error org.aaaa.bbbb.Cccc line 54 caused by foo exception]\"}") + ), + new SortedNumericDocValuesField(NUMERIC_FIELD_NAME, 0) + ) + ); + w.addDocument( + Arrays.asList( + new StoredField( + "_source", + new BytesRef("{\"text\":\"Failed to shutdown [error org.aaaa.bbbb.Cccc line 54 caused by foo exception]\"}") + ), + new SortedNumericDocValuesField(NUMERIC_FIELD_NAME, 4) + ) + ); + w.addDocument( + Arrays.asList( + new StoredField("_source", new BytesRef("{\"text\":\"Node 2 started\"}")), + new SortedNumericDocValuesField(NUMERIC_FIELD_NAME, 2) + ) + ); + w.addDocument( + Arrays.asList( + new StoredField("_source", new BytesRef("{\"text\":\"Node 2 started\"}")), + new SortedNumericDocValuesField(NUMERIC_FIELD_NAME, 3) + ) + ); + w.addDocument( + Arrays.asList( + new StoredField("_source", new BytesRef("{\"text\":\"Node 3 started\"}")), + new SortedNumericDocValuesField(NUMERIC_FIELD_NAME, 4) + ) + ); + w.addDocument( + Arrays.asList( + new StoredField("_source", new BytesRef("{\"text\":\"Node 3 started\"}")), + new SortedNumericDocValuesField(NUMERIC_FIELD_NAME, 5) + ) + ); + } + + private static void writeManyTestDocs(RandomIndexWriter w) throws IOException { + for (int i = 0; i < 5_000; i++) { + writeTestDocs(w); + } + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InnerTreeNodeTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InnerTreeNodeTests.java new file mode 100644 index 0000000000000..e7f78a01d0130 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InnerTreeNodeTests.java @@ -0,0 +1,123 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.util.BytesRefHash; +import org.elasticsearch.test.ESTestCase; +import org.junit.After; +import org.junit.Before; + +import static org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash.WILD_CARD_ID; +import static org.elasticsearch.xpack.ml.aggs.categorization.TextCategorizationTests.getTokens; +import static org.elasticsearch.xpack.ml.aggs.categorization.TextCategorizationTests.mockBigArrays; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; + +public class InnerTreeNodeTests extends ESTestCase { + + private final CategorizationTokenTree factory = new CategorizationTokenTree(3, 4, 60); + private CategorizationBytesRefHash bytesRefHash; + + @Before + public void createRefHash() { + bytesRefHash = new CategorizationBytesRefHash(new BytesRefHash(1L, mockBigArrays())); + } + + @After + public void closeRefHash() { + bytesRefHash.close(); + } + + public void testAddText() { + TreeNode.InnerTreeNode innerTreeNode = new TreeNode.InnerTreeNode(1, 0, 3); + TextCategorization group = innerTreeNode.addText(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 1, factory); + assertArrayEquals(group.getCategorization(), getTokens(bytesRefHash, "foo", "bar", "baz", "biz")); + + assertArrayEquals( + innerTreeNode.addText(getTokens(bytesRefHash, "foo2", "bar", "baz", "biz"), 1, factory).getCategorization(), + getTokens(bytesRefHash, "foo2", "bar", "baz", "biz") + ); + assertArrayEquals( + innerTreeNode.addText(getTokens(bytesRefHash, "foo3", "bar", "baz", "biz"), 1, factory).getCategorization(), + getTokens(bytesRefHash, "foo3", "bar", "baz", "biz") + ); + assertArrayEquals( + innerTreeNode.addText(getTokens(bytesRefHash, "foo4", "bar", "baz", "biz"), 1, factory).getCategorization(), + getTokens(bytesRefHash, "*", "bar", "baz", "biz") + ); + assertArrayEquals( + innerTreeNode.addText(getTokens(bytesRefHash, "foo", "bar", "baz", "bizzy"), 1, factory).getCategorization(), + getTokens(bytesRefHash, "foo", "bar", "baz", "*") + ); + } + + public void testAddTokensWithLargerIncoming() { + TreeNode.InnerTreeNode innerTreeNode = new TreeNode.InnerTreeNode(1, 0, 3); + TextCategorization group = innerTreeNode.addText(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 100, factory); + assertArrayEquals(group.getCategorization(), getTokens(bytesRefHash, "foo", "bar", "baz", "biz")); + + assertArrayEquals( + innerTreeNode.addText(getTokens(bytesRefHash, "foo2", "bar", "baz", "biz"), 100, factory).getCategorization(), + getTokens(bytesRefHash, "foo2", "bar", "baz", "biz") + ); + assertArrayEquals( + innerTreeNode.addText(getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz"), 1, factory).getCategorization(), + getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz") + ); + assertArrayEquals( + innerTreeNode.addText(getTokens(bytesRefHash, "foobigun", "bar", "baz", "biz"), 1000, factory).getCategorization(), + getTokens(bytesRefHash, "foobigun", "bar", "baz", "biz") + ); + assertThat( + innerTreeNode.getCategorization(getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz")).getCategorization(), + equalTo(getTokens(bytesRefHash, "*", "bar", "baz", "biz")) + ); + } + + public void testCollapseTinyChildren() { + TreeNode.InnerTreeNode innerTreeNode = new TreeNode.InnerTreeNode(1000, 0, 4); + TextCategorization group = innerTreeNode.addText(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 1000, factory); + assertArrayEquals(group.getCategorization(), getTokens(bytesRefHash, "foo", "bar", "baz", "biz")); + + assertArrayEquals( + innerTreeNode.addText(getTokens(bytesRefHash, "foo2", "bar", "baz", "biz"), 1000, factory).getCategorization(), + getTokens(bytesRefHash, "foo2", "bar", "baz", "biz") + ); + innerTreeNode.incCount(1000); + assertArrayEquals( + innerTreeNode.addText(getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz"), 1, factory).getCategorization(), + getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz") + ); + innerTreeNode.incCount(1); + innerTreeNode.collapseTinyChildren(); + assertThat(innerTreeNode.hasChild(bytesRefHash.put(new BytesRef("foosmall"))), is(false)); + assertThat(innerTreeNode.hasChild(WILD_CARD_ID), is(true)); + } + + public void testMergeWith() { + TreeNode.InnerTreeNode innerTreeNode = new TreeNode.InnerTreeNode(1000, 0, 3); + innerTreeNode.addText(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 1000, factory); + innerTreeNode.incCount(1000); + innerTreeNode.addText(getTokens(bytesRefHash, "foo2", "bar", "baz", "biz"), 1000, factory); + + expectThrows(UnsupportedOperationException.class, () -> innerTreeNode.mergeWith(new TreeNode.LeafTreeNode(1, 60))); + + TreeNode.InnerTreeNode mergeWith = new TreeNode.InnerTreeNode(1, 0, 3); + innerTreeNode.addText(getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz"), 1, factory); + innerTreeNode.incCount(1); + innerTreeNode.addText(getTokens(bytesRefHash, "footiny", "bar", "baz", "biz"), 1, factory); + + innerTreeNode.mergeWith(mergeWith); + assertThat(innerTreeNode.hasChild(WILD_CARD_ID), is(true)); + assertArrayEquals( + innerTreeNode.getCategorization(getTokens(bytesRefHash, "footiny", "bar", "baz", "biz")).getCategorization(), + getTokens(bytesRefHash, "*", "bar", "baz", "biz") + ); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregationTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregationTests.java new file mode 100644 index 0000000000000..749863b336fa2 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregationTests.java @@ -0,0 +1,117 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.CollectionUtils; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.ParseField; +import org.elasticsearch.plugins.SearchPlugin; +import org.elasticsearch.search.aggregations.Aggregation; +import org.elasticsearch.search.aggregations.InternalAggregations; +import org.elasticsearch.search.aggregations.ParsedMultiBucketAggregation; +import org.elasticsearch.test.InternalMultiBucketAggregationTestCase; +import org.elasticsearch.xpack.ml.MachineLearning; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class InternalCategorizationAggregationTests extends InternalMultiBucketAggregationTestCase { + + @Override + protected SearchPlugin registerPlugin() { + return new MachineLearning(Settings.EMPTY, null); + } + + @Override + protected List getNamedXContents() { + return CollectionUtils.appendToCopy( + super.getNamedXContents(), + new NamedXContentRegistry.Entry( + Aggregation.class, + new ParseField(CategorizeTextAggregationBuilder.NAME), + (p, c) -> ParsedCategorization.fromXContent(p, (String) c) + ) + ); + } + + @Override + protected void assertReduced(InternalCategorizationAggregation reduced, List inputs) { + Map reducedCounts = toCounts(reduced.getBuckets().stream()); + Map totalCounts = toCounts(inputs.stream().map(InternalCategorizationAggregation::getBuckets).flatMap(List::stream)); + + Map expectedReducedCounts = new HashMap<>(totalCounts); + expectedReducedCounts.keySet().retainAll(reducedCounts.keySet()); + assertEquals(expectedReducedCounts, reducedCounts); + } + + @Override + protected Predicate excludePathsFromXContentInsertion() { + return p -> p.contains("key"); + } + + static InternalCategorizationAggregation.BucketKey randomKey() { + int numVals = randomIntBetween(1, 50); + return new InternalCategorizationAggregation.BucketKey( + Stream.generate(() -> randomAlphaOfLength(10)).limit(numVals).map(BytesRef::new).toArray(BytesRef[]::new) + ); + } + + @Override + protected InternalCategorizationAggregation createTestInstance( + String name, + Map metadata, + InternalAggregations aggregations + ) { + List buckets = new ArrayList<>(); + final int numBuckets = randomNumberOfBuckets(); + HashSet keys = new HashSet<>(); + for (int i = 0; i < numBuckets; ++i) { + InternalCategorizationAggregation.BucketKey key = randomValueOtherThanMany( + l -> keys.add(l) == false, + InternalCategorizationAggregationTests::randomKey + ); + int docCount = randomIntBetween(1, 100); + buckets.add(new InternalCategorizationAggregation.Bucket(key, docCount, aggregations)); + } + Collections.sort(buckets); + return new InternalCategorizationAggregation( + name, + randomIntBetween(10, 100), + randomLongBetween(1, 10), + randomIntBetween(1, 500), + randomIntBetween(1, 10), + randomIntBetween(1, 100), + metadata, + buckets + ); + } + + @Override + protected Class> implementationClass() { + return ParsedCategorization.class; + } + + private static Map toCounts(Stream buckets) { + return buckets.collect( + Collectors.toMap( + InternalCategorizationAggregation.Bucket::getKey, + InternalCategorizationAggregation.Bucket::getDocCount, + Long::sum + ) + ); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LeafTreeNodeTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LeafTreeNodeTests.java new file mode 100644 index 0000000000000..2bef18993e019 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LeafTreeNodeTests.java @@ -0,0 +1,88 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.elasticsearch.common.util.BytesRefHash; +import org.elasticsearch.test.ESTestCase; +import org.junit.After; +import org.junit.Before; + +import static org.elasticsearch.xpack.ml.aggs.categorization.TextCategorizationTests.getTokens; +import static org.elasticsearch.xpack.ml.aggs.categorization.TextCategorizationTests.mockBigArrays; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.hasSize; + +public class LeafTreeNodeTests extends ESTestCase { + + private final CategorizationTokenTree factory = new CategorizationTokenTree(10, 10, 60); + + private CategorizationBytesRefHash bytesRefHash; + + @Before + public void createRefHash() { + bytesRefHash = new CategorizationBytesRefHash(new BytesRefHash(1L, mockBigArrays())); + } + + @After + public void closeRefHash() { + bytesRefHash.close(); + } + + public void testAddGroup() { + TreeNode.LeafTreeNode leafTreeNode = new TreeNode.LeafTreeNode(0, 60); + TextCategorization group = leafTreeNode.addText(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 1, factory); + + assertArrayEquals(group.getCategorization(), getTokens(bytesRefHash, "foo", "bar", "baz", "biz")); + assertThat(group.getCount(), equalTo(1L)); + assertThat(leafTreeNode.getAllChildrenTextCategorizations(), hasSize(1)); + long previousBytesUsed = leafTreeNode.ramBytesUsed(); + + group = leafTreeNode.addText(getTokens(bytesRefHash, "foo", "bar", "bozo", "bizzy"), 1, factory); + assertArrayEquals(group.getCategorization(), getTokens(bytesRefHash, "foo", "bar", "bozo", "bizzy")); + assertThat(group.getCount(), equalTo(1L)); + assertThat(leafTreeNode.getAllChildrenTextCategorizations(), hasSize(2)); + assertThat(leafTreeNode.ramBytesUsed(), greaterThan(previousBytesUsed)); + previousBytesUsed = leafTreeNode.ramBytesUsed(); + + group = leafTreeNode.addText(getTokens(bytesRefHash, "foo", "bar", "baz", "different"), 3, factory); + assertArrayEquals(group.getCategorization(), getTokens(bytesRefHash, "foo", "bar", "baz", "*")); + assertThat(group.getCount(), equalTo(4L)); + assertThat(leafTreeNode.getAllChildrenTextCategorizations(), hasSize(2)); + assertThat(previousBytesUsed, equalTo(leafTreeNode.ramBytesUsed())); + } + + public void testMergeWith() { + TreeNode.LeafTreeNode leafTreeNode = new TreeNode.LeafTreeNode(0, 60); + leafTreeNode.mergeWith(null); + assertThat(leafTreeNode, equalTo(new TreeNode.LeafTreeNode(0, 60))); + + expectThrows(UnsupportedOperationException.class, () -> leafTreeNode.mergeWith(new TreeNode.InnerTreeNode(1, 2, 3))); + + leafTreeNode.incCount(5); + leafTreeNode.addText(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 5, factory); + + TreeNode.LeafTreeNode toMerge = new TreeNode.LeafTreeNode(0, 60); + leafTreeNode.incCount(1); + toMerge.addText(getTokens(bytesRefHash, "foo", "bar", "baz", "bizzy"), 1, factory); + leafTreeNode.incCount(1); + toMerge.addText(getTokens(bytesRefHash, "foo", "bart", "bat", "built"), 1, factory); + leafTreeNode.mergeWith(toMerge); + + assertThat(leafTreeNode.getAllChildrenTextCategorizations(), hasSize(2)); + assertThat(leafTreeNode.getCount(), equalTo(7L)); + assertArrayEquals( + leafTreeNode.getAllChildrenTextCategorizations().get(0).getCategorization(), + getTokens(bytesRefHash, "foo", "bar", "baz", "*") + ); + assertArrayEquals( + leafTreeNode.getAllChildrenTextCategorizations().get(1).getCategorization(), + getTokens(bytesRefHash, "foo", "bart", "bat", "built") + ); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/ParsedCategorization.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/ParsedCategorization.java new file mode 100644 index 0000000000000..b554f9cfc43e1 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/ParsedCategorization.java @@ -0,0 +1,113 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentParserUtils; +import org.elasticsearch.search.aggregations.ParsedMultiBucketAggregation; +import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + +class ParsedCategorization extends ParsedMultiBucketAggregation { + + @Override + public String getType() { + return CategorizeTextAggregationBuilder.NAME; + } + + private static final ObjectParser PARSER = new ObjectParser<>( + ParsedCategorization.class.getSimpleName(), + true, + ParsedCategorization::new + ); + static { + declareMultiBucketAggregationFields(PARSER, ParsedBucket::fromXContent, ParsedBucket::fromXContent); + } + + public static ParsedCategorization fromXContent(XContentParser parser, String name) throws IOException { + ParsedCategorization aggregation = PARSER.parse(parser, null); + aggregation.setName(name); + return aggregation; + } + + @Override + public List getBuckets() { + return buckets; + } + + public static class ParsedBucket extends ParsedMultiBucketAggregation.ParsedBucket implements MultiBucketsAggregation.Bucket { + + private InternalCategorizationAggregation.BucketKey key; + + protected void setKeyAsString(String keyAsString) { + if (keyAsString == null) { + key = null; + return; + } + if (keyAsString.isEmpty()) { + key = new InternalCategorizationAggregation.BucketKey(new BytesRef[0]); + return; + } + String[] split = Strings.tokenizeToStringArray(keyAsString, " "); + key = new InternalCategorizationAggregation.BucketKey( + split == null + ? new BytesRef[] { new BytesRef(keyAsString) } + : Arrays.stream(split).map(BytesRef::new).toArray(BytesRef[]::new) + ); + } + + @Override + public Object getKey() { + return key; + } + + @Override + public String getKeyAsString() { + return key.asString(); + } + + @Override + protected XContentBuilder keyToXContent(XContentBuilder builder) throws IOException { + return builder.field(CommonFields.KEY.getPreferredName(), getKey()); + } + + static InternalCategorizationAggregation.BucketKey parsedKey(final XContentParser parser) throws IOException { + if (parser.currentToken() == XContentParser.Token.VALUE_STRING) { + String toSplit = parser.text(); + String[] split = Strings.tokenizeToStringArray(toSplit, " "); + return new InternalCategorizationAggregation.BucketKey( + split == null + ? new BytesRef[] { new BytesRef(toSplit) } + : Arrays.stream(split).map(BytesRef::new).toArray(BytesRef[]::new) + ); + } else { + return new InternalCategorizationAggregation.BucketKey( + XContentParserUtils.parseList(parser, p -> new BytesRef(p.binaryValue())).toArray(BytesRef[]::new) + ); + } + } + + static ParsedBucket fromXContent(final XContentParser parser) throws IOException { + return ParsedMultiBucketAggregation.ParsedBucket.parseXContent( + parser, + false, + ParsedBucket::new, + (p, bucket) -> bucket.key = parsedKey(p) + ); + } + + } + +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorizationTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorizationTests.java new file mode 100644 index 0000000000000..59129f8801937 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorizationTests.java @@ -0,0 +1,75 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BytesRefHash; +import org.elasticsearch.common.util.MockBigArrays; +import org.elasticsearch.common.util.MockPageCacheRecycler; +import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; +import org.elasticsearch.test.ESTestCase; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; + +import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.equalTo; + +public class TextCategorizationTests extends ESTestCase { + + static BigArrays mockBigArrays() { + return new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService()); + } + + private CategorizationBytesRefHash bytesRefHash; + + @Before + public void createRefHash() { + bytesRefHash = new CategorizationBytesRefHash(new BytesRefHash(1L, mockBigArrays())); + } + + @After + public void closeRefHash() throws IOException { + bytesRefHash.close(); + } + + public void testSimilarity() { + TextCategorization lg = new TextCategorization(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 1, 1); + TextCategorization.Similarity sims = lg.calculateSimilarity(getTokens(bytesRefHash, "not", "matching", "anything", "nope")); + assertThat(sims.getSimilarity(), equalTo(0.0)); + assertThat(sims.getWildCardCount(), equalTo(0)); + + sims = lg.calculateSimilarity(getTokens(bytesRefHash, "foo", "bar", "baz", "biz")); + assertThat(sims.getSimilarity(), equalTo(1.0)); + assertThat(sims.getWildCardCount(), equalTo(0)); + + sims = lg.calculateSimilarity(getTokens(bytesRefHash, "foo", "fooagain", "notbar", "biz")); + assertThat(sims.getSimilarity(), closeTo(0.5, 0.0001)); + assertThat(sims.getWildCardCount(), equalTo(0)); + } + + public void testAddTokens() { + TextCategorization lg = new TextCategorization(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 1, 1); + lg.addTokens(getTokens(bytesRefHash, "foo", "bar", "baz", "bozo"), 2); + assertThat(lg.getCount(), equalTo(3L)); + assertArrayEquals(lg.getCategorization(), getTokens(bytesRefHash, "foo", "bar", "baz", "*")); + } + + static int[] getTokens(CategorizationBytesRefHash bytesRefHash, String... tokens) { + BytesRef[] refs = new BytesRef[tokens.length]; + int i = 0; + for (String token : tokens) { + refs[i++] = new BytesRef(token); + } + return bytesRefHash.getIds(refs); + } + +} diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/categorization_agg.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/categorization_agg.yml new file mode 100644 index 0000000000000..c2d5e0dbf09f1 --- /dev/null +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/categorization_agg.yml @@ -0,0 +1,153 @@ +setup: + - skip: + features: headers + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + indices.create: + index: to_categorize + body: + mappings: + properties: + kind: + type: keyword + text: + type: text + + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + Content-Type: application/json + bulk: + index: to_categorize + refresh: true + body: | + {"index": {}} + {"product": "server","text": "Node 2 stopping"} + {"index": {}} + {"product": "server", "text": "Node 2 starting"} + {"index": {}} + {"product": "server", "text": "Node 4 stopping"} + {"index": {}} + {"product": "server", "text": "Node 5 stopping"} + {"index": {}} + {"product": "user", "text": "User Foo logging on"} + {"index": {}} + {"product": "user", "text": "User Foo logging on"} + {"index": {}} + {"product": "user", "text": "User Foo logging off"} + +--- +"Test categorization agg simple": + + - do: + search: + index: to_categorize + body: > + { + "size": 0, + "aggs": { + "categories": { + "categorize_text": { + "field": "text" + } + } + } + } + - length: { aggregations.categories.buckets: 4} + - match: {aggregations.categories.buckets.0.doc_count: 3} + - match: {aggregations.categories.buckets.0.key: "Node stopping" } + + - do: + search: + index: to_categorize + body: > + { + "size": 0, + "aggs": { + "categories": { + "categorize_text": { + "field": "text", + "size": 10, + "max_unique_tokens": 2, + "max_matched_tokens": 1, + "similarity_threshold": 11 + } + } + } + } + + - length: { aggregations.categories.buckets: 2 } + - match: { aggregations.categories.buckets.0.doc_count: 4 } + - match: { aggregations.categories.buckets.0.key: "Node *" } + - match: { aggregations.categories.buckets.1.doc_count: 3 } + - match: { aggregations.categories.buckets.1.key: "User Foo logging *" } +--- +"Test categorization aggregation with poor settings": + + - do: + catch: /\[max_unique_tokens\] must be greater than 0 and less than \[100\]/ + search: + index: to_categorize + body: > + { + "size": 0, + "aggs": { + "categories": { + "categorize_text": { + "field": "text", + "max_unique_tokens": -2 + } + } + } + } + - do: + catch: /\[max_matched_tokens\] must be greater than 0 and less than \[100\]/ + search: + index: to_categorize + body: > + { + "size": 0, + "aggs": { + "categories": { + "categorize_text": { + "field": "text", + "max_matched_tokens": -2 + } + } + } + } + - do: + catch: /\[similarity_threshold\] must be in the range \[1, 100\]/ + search: + index: to_categorize + body: > + { + "size": 0, + "aggs": { + "categories": { + "categorize_text": { + "field": "text", + "similarity_threshold": 0 + } + } + } + } + + - do: + catch: /\[categorization_filters\] cannot be used with \[categorization_analyzer\]/ + search: + index: to_categorize + body: > + { + "size": 0, + "aggs": { + "categories": { + "categorize_text": { + "field": "text", + "categorization_filters": ["foo"], + "categorization_analyzer": "english" + } + } + } + }