diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/search/aggregations/AggConstructionContentionBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/search/aggregations/AggConstructionContentionBenchmark.java index f496608e1c273..43a2d1930d0e8 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/search/aggregations/AggConstructionContentionBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/search/aggregations/AggConstructionContentionBenchmark.java @@ -22,6 +22,7 @@ import org.elasticsearch.core.Releasables; import org.elasticsearch.index.Index; import org.elasticsearch.index.IndexSettings; +import org.elasticsearch.index.analysis.NameOrDefinition; import org.elasticsearch.index.analysis.NamedAnalyzer; import org.elasticsearch.index.cache.bitset.BitsetFilterCache; import org.elasticsearch.index.fielddata.IndexFieldData; @@ -197,6 +198,22 @@ public long nowInMillis() { return 0; } + @Override + public Analyzer getNamedAnalyzer(String analyzer) { + return null; + } + + @Override + public Analyzer buildCustomAnalyzer( + IndexSettings indexSettings, + boolean normalizer, + NameOrDefinition tokenizer, + List charFilters, + List tokenFilters + ) { + return null; + } + @Override protected IndexFieldData buildFieldData(MappedFieldType ft) { IndexFieldDataCache indexFieldDataCache = indicesFieldDataCache.buildIndexFieldDataCache(new IndexFieldDataCache.Listener() { diff --git a/docs/build.gradle b/docs/build.gradle index f84bc91cf7489..5ecaf7dc6faff 100644 --- a/docs/build.gradle +++ b/docs/build.gradle @@ -1071,6 +1071,39 @@ buildRestTests.setups['farequote_datafeed'] = buildRestTests.setups['farequote_j "indexes":"farequote" } ''' +buildRestTests.setups['categorize_text'] = ''' + - do: + indices.create: + index: log-messages + body: + settings: + number_of_shards: 1 + number_of_replicas: 0 + mappings: + properties: + time: + type: date + message: + type: text + + - do: + bulk: + index: log-messages + refresh: true + body: | + {"index": {"_id":"1"}} + {"time":"2016-02-07T00:01:00+0000", "message": "2016-02-07T00:00:00+0000 Node 3 shutting down"} + {"index": {"_id":"2"}} + {"time":"2016-02-07T00:02:00+0000", "message": "2016-02-07T00:00:00+0000 Node 5 starting up"} + {"index": {"_id":"3"}} + {"time":"2016-02-07T00:03:00+0000", "message": "2016-02-07T00:00:00+0000 Node 4 shutting down"} + {"index": {"_id":"4"}} + {"time":"2016-02-08T00:01:00+0000", "message": "2016-02-08T00:00:00+0000 Node 5 shutting down"} + {"index": {"_id":"5"}} + {"time":"2016-02-08T00:02:00+0000", "message": "2016-02-08T00:00:00+0000 User foo_325 logging on"} + {"index": {"_id":"6"}} + {"time":"2016-02-08T00:04:00+0000", "message": "2016-02-08T00:00:00+0000 User foo_864 logged off"} +''' buildRestTests.setups['server_metrics_index'] = ''' - do: indices.create: diff --git a/docs/reference/aggregations/bucket.asciidoc b/docs/reference/aggregations/bucket.asciidoc index 302e196caf3ce..dfdaca18e6cfb 100644 --- a/docs/reference/aggregations/bucket.asciidoc +++ b/docs/reference/aggregations/bucket.asciidoc @@ -20,6 +20,8 @@ include::bucket/adjacency-matrix-aggregation.asciidoc[] include::bucket/autodatehistogram-aggregation.asciidoc[] +include::bucket/categorize-text-aggregation.asciidoc[] + include::bucket/children-aggregation.asciidoc[] include::bucket/composite-aggregation.asciidoc[] diff --git a/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc b/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc new file mode 100644 index 0000000000000..cc0a0e787f844 --- /dev/null +++ b/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc @@ -0,0 +1,469 @@ +[[search-aggregations-bucket-categorize-text-aggregation]] +=== Categorize text aggregation +++++ +Categorize text +++++ + +experimental::[] + +A multi-bucket aggregation that groups semi-structured text into buckets. Each `text` field is re-analyzed +using a custom analyzer. The resulting tokens are then categorized creating buckets of similarly formatted +text values. This aggregation works best with machine generated text like system logs. + +NOTE: If you have considerable memory allocated to your JVM but are receiving circuit breaker exceptions from this + aggregation, you may be attempting to categorize text that is poorly formatted for categorization. Consider + adding `categorization_filters` or running under <> or + <> to explore the created categories. + +[[bucket-categorize-text-agg-syntax]] +==== Parameters + +`field`:: +(Required, string) +The semi-structured text field to categorize. + +`max_unique_tokens`:: +(Optional, integer, default: `50`) +The maximum number of unique tokens at any position up to `max_matched_tokens`. +Must be larger than 1. Smaller values use less memory and create fewer categories. +Larger values will use more memory and create narrower categories. + +`max_matched_tokens`:: +(Optional, integer, default: `5`) +The maximum number of token positions to match on before attempting to merge categories. +Larger values will use more memory and create narrower categories. + +Example: +`max_matched_tokens` of 2 would disallow merging of the categories +[`foo` `bar` `baz`] +[`foo` `baz` `bozo`] +As the first 2 tokens are required to match for the category. + +NOTE: Once `max_unique_tokens` is reached at a given position, a new `*` token is +added and all new tokens at that position are matched by the `*` token. + +`similarity_threshold`:: +(Optional, integer, default: `50`) +The minimum percentage of tokens that must match for text to be added to the +category bucket. +Must be between 1 and 100. The larger the value the narrower the categories. +Larger values will increase memory usage and create narrower categories. + +`categorization_filters`:: +(Optional, array of strings) +This property expects an array of regular expressions. The expressions +are used to filter out matching sequences from the categorization field values. +You can use this functionality to fine tune the categorization by excluding +sequences from consideration when categories are defined. For example, you can +exclude SQL statements that appear in your log files. This +property cannot be used at the same time as `categorization_analyzer`. If you +only want to define simple regular expression filters that are applied prior to +tokenization, setting this property is the easiest method. If you also want to +customize the tokenizer or post-tokenization filtering, use the +`categorization_analyzer` property instead and include the filters as +`pattern_replace` character filters. + +`categorization_analyzer`:: +(Optional, object or string) +The categorization analyzer specifies how the text is analyzed and tokenized before +being categorized. The syntax is very similar to that used to define the `analyzer` in the +<>. This +property cannot be used at the same time as `categorization_filters`. ++ +The `categorization_analyzer` field can be specified either as a string or as an +object. If it is a string it must refer to a +<> or one added by another plugin. If it +is an object it has the following properties: ++ +.Properties of `categorization_analyzer` +[%collapsible%open] +===== +`char_filter`:::: +(array of strings or objects) +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=char-filter] + +`tokenizer`:::: +(string or object) +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=tokenizer] + +`filter`:::: +(array of strings or objects) +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=filter] +===== +end::categorization-analyzer[] + +`shard_size`:: +(Optional, integer) +The number of categorization buckets to return from each shard before merging +all the results. + +`size`:: +(Optional, integer, default: `10`) +The number of buckets to return. + +`min_doc_count`:: +(Optional, integer) +The minimum number of documents for a bucket to be returned to the results. + +`shard_min_doc_count`:: +(Optional, integer) +The minimum number of documents for a bucket to be returned from the shard before +merging. + +==== Basic use + + +WARNING: Re-analyzing _large_ result sets will require a lot of time and memory. This aggregation should be + used in conjunction with <>. Additionally, you may consider + using the aggregation as a child of either the <> or + <> aggregation. + This will typically improve speed and memory use. + +Example: + +[source,console] +-------------------------------------------------- +POST log-messages/_search?filter_path=aggregations +{ + "aggs": { + "categories": { + "categorize_text": { + "field": "message" + } + } + } +} +-------------------------------------------------- +// TEST[setup:categorize_text] + +Response: + +[source,console-result] +-------------------------------------------------- +{ + "aggregations" : { + "categories" : { + "buckets" : [ + { + "doc_count" : 3, + "key" : "Node shutting down" + }, + { + "doc_count" : 1, + "key" : "Node starting up" + }, + { + "doc_count" : 1, + "key" : "User foo_325 logging on" + }, + { + "doc_count" : 1, + "key" : "User foo_864 logged off" + } + ] + } + } +} +-------------------------------------------------- + + +Here is an example using `categorization_filters` + +[source,console] +-------------------------------------------------- +POST log-messages/_search?filter_path=aggregations +{ + "aggs": { + "categories": { + "categorize_text": { + "field": "message", + "categorization_filters": ["\\w+\\_\\d{3}"] <1> + } + } + } +} +-------------------------------------------------- +// TEST[setup:categorize_text] + +<1> The filters to apply to the analyzed tokens. It filters + out tokens like `bar_123`. + +Note how the `foo_` tokens are not part of the +category results + +[source,console-result] +-------------------------------------------------- +{ + "aggregations" : { + "categories" : { + "buckets" : [ + { + "doc_count" : 3, + "key" : "Node shutting down" + }, + { + "doc_count" : 1, + "key" : "Node starting up" + }, + { + "doc_count" : 1, + "key" : "User logged off" + }, + { + "doc_count" : 1, + "key" : "User logging on" + } + ] + } + } +} +-------------------------------------------------- + +Here is an example using `categorization_filters`. +The default analyzer is a whitespace analyzer with a custom token filter +which filters out tokens that start with any number. +But, it may be that a token is a known highly-variable token (formatted usernames, emails, etc.). In that case, it is good to supply +custom `categorization_filters` to filter out those tokens for better categories. These filters will also reduce memory usage as fewer +tokens are held in memory for the categories. + +[source,console] +-------------------------------------------------- +POST log-messages/_search?filter_path=aggregations +{ + "aggs": { + "categories": { + "categorize_text": { + "field": "message", + "categorization_filters": ["\\w+\\_\\d{3}"], <1> + "max_matched_tokens": 2, <2> + "similarity_threshold": 30 <3> + } + } + } +} +-------------------------------------------------- +// TEST[setup:categorize_text] +<1> The filters to apply to the analyzed tokens. It filters +out tokens like `bar_123`. +<2> Require at least 2 tokens before the log categories attempt to merge together +<3> Require 30% of the tokens to match before expanding a log categories + to add a new log entry + +The resulting categories are now broad, matching the first token +and merging the log groups. + +[source,console-result] +-------------------------------------------------- +{ + "aggregations" : { + "categories" : { + "buckets" : [ + { + "doc_count" : 4, + "key" : "Node *" + }, + { + "doc_count" : 2, + "key" : "User *" + } + ] + } + } +} +-------------------------------------------------- + +This aggregation can have both sub-aggregations and itself be a sub-aggregation. This allows gathering the top daily categories and the +top sample doc as below. + +[source,console] +-------------------------------------------------- +POST log-messages/_search?filter_path=aggregations +{ + "aggs": { + "daily": { + "date_histogram": { + "field": "time", + "fixed_interval": "1d" + }, + "aggs": { + "categories": { + "categorize_text": { + "field": "message", + "categorization_filters": ["\\w+\\_\\d{3}"] + }, + "aggs": { + "hit": { + "top_hits": { + "size": 1, + "sort": ["time"], + "_source": "message" + } + } + } + } + } + } + } +} +-------------------------------------------------- +// TEST[setup:categorize_text] + +[source,console-result] +-------------------------------------------------- +{ + "aggregations" : { + "daily" : { + "buckets" : [ + { + "key_as_string" : "2016-02-07T00:00:00.000Z", + "key" : 1454803200000, + "doc_count" : 3, + "categories" : { + "buckets" : [ + { + "doc_count" : 2, + "key" : "Node shutting down", + "hit" : { + "hits" : { + "total" : { + "value" : 2, + "relation" : "eq" + }, + "max_score" : null, + "hits" : [ + { + "_index" : "log-messages", + "_id" : "1", + "_score" : null, + "_source" : { + "message" : "2016-02-07T00:00:00+0000 Node 3 shutting down" + }, + "sort" : [ + 1454803260000 + ] + } + ] + } + } + }, + { + "doc_count" : 1, + "key" : "Node starting up", + "hit" : { + "hits" : { + "total" : { + "value" : 1, + "relation" : "eq" + }, + "max_score" : null, + "hits" : [ + { + "_index" : "log-messages", + "_id" : "2", + "_score" : null, + "_source" : { + "message" : "2016-02-07T00:00:00+0000 Node 5 starting up" + }, + "sort" : [ + 1454803320000 + ] + } + ] + } + } + } + ] + } + }, + { + "key_as_string" : "2016-02-08T00:00:00.000Z", + "key" : 1454889600000, + "doc_count" : 3, + "categories" : { + "buckets" : [ + { + "doc_count" : 1, + "key" : "Node shutting down", + "hit" : { + "hits" : { + "total" : { + "value" : 1, + "relation" : "eq" + }, + "max_score" : null, + "hits" : [ + { + "_index" : "log-messages", + "_id" : "4", + "_score" : null, + "_source" : { + "message" : "2016-02-08T00:00:00+0000 Node 5 shutting down" + }, + "sort" : [ + 1454889660000 + ] + } + ] + } + } + }, + { + "doc_count" : 1, + "key" : "User logged off", + "hit" : { + "hits" : { + "total" : { + "value" : 1, + "relation" : "eq" + }, + "max_score" : null, + "hits" : [ + { + "_index" : "log-messages", + "_id" : "6", + "_score" : null, + "_source" : { + "message" : "2016-02-08T00:00:00+0000 User foo_864 logged off" + }, + "sort" : [ + 1454889840000 + ] + } + ] + } + } + }, + { + "doc_count" : 1, + "key" : "User logging on", + "hit" : { + "hits" : { + "total" : { + "value" : 1, + "relation" : "eq" + }, + "max_score" : null, + "hits" : [ + { + "_index" : "log-messages", + "_id" : "5", + "_score" : null, + "_source" : { + "message" : "2016-02-08T00:00:00+0000 User foo_325 logging on" + }, + "sort" : [ + 1454889720000 + ] + } + ] + } + } + } + ] + } + } + ] + } + } +} +-------------------------------------------------- diff --git a/server/src/main/java/org/elasticsearch/search/SearchService.java b/server/src/main/java/org/elasticsearch/search/SearchService.java index 01f00e1a00721..175f028435ed6 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchService.java +++ b/server/src/main/java/org/elasticsearch/search/SearchService.java @@ -986,6 +986,7 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc context.terminateAfter(source.terminateAfter()); if (source.aggregations() != null && includeAggregations) { AggregationContext aggContext = new ProductionAggregationContext( + indicesService.getAnalysis(), context.getSearchExecutionContext(), bigArrays, source.aggregations().bytesToPreallocate(), diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/ParsedMultiBucketAggregation.java b/server/src/main/java/org/elasticsearch/search/aggregations/ParsedMultiBucketAggregation.java index 76ca0a917fb5d..48bd678ce5f80 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/ParsedMultiBucketAggregation.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/ParsedMultiBucketAggregation.java @@ -48,7 +48,7 @@ protected XContentBuilder doXContentBody(XContentBuilder builder, Params params) return builder; } - protected static , T extends ParsedBucket> void declareMultiBucketAggregationFields( + public static , T extends ParsedBucket> void declareMultiBucketAggregationFields( final ObjectParser objectParser, final CheckedFunction bucketParser, final CheckedFunction keyedBucketParser diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/AggregationContext.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/AggregationContext.java index 5008fbe08eac4..8ce69009f09e0 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/support/AggregationContext.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/support/AggregationContext.java @@ -18,6 +18,8 @@ import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; import org.elasticsearch.index.IndexSettings; +import org.elasticsearch.index.analysis.AnalysisRegistry; +import org.elasticsearch.index.analysis.NameOrDefinition; import org.elasticsearch.index.analysis.NamedAnalyzer; import org.elasticsearch.index.cache.bitset.BitsetFilterCache; import org.elasticsearch.index.fielddata.IndexFieldData; @@ -95,6 +97,30 @@ public final FieldContext buildFieldContext(String field) { return new FieldContext(field, buildFieldData(ft), ft); } + /** + * Returns an existing registered analyzer that should NOT be closed when finished being used. + * @param analyzer The custom analyzer name + * @return The existing named analyzer. + */ + public abstract Analyzer getNamedAnalyzer(String analyzer) throws IOException; + + /** + * Creates a new custom analyzer that should be closed when finished being used. + * @param indexSettings The current index settings or null + * @param normalizer Is a normalizer + * @param tokenizer The tokenizer name or definition to use + * @param charFilters The char filter name or definition to use + * @param tokenFilters The token filter name or definition to use + * @return A new custom analyzer + */ + public abstract Analyzer buildCustomAnalyzer( + IndexSettings indexSettings, + boolean normalizer, + NameOrDefinition tokenizer, + List charFilters, + List tokenFilters + ) throws IOException; + /** * Lookup the context for an already resolved field type. */ @@ -277,10 +303,12 @@ public static class ProductionAggregationContext extends AggregationContext { private final Supplier isCancelled; private final Function filterQuery; private final boolean enableRewriteToFilterByFilter; + private final AnalysisRegistry analysisRegistry; private final List releaseMe = new ArrayList<>(); public ProductionAggregationContext( + AnalysisRegistry analysisRegistry, SearchExecutionContext context, BigArrays bigArrays, long bytesToPreallocate, @@ -295,6 +323,7 @@ public ProductionAggregationContext( Function filterQuery, boolean enableRewriteToFilterByFilter ) { + this.analysisRegistry = analysisRegistry; this.context = context; if (bytesToPreallocate == 0) { /* @@ -350,6 +379,22 @@ public long nowInMillis() { return context.nowInMillis(); } + @Override + public Analyzer getNamedAnalyzer(String analyzer) throws IOException { + return analysisRegistry.getAnalyzer(analyzer); + } + + @Override + public Analyzer buildCustomAnalyzer( + IndexSettings indexSettings, + boolean normalizer, + NameOrDefinition tokenizer, + List charFilters, + List tokenFilters + ) throws IOException { + return analysisRegistry.buildCustomAnalyzer(indexSettings, normalizer, tokenizer, charFilters, tokenFilters); + } + @Override protected IndexFieldData buildFieldData(MappedFieldType ft) { return context.getForField(ft); diff --git a/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperServiceTestCase.java b/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperServiceTestCase.java index f341ae905ce0b..c9d1529e45430 100644 --- a/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperServiceTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperServiceTestCase.java @@ -37,6 +37,7 @@ import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.analysis.AnalyzerScope; import org.elasticsearch.index.analysis.IndexAnalyzers; +import org.elasticsearch.index.analysis.NameOrDefinition; import org.elasticsearch.index.analysis.NamedAnalyzer; import org.elasticsearch.index.cache.bitset.BitsetFilterCache; import org.elasticsearch.index.fielddata.IndexFieldData; @@ -352,6 +353,22 @@ public long nowInMillis() { return 0; } + @Override + public Analyzer getNamedAnalyzer(String analyzer) { + return null; + } + + @Override + public Analyzer buildCustomAnalyzer( + IndexSettings indexSettings, + boolean normalizer, + NameOrDefinition tokenizer, + List charFilters, + List tokenFilters + ) { + return null; + } + @Override public boolean isFieldMapped(String field) { throw new UnsupportedOperationException(); diff --git a/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java b/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java index 577598b467792..ac684f99e97c7 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java @@ -94,6 +94,7 @@ import org.elasticsearch.index.shard.IndexShard; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.indices.IndicesModule; +import org.elasticsearch.indices.analysis.AnalysisModule; import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; import org.elasticsearch.plugins.SearchPlugin; @@ -137,6 +138,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.function.Consumer; import java.util.function.Supplier; @@ -160,6 +162,7 @@ public abstract class AggregatorTestCase extends ESTestCase { private List releasables = new ArrayList<>(); protected ValuesSourceRegistry valuesSourceRegistry; + private AnalysisModule analysisModule; // A list of field types that should not be tested, or are not currently supported private static final List TYPE_TEST_BLACKLIST = List.of( @@ -179,6 +182,19 @@ public void initValuesSourceRegistry() { valuesSourceRegistry = searchModule.getValuesSourceRegistry(); } + @Before + public void initAnalysisRegistry() throws Exception { + analysisModule = createAnalysisModule(); + } + + /** + * @return a new analysis module. Tests that require a fully constructed analysis module (used to create an analysis registry) + * should override this method + */ + protected AnalysisModule createAnalysisModule() throws Exception { + return null; + } + /** * Test cases should override this if they have plugins that need to be loaded, e.g. the plugins their aggregators are in. */ @@ -284,6 +300,7 @@ public void onCache(ShardId shardId, Accountable accountable) {} MultiBucketConsumer consumer = new MultiBucketConsumer(maxBucket, breakerService.getBreaker(CircuitBreaker.REQUEST)); AggregationContext context = new ProductionAggregationContext( + Optional.ofNullable(analysisModule).map(AnalysisModule::getAnalysisRegistry).orElse(null), searchExecutionContext, new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), breakerService), bytesToPreallocate, diff --git a/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java index 30bcfd62a8e99..73eef9442246e 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java @@ -142,6 +142,14 @@ protected Collection> getPlugins() { return Collections.singletonList(TestGeoShapeFieldMapperPlugin.class); } + /** + * Allows additional plugins other than the required `TestGeoShapeFieldMapperPlugin` + * Could probably be removed when dependencies against geo_shape is decoupled + */ + protected Collection> getExtraPlugins() { + return Collections.emptyList(); + } + protected void initializeAdditionalMappings(MapperService mapperService) throws IOException { } @@ -208,9 +216,11 @@ public void beforeTest() throws Exception { // this setup long masterSeed = SeedUtils.parseSeed(RandomizedTest.getContext().getRunnerSeedAsString()); RandomizedTest.getContext().runWithPrivateRandomness(masterSeed, (Callable) () -> { - serviceHolder = new ServiceHolder(nodeSettings, createTestIndexSettings(), getPlugins(), nowInMillis, + Collection> plugins = new ArrayList<>(getPlugins()); + plugins.addAll(getExtraPlugins()); + serviceHolder = new ServiceHolder(nodeSettings, createTestIndexSettings(), plugins, nowInMillis, AbstractBuilderTestCase.this, true); - serviceHolderWithNoType = new ServiceHolder(nodeSettings, createTestIndexSettings(), getPlugins(), nowInMillis, + serviceHolderWithNoType = new ServiceHolder(nodeSettings, createTestIndexSettings(), plugins, nowInMillis, AbstractBuilderTestCase.this, false); return null; }); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/AnalysisConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/AnalysisConfig.java index a27aa10215ef9..6e8cb3845de63 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/AnalysisConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/AnalysisConfig.java @@ -736,7 +736,7 @@ private void verifyCategorizationFiltersAreValidRegex() { } } - private static boolean isValidRegex(String exp) { + public static boolean isValidRegex(String exp) { try { Pattern.compile(exp); return true; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/CategorizationAnalyzerConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/CategorizationAnalyzerConfig.java index d9430762df5d7..cedc046b26674 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/CategorizationAnalyzerConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/CategorizationAnalyzerConfig.java @@ -86,8 +86,10 @@ public static CategorizationAnalyzerConfig buildFromXContentObject(XContentParse * * The parser is strict when parsing config and lenient when parsing cluster state. */ - static CategorizationAnalyzerConfig buildFromXContentFragment(XContentParser parser, boolean ignoreUnknownFields) throws IOException { - + public static CategorizationAnalyzerConfig buildFromXContentFragment( + XContentParser parser, + boolean ignoreUnknownFields + ) throws IOException { CategorizationAnalyzerConfig.Builder builder = new CategorizationAnalyzerConfig.Builder(); XContentParser.Token token = parser.currentToken(); diff --git a/x-pack/plugin/ml/qa/ml-with-security/build.gradle b/x-pack/plugin/ml/qa/ml-with-security/build.gradle index 9c53ed96f3724..448fc8b1fc39b 100644 --- a/x-pack/plugin/ml/qa/ml-with-security/build.gradle +++ b/x-pack/plugin/ml/qa/ml-with-security/build.gradle @@ -35,6 +35,9 @@ tasks.named("yamlRestTest").configure { 'ml/calendar_crud/Test put calendar given id contains invalid chars', 'ml/calendar_crud/Test delete event from non existing calendar', 'ml/calendar_crud/Test delete job from non existing calendar', + // These are searching tests with aggregations, and do not call any ML endpoints + 'ml/categorization_agg/Test categorization agg simple', + 'ml/categorization_agg/Test categorization aggregation with poor settings', 'ml/custom_all_field/Test querying custom all field', 'ml/datafeeds_crud/Test delete datafeed with missing id', 'ml/datafeeds_crud/Test put datafeed referring to missing job_id', diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/CategorizationAggregationIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/CategorizationAggregationIT.java new file mode 100644 index 0000000000000..bd85984b3619f --- /dev/null +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/CategorizationAggregationIT.java @@ -0,0 +1,160 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.integration; + +import org.elasticsearch.action.bulk.BulkRequestBuilder; +import org.elasticsearch.action.bulk.BulkResponse; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.search.aggregations.AggregationBuilders; +import org.elasticsearch.search.aggregations.metrics.Max; +import org.elasticsearch.search.aggregations.metrics.Min; +import org.elasticsearch.xpack.ml.aggs.categorization.CategorizeTextAggregationBuilder; +import org.elasticsearch.xpack.ml.aggs.categorization.InternalCategorizationAggregation; +import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase; +import org.junit.Before; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.notANumber; + +public class CategorizationAggregationIT extends BaseMlIntegTestCase { + + private static final String DATA_INDEX = "categorization-agg-data"; + + @Before + public void setupCluster() { + internalCluster().ensureAtLeastNumDataNodes(3); + ensureStableCluster(); + createSourceData(); + } + + public void testAggregation() { + SearchResponse response = client().prepareSearch(DATA_INDEX) + .setSize(0) + .setTrackTotalHits(false) + .addAggregation( + new CategorizeTextAggregationBuilder("categorize", "msg") + .subAggregation(AggregationBuilders.max("max").field("time")) + .subAggregation(AggregationBuilders.min("min").field("time")) + ).get(); + + InternalCategorizationAggregation agg = response.getAggregations().get("categorize"); + assertThat(agg.getBuckets(), hasSize(3)); + + assertCategorizationBucket(agg.getBuckets().get(0), "Node started", 3); + assertCategorizationBucket( + agg.getBuckets().get(1), + "Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception", + 2 + ); + assertCategorizationBucket(agg.getBuckets().get(2), "Node stopped", 1); + } + + public void testAggregationWithOnlyOneBucket() { + SearchResponse response = client().prepareSearch(DATA_INDEX) + .setSize(0) + .setTrackTotalHits(false) + .addAggregation( + new CategorizeTextAggregationBuilder("categorize", "msg") + .size(1) + .subAggregation(AggregationBuilders.max("max").field("time")) + .subAggregation(AggregationBuilders.min("min").field("time")) + ).get(); + InternalCategorizationAggregation agg = response.getAggregations().get("categorize"); + assertThat(agg.getBuckets(), hasSize(1)); + + assertCategorizationBucket(agg.getBuckets().get(0), "Node started", 3); + } + + public void testAggregationWithBroadCategories() { + SearchResponse response = client().prepareSearch(DATA_INDEX) + .setSize(0) + .setTrackTotalHits(false) + .addAggregation( + new CategorizeTextAggregationBuilder("categorize", "msg") + .setSimilarityThreshold(11) + .setMaxUniqueTokens(2) + .setMaxMatchedTokens(1) + .subAggregation(AggregationBuilders.max("max").field("time")) + .subAggregation(AggregationBuilders.min("min").field("time")) + ).get(); + InternalCategorizationAggregation agg = response.getAggregations().get("categorize"); + assertThat(agg.getBuckets(), hasSize(2)); + + assertCategorizationBucket(agg.getBuckets().get(0), "Node *", 4); + assertCategorizationBucket( + agg.getBuckets().get(1), + "Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception", + 2 + ); + } + + private void assertCategorizationBucket(InternalCategorizationAggregation.Bucket bucket, String key, long docCount) { + assertThat(bucket.getKeyAsString(), equalTo(key)); + assertThat(bucket.getDocCount(), equalTo(docCount)); + assertThat(((Max)bucket.getAggregations().get("max")).getValue(), not(notANumber())); + assertThat(((Min)bucket.getAggregations().get("min")).getValue(), not(notANumber())); + } + + private void ensureStableCluster() { + ensureStableCluster(internalCluster().getNodeNames().length, TimeValue.timeValueSeconds(60)); + } + + private void createSourceData() { + client().admin().indices().prepareCreate(DATA_INDEX) + .setMapping("time", "type=date,format=epoch_millis", + "msg", "type=text") + .get(); + + long nowMillis = System.currentTimeMillis(); + + BulkRequestBuilder bulkRequestBuilder = client().prepareBulk(); + IndexRequest indexRequest = new IndexRequest(DATA_INDEX); + indexRequest.source("time", nowMillis - TimeValue.timeValueHours(2).millis(), + "msg", "Node 1 started", + "part", "nodes"); + bulkRequestBuilder.add(indexRequest); + indexRequest = new IndexRequest(DATA_INDEX); + indexRequest.source("time", nowMillis - TimeValue.timeValueHours(2).millis() + 1, + "msg", "Failed to shutdown [error org.aaaa.bbbb.Cccc line 54 caused by foo exception]", + "part", "shutdowns"); + bulkRequestBuilder.add(indexRequest); + indexRequest = new IndexRequest(DATA_INDEX); + indexRequest.source("time", nowMillis - TimeValue.timeValueHours(2).millis() + 1, + "msg", "Failed to shutdown [error org.aaaa.bbbb.Cccc line 55 caused by foo exception]", + "part", "shutdowns"); + bulkRequestBuilder.add(indexRequest); + indexRequest = new IndexRequest(DATA_INDEX); + indexRequest.source("time", nowMillis - TimeValue.timeValueHours(1).millis(), + "msg", "Node 2 started", + "part", "nodes"); + bulkRequestBuilder.add(indexRequest); + indexRequest = new IndexRequest(DATA_INDEX); + indexRequest.source("time", nowMillis, + "msg", "Node 3 started", + "part", "nodes"); + bulkRequestBuilder.add(indexRequest); + + indexRequest = new IndexRequest(DATA_INDEX); + indexRequest.source("time", nowMillis, + "msg", "Node 3 stopped", + "part", "nodes"); + bulkRequestBuilder.add(indexRequest); + + BulkResponse bulkResponse = bulkRequestBuilder + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .get(); + assertThat(bulkResponse.hasFailures(), is(false)); + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index fa36eb5a3b425..970f856ce5142 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -264,6 +264,8 @@ import org.elasticsearch.xpack.ml.action.TransportUpgradeJobModelSnapshotAction; import org.elasticsearch.xpack.ml.action.TransportValidateDetectorAction; import org.elasticsearch.xpack.ml.action.TransportValidateJobConfigAction; +import org.elasticsearch.xpack.ml.aggs.categorization.CategorizeTextAggregationBuilder; +import org.elasticsearch.xpack.ml.aggs.categorization.InternalCategorizationAggregation; import org.elasticsearch.xpack.ml.aggs.correlation.BucketCorrelationAggregationBuilder; import org.elasticsearch.xpack.ml.aggs.correlation.CorrelationNamedContentProvider; import org.elasticsearch.xpack.ml.aggs.heuristic.PValueScore; @@ -1220,6 +1222,7 @@ public List> getExecutorBuilders(Settings settings) { return Arrays.asList(jobComms, utility, datafeed); } + @Override public Map> getCharFilters() { return MapBuilder.>newMapBuilder() .put(FirstNonBlankLineCharFilter.NAME, FirstNonBlankLineCharFilterFactory::new) @@ -1249,6 +1252,18 @@ public List> getSignificanceHeuristics() { ); } + @Override + public List getAggregations() { + return Arrays.asList( + new AggregationSpec( + CategorizeTextAggregationBuilder.NAME, + CategorizeTextAggregationBuilder::new, + CategorizeTextAggregationBuilder.PARSER + ).addResultReader(InternalCategorizationAggregation::new) + .setAggregatorRegistrar(s -> s.registerUsage(CategorizeTextAggregationBuilder.NAME)) + ); + } + @Override public UnaryOperator> getIndexTemplateMetadataUpgrader() { return UnaryOperator.identity(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationBytesRefHash.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationBytesRefHash.java new file mode 100644 index 0000000000000..6246683ddfe6e --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationBytesRefHash.java @@ -0,0 +1,83 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.logging.LoggerMessageFormat; +import org.elasticsearch.common.util.BytesRefHash; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.search.aggregations.AggregationExecutionException; + +class CategorizationBytesRefHash implements Releasable { + + /** + * Our special wild card value. + */ + static final BytesRef WILD_CARD_REF = new BytesRef("*"); + /** + * For all WILD_CARD references, the token ID is always -1 + */ + static final int WILD_CARD_ID = -1; + private final BytesRefHash bytesRefHash; + + CategorizationBytesRefHash(BytesRefHash bytesRefHash) { + this.bytesRefHash = bytesRefHash; + } + + int[] getIds(BytesRef[] tokens) { + int[] ids = new int[tokens.length]; + for (int i = 0; i < tokens.length; i++) { + ids[i] = put(tokens[i]); + } + return ids; + } + + BytesRef[] getDeeps(int[] ids) { + BytesRef[] tokens = new BytesRef[ids.length]; + for (int i = 0; i < tokens.length; i++) { + tokens[i] = getDeep(ids[i]); + } + return tokens; + } + + BytesRef getDeep(long id) { + if (id == WILD_CARD_ID) { + return WILD_CARD_REF; + } + BytesRef shallow = bytesRefHash.get(id, new BytesRef()); + return BytesRef.deepCopyOf(shallow); + } + + int put(BytesRef bytesRef) { + if (WILD_CARD_REF.equals(bytesRef)) { + return WILD_CARD_ID; + } + long hash = bytesRefHash.add(bytesRef); + if (hash < 0) { + return (int) (-1L - hash); + } else { + if (hash > Integer.MAX_VALUE) { + throw new AggregationExecutionException( + LoggerMessageFormat.format( + "more than [{}] unique terms encountered. " + + "Consider restricting the documents queried or adding [{}] in the {} configuration", + Integer.MAX_VALUE, + CategorizeTextAggregationBuilder.CATEGORIZATION_FILTERS.getPreferredName(), + CategorizeTextAggregationBuilder.NAME + ) + ); + } + return (int) hash; + } + } + + @Override + public void close() { + bytesRefHash.close(); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java new file mode 100644 index 0000000000000..f5b5e6daea956 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java @@ -0,0 +1,143 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.search.aggregations.InternalAggregations; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + + +/** + * Categorized semi-structured text utilizing the drain algorithm: https://arxiv.org/pdf/1806.04356.pdf + * With the following key differences + * - This structure keeps track of the "smallest" sub-tree. So, instead of naively adding a new "*" node, the smallest sub-tree + * is transformed if the incoming token has a higher doc_count. + * - Additionally, similarities are weighted, which allows for nicer merging of existing categories + * - An optional tree reduction step is available to collapse together tiny sub-trees + * + * + * The main implementation is a fixed-sized prefix tree. + * Consequently, this assumes that splits that give us more information come earlier in the text. + * + * Examples: + * + * Given token values: + * + * Node is online + * Node is offline + * + * With a fixed tree depth of 2 we would get the following splits + * 3 // initial root is the number of tokens + * | + * "Node" // first prefix node of value "Node" + * | + * "is" + * / \ + * [Node is online] [Node is offline] //the individual categories for this simple case + * + * If the similarityThreshold was less than 0.6, the result would be a single category [Node is *] + * + */ +public class CategorizationTokenTree implements Accountable { + + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(CategorizationTokenTree.class); + private final int maxMatchTokens; + private final int maxUniqueTokens; + private final int similarityThreshold; + private long idGenerator; + private final Map root = new HashMap<>(); + private long sizeInBytes; + + public CategorizationTokenTree(int maxUniqueTokens, int maxMatchTokens, int similarityThreshold) { + assert maxUniqueTokens > 0 && maxMatchTokens >= 0; + this.maxUniqueTokens = maxUniqueTokens; + this.maxMatchTokens = maxMatchTokens; + this.similarityThreshold = similarityThreshold; + this.sizeInBytes = SHALLOW_SIZE; + } + + public List toIntermediateBuckets(CategorizationBytesRefHash hash) { + return root.values().stream().flatMap(c -> c.getAllChildrenTextCategorizations().stream()).map(lg -> { + int[] categoryTokenIds = lg.getCategorization(); + BytesRef[] bytesRefs = new BytesRef[categoryTokenIds.length]; + for (int i = 0; i < categoryTokenIds.length; i++) { + bytesRefs[i] = hash.getDeep(categoryTokenIds[i]); + } + InternalCategorizationAggregation.Bucket bucket = new InternalCategorizationAggregation.Bucket( + new InternalCategorizationAggregation.BucketKey(bytesRefs), + lg.getCount(), + InternalAggregations.EMPTY + ); + bucket.bucketOrd = lg.bucketOrd; + return bucket; + }).collect(Collectors.toList()); + } + + void mergeSmallestChildren() { + root.values().forEach(TreeNode::collapseTinyChildren); + } + + /** + * This method does not mutate the underlying structure. Meaning, if a matching categories isn't found, it may return empty. + * + * @param tokenIds The tokens to categorize + * @return The category or `Optional.empty()` if one doesn't exist + */ + public Optional parseTokensConst(final int[] tokenIds) { + TreeNode currentNode = this.root.get(tokenIds.length); + if (currentNode == null) { // we are missing an entire sub tree. New token length found + return Optional.empty(); + } + return Optional.ofNullable(currentNode.getCategorization(tokenIds)); + } + + /** + * This categorizes the passed tokens, potentially mutating the structure by expanding an existing category or adding a new one. + * @param tokenIds The tokens to categorize + * @param docCount The count of docs for the given tokens + * @return An existing categorization or a newly created one + */ + public TextCategorization parseTokens(final int[] tokenIds, long docCount) { + TreeNode currentNode = this.root.get(tokenIds.length); + if (currentNode == null) { // we are missing an entire sub tree. New token length found + currentNode = newNode(docCount, 0, tokenIds); + incSize(currentNode.ramBytesUsed() + RamUsageEstimator.HASHTABLE_RAM_BYTES_PER_ENTRY + RamUsageEstimator.NUM_BYTES_OBJECT_REF); + this.root.put(tokenIds.length, currentNode); + } else { + currentNode.incCount(docCount); + } + return currentNode.addText(tokenIds, docCount, this); + } + + TreeNode newNode(long docCount, int tokenPos, int[] tokenIds) { + return tokenPos < maxMatchTokens - 1 && tokenPos < tokenIds.length + ? new TreeNode.InnerTreeNode(docCount, tokenPos, maxUniqueTokens) + : new TreeNode.LeafTreeNode(docCount, similarityThreshold); + } + + TextCategorization newCategorization(long docCount, int[] tokenIds) { + return new TextCategorization(tokenIds, docCount, idGenerator++); + } + + void incSize(long size) { + sizeInBytes += size; + } + + @Override + public long ramBytesUsed() { + return sizeInBytes; + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilder.java new file mode 100644 index 0000000000000..e84167fd1d0a9 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilder.java @@ -0,0 +1,349 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ParseField; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.aggregations.AbstractAggregationBuilder; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.AggregatorFactories; +import org.elasticsearch.search.aggregations.AggregatorFactory; +import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder; +import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregator; +import org.elasticsearch.search.aggregations.support.AggregationContext; +import org.elasticsearch.xpack.core.ml.job.config.CategorizationAnalyzerConfig; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder.MIN_DOC_COUNT_FIELD_NAME; +import static org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder.REQUIRED_SIZE_FIELD_NAME; +import static org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder.SHARD_MIN_DOC_COUNT_FIELD_NAME; +import static org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder.SHARD_SIZE_FIELD_NAME; +import static org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig.Builder.isValidRegex; + +public class CategorizeTextAggregationBuilder extends AbstractAggregationBuilder { + + static final TermsAggregator.BucketCountThresholds DEFAULT_BUCKET_COUNT_THRESHOLDS = new TermsAggregator.BucketCountThresholds( + 1, + 0, + 10, + -1 + ); + + static final int MAX_MAX_UNIQUE_TOKENS = 100; + static final int MAX_MAX_MATCHED_TOKENS = 100; + public static final String NAME = "categorize_text"; + + static final ParseField FIELD_NAME = new ParseField("field"); + static final ParseField MAX_UNIQUE_TOKENS = new ParseField("max_unique_tokens"); + static final ParseField SIMILARITY_THRESHOLD = new ParseField("similarity_threshold"); + static final ParseField MAX_MATCHED_TOKENS = new ParseField("max_matched_tokens"); + static final ParseField CATEGORIZATION_FILTERS = new ParseField("categorization_filters"); + static final ParseField CATEGORIZATION_ANALYZER = new ParseField("categorization_analyzer"); + + public static final ObjectParser PARSER = ObjectParser.fromBuilder( + CategorizeTextAggregationBuilder.NAME, + CategorizeTextAggregationBuilder::new + ); + static { + PARSER.declareString(CategorizeTextAggregationBuilder::setFieldName, FIELD_NAME); + PARSER.declareInt(CategorizeTextAggregationBuilder::setMaxUniqueTokens, MAX_UNIQUE_TOKENS); + PARSER.declareInt(CategorizeTextAggregationBuilder::setMaxMatchedTokens, MAX_MATCHED_TOKENS); + PARSER.declareInt(CategorizeTextAggregationBuilder::setSimilarityThreshold, SIMILARITY_THRESHOLD); + PARSER.declareField( + CategorizeTextAggregationBuilder::setCategorizationAnalyzerConfig, + (p, c) -> CategorizationAnalyzerConfig.buildFromXContentFragment(p, false), + CATEGORIZATION_ANALYZER, + ObjectParser.ValueType.OBJECT_OR_STRING + ); + PARSER.declareStringArray(CategorizeTextAggregationBuilder::setCategorizationFilters, CATEGORIZATION_FILTERS); + PARSER.declareInt(CategorizeTextAggregationBuilder::shardSize, TermsAggregationBuilder.SHARD_SIZE_FIELD_NAME); + PARSER.declareLong(CategorizeTextAggregationBuilder::minDocCount, TermsAggregationBuilder.MIN_DOC_COUNT_FIELD_NAME); + PARSER.declareLong(CategorizeTextAggregationBuilder::shardMinDocCount, TermsAggregationBuilder.SHARD_MIN_DOC_COUNT_FIELD_NAME); + PARSER.declareInt(CategorizeTextAggregationBuilder::size, REQUIRED_SIZE_FIELD_NAME); + } + + public static CategorizeTextAggregationBuilder parse(String aggregationName, XContentParser parser) throws IOException { + return PARSER.parse(parser, new CategorizeTextAggregationBuilder(aggregationName), null); + } + + private TermsAggregator.BucketCountThresholds bucketCountThresholds = new TermsAggregator.BucketCountThresholds( + DEFAULT_BUCKET_COUNT_THRESHOLDS + ); + private CategorizationAnalyzerConfig categorizationAnalyzerConfig; + private String fieldName; + private int maxUniqueTokens = 50; + private int similarityThreshold = 50; + private int maxMatchedTokens = 5; + + private CategorizeTextAggregationBuilder(String name) { + super(name); + } + + public CategorizeTextAggregationBuilder(String name, String fieldName) { + super(name); + this.fieldName = ExceptionsHelper.requireNonNull(fieldName, FIELD_NAME); + } + + public String getFieldName() { + return fieldName; + } + + public CategorizeTextAggregationBuilder setFieldName(String fieldName) { + this.fieldName = ExceptionsHelper.requireNonNull(fieldName, FIELD_NAME); + return this; + } + + public CategorizeTextAggregationBuilder(StreamInput in) throws IOException { + super(in); + this.bucketCountThresholds = new TermsAggregator.BucketCountThresholds(in); + this.fieldName = in.readString(); + this.maxUniqueTokens = in.readVInt(); + this.maxMatchedTokens = in.readVInt(); + this.similarityThreshold = in.readVInt(); + this.categorizationAnalyzerConfig = in.readOptionalWriteable(CategorizationAnalyzerConfig::new); + } + + public int getMaxUniqueTokens() { + return maxUniqueTokens; + } + + public CategorizeTextAggregationBuilder setMaxUniqueTokens(int maxUniqueTokens) { + this.maxUniqueTokens = maxUniqueTokens; + if (maxUniqueTokens <= 0) { + throw ExceptionsHelper.badRequestException( + "[{}] must be greater than 0 and less than [{}]. Found [{}] in [{}]", + MAX_UNIQUE_TOKENS.getPreferredName(), + MAX_MAX_UNIQUE_TOKENS, + maxUniqueTokens, + name + ); + } + return this; + } + + public double getSimilarityThreshold() { + return similarityThreshold; + } + + public CategorizeTextAggregationBuilder setSimilarityThreshold(int similarityThreshold) { + this.similarityThreshold = similarityThreshold; + if (similarityThreshold < 1 || similarityThreshold > 100) { + throw ExceptionsHelper.badRequestException( + "[{}] must be in the range [1, 100]. Found [{}] in [{}]", + SIMILARITY_THRESHOLD.getPreferredName(), + similarityThreshold, + name + ); + } + return this; + } + + public CategorizeTextAggregationBuilder setCategorizationAnalyzerConfig(CategorizationAnalyzerConfig categorizationAnalyzerConfig) { + this.categorizationAnalyzerConfig = categorizationAnalyzerConfig; + return this; + } + + public CategorizeTextAggregationBuilder setCategorizationFilters(List categorizationFilters) { + if (categorizationFilters == null || categorizationFilters.isEmpty()) { + return this; + } + if (categorizationAnalyzerConfig != null) { + throw ExceptionsHelper.badRequestException( + "[{}] cannot be used with [{}] - instead specify them as pattern_replace char_filters in the analyzer", + CATEGORIZATION_FILTERS.getPreferredName(), + CATEGORIZATION_ANALYZER.getPreferredName() + ); + } + if (categorizationFilters.stream().distinct().count() != categorizationFilters.size()) { + throw ExceptionsHelper.badRequestException(Messages.JOB_CONFIG_CATEGORIZATION_FILTERS_CONTAINS_DUPLICATES); + } + if (categorizationFilters.stream().anyMatch(String::isEmpty)) { + throw ExceptionsHelper.badRequestException(Messages.getMessage(Messages.JOB_CONFIG_CATEGORIZATION_FILTERS_CONTAINS_EMPTY)); + } + for (String filter : categorizationFilters) { + if (isValidRegex(filter) == false) { + throw ExceptionsHelper.badRequestException( + Messages.getMessage(Messages.JOB_CONFIG_CATEGORIZATION_FILTERS_CONTAINS_INVALID_REGEX, filter) + ); + } + } + this.categorizationAnalyzerConfig = CategorizationAnalyzerConfig.buildStandardCategorizationAnalyzer(categorizationFilters); + return this; + } + + public int getMaxMatchedTokens() { + return maxMatchedTokens; + } + + public CategorizeTextAggregationBuilder setMaxMatchedTokens(int maxMatchedTokens) { + this.maxMatchedTokens = maxMatchedTokens; + if (maxMatchedTokens <= 0) { + throw ExceptionsHelper.badRequestException( + "[{}] must be greater than 0 and less than [{}]. Found [{}] in [{}]", + MAX_MATCHED_TOKENS.getPreferredName(), + MAX_MAX_MATCHED_TOKENS, + maxMatchedTokens, + name + ); + } + return this; + } + + /** + * @param size indicating how many buckets should be returned + */ + public CategorizeTextAggregationBuilder size(int size) { + if (size <= 0) { + throw ExceptionsHelper.badRequestException( + "[{}] must be greater than 0. Found [{}] in [{}]", + REQUIRED_SIZE_FIELD_NAME.getPreferredName(), + size, + name + ); + } + bucketCountThresholds.setRequiredSize(size); + return this; + } + + /** + * @param shardSize - indicating the number of buckets each shard + * will return to the coordinating node (the node that coordinates the + * search execution). The higher the shard size is, the more accurate the + * results are. + */ + public CategorizeTextAggregationBuilder shardSize(int shardSize) { + if (shardSize <= 0) { + throw ExceptionsHelper.badRequestException( + "[{}] must be greater than 0. Found [{}] in [{}]", + SHARD_SIZE_FIELD_NAME.getPreferredName(), + shardSize, + name + ); + } + bucketCountThresholds.setShardSize(shardSize); + return this; + } + + /** + * @param minDocCount the minimum document count a text category should have in order to appear in + * the response. + */ + public CategorizeTextAggregationBuilder minDocCount(long minDocCount) { + if (minDocCount < 0) { + throw ExceptionsHelper.badRequestException( + "[{}] must be greater than or equal to 0. Found [{}] in [{}]", + MIN_DOC_COUNT_FIELD_NAME.getPreferredName(), + minDocCount, + name + ); + } + bucketCountThresholds.setMinDocCount(minDocCount); + return this; + } + + /** + * @param shardMinDocCount the minimum document count a text category should have on the shard in order to + * appear in the response. + */ + public CategorizeTextAggregationBuilder shardMinDocCount(long shardMinDocCount) { + if (shardMinDocCount < 0) { + throw ExceptionsHelper.badRequestException( + "[{}] must be greater than or equal to 0. Found [{}] in [{}]", + SHARD_MIN_DOC_COUNT_FIELD_NAME.getPreferredName(), + shardMinDocCount, + name + ); + } + bucketCountThresholds.setShardMinDocCount(shardMinDocCount); + return this; + } + + protected CategorizeTextAggregationBuilder( + CategorizeTextAggregationBuilder clone, + AggregatorFactories.Builder factoriesBuilder, + Map metadata + ) { + super(clone, factoriesBuilder, metadata); + this.bucketCountThresholds = new TermsAggregator.BucketCountThresholds(clone.bucketCountThresholds); + this.fieldName = clone.fieldName; + this.maxUniqueTokens = clone.maxUniqueTokens; + this.maxMatchedTokens = clone.maxMatchedTokens; + this.similarityThreshold = clone.similarityThreshold; + this.categorizationAnalyzerConfig = clone.categorizationAnalyzerConfig; + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + bucketCountThresholds.writeTo(out); + out.writeString(fieldName); + out.writeVInt(maxUniqueTokens); + out.writeVInt(maxMatchedTokens); + out.writeVInt(similarityThreshold); + out.writeOptionalWriteable(categorizationAnalyzerConfig); + } + + @Override + protected AggregatorFactory doBuild( + AggregationContext context, + AggregatorFactory parent, + AggregatorFactories.Builder subfactoriesBuilder + ) throws IOException { + return new CategorizeTextAggregatorFactory( + name, + fieldName, + maxUniqueTokens, + maxMatchedTokens, + similarityThreshold, + bucketCountThresholds, + categorizationAnalyzerConfig, + context, + parent, + subfactoriesBuilder, + metadata + ); + } + + @Override + protected XContentBuilder internalXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + bucketCountThresholds.toXContent(builder, params); + builder.field(FIELD_NAME.getPreferredName(), fieldName); + builder.field(MAX_UNIQUE_TOKENS.getPreferredName(), maxUniqueTokens); + builder.field(MAX_MATCHED_TOKENS.getPreferredName(), maxMatchedTokens); + builder.field(SIMILARITY_THRESHOLD.getPreferredName(), similarityThreshold); + if (categorizationAnalyzerConfig != null) { + categorizationAnalyzerConfig.toXContent(builder, params); + } + builder.endObject(); + return null; + } + + @Override + protected AggregationBuilder shallowCopy(AggregatorFactories.Builder factoriesBuilder, Map metadata) { + return new CategorizeTextAggregationBuilder(this, factoriesBuilder, metadata); + } + + @Override + public BucketCardinality bucketCardinality() { + return BucketCardinality.MANY; + } + + @Override + public String getType() { + return NAME; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java new file mode 100644 index 0000000000000..16058fbdae4f2 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java @@ -0,0 +1,239 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.PriorityQueue; +import org.elasticsearch.common.util.BytesRefHash; +import org.elasticsearch.common.util.ObjectArray; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.search.aggregations.Aggregator; +import org.elasticsearch.search.aggregations.AggregatorFactories; +import org.elasticsearch.search.aggregations.CardinalityUpperBound; +import org.elasticsearch.search.aggregations.InternalAggregation; +import org.elasticsearch.search.aggregations.LeafBucketCollector; +import org.elasticsearch.search.aggregations.LeafBucketCollectorBase; +import org.elasticsearch.search.aggregations.bucket.DeferableBucketAggregator; +import org.elasticsearch.search.aggregations.bucket.terms.LongKeyedBucketOrds; +import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregator; +import org.elasticsearch.search.aggregations.support.AggregationContext; +import org.elasticsearch.search.lookup.SourceLookup; +import org.elasticsearch.xpack.core.ml.job.config.CategorizationAnalyzerConfig; +import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.Iterator; +import java.util.Map; +import java.util.Optional; + +public class CategorizeTextAggregator extends DeferableBucketAggregator { + + private final TermsAggregator.BucketCountThresholds bucketCountThresholds; + private final SourceLookup sourceLookup; + private final MappedFieldType fieldType; + private final CategorizationAnalyzer analyzer; + private final String sourceFieldName; + private ObjectArray categorizers; + private final int maxUniqueTokens; + private final int maxMatchTokens; + private final int similarityThreshold; + private final LongKeyedBucketOrds bucketOrds; + private final CategorizationBytesRefHash bytesRefHash; + + protected CategorizeTextAggregator( + String name, + AggregatorFactories factories, + AggregationContext context, + Aggregator parent, + String sourceFieldName, + MappedFieldType fieldType, + TermsAggregator.BucketCountThresholds bucketCountThresholds, + int maxUniqueTokens, + int maxMatchTokens, + int similarityThreshold, + CategorizationAnalyzerConfig categorizationAnalyzerConfig, + Map metadata + ) throws IOException { + super(name, factories, context, parent, metadata); + this.sourceLookup = context.lookup().source(); + this.sourceFieldName = sourceFieldName; + this.fieldType = fieldType; + CategorizationAnalyzerConfig analyzerConfig = Optional.ofNullable(categorizationAnalyzerConfig) + .orElse(CategorizationAnalyzerConfig.buildStandardCategorizationAnalyzer(Collections.emptyList())); + final String analyzerName = analyzerConfig.getAnalyzer(); + if (analyzerName != null) { + Analyzer globalAnalyzer = context.getNamedAnalyzer(analyzerName); + if (globalAnalyzer == null) { + throw new IllegalArgumentException("Failed to find global analyzer [" + analyzerName + "]"); + } + this.analyzer = new CategorizationAnalyzer(globalAnalyzer, false); + } else { + this.analyzer = new CategorizationAnalyzer( + context.buildCustomAnalyzer( + context.getIndexSettings(), + false, + analyzerConfig.getTokenizer(), + analyzerConfig.getCharFilters(), + analyzerConfig.getTokenFilters() + ), + true + ); + } + this.categorizers = bigArrays().newObjectArray(1); + this.maxUniqueTokens = maxUniqueTokens; + this.maxMatchTokens = maxMatchTokens; + this.similarityThreshold = similarityThreshold; + this.bucketOrds = LongKeyedBucketOrds.build(bigArrays(), CardinalityUpperBound.MANY); + this.bucketCountThresholds = bucketCountThresholds; + this.bytesRefHash = new CategorizationBytesRefHash(new BytesRefHash(2048, bigArrays())); + } + + @Override + protected void doClose() { + super.doClose(); + Releasables.close(this.analyzer, this.bytesRefHash); + } + + @Override + public InternalAggregation[] buildAggregations(long[] ordsToCollect) throws IOException { + InternalCategorizationAggregation.Bucket[][] topBucketsPerOrd = + new InternalCategorizationAggregation.Bucket[ordsToCollect.length][]; + for (int ordIdx = 0; ordIdx < ordsToCollect.length; ordIdx++) { + final CategorizationTokenTree categorizationTokenTree = categorizers.get(ordsToCollect[ordIdx]); + if (categorizationTokenTree == null) { + topBucketsPerOrd[ordIdx] = new InternalCategorizationAggregation.Bucket[0]; + continue; + } + int size = (int) Math.min(bucketOrds.bucketsInOrd(ordIdx), bucketCountThresholds.getShardSize()); + PriorityQueue ordered = + new InternalCategorizationAggregation.BucketCountPriorityQueue(size); + for (InternalCategorizationAggregation.Bucket bucket : categorizationTokenTree.toIntermediateBuckets(bytesRefHash)) { + if (bucket.docCount < bucketCountThresholds.getShardMinDocCount()) { + continue; + } + ordered.insertWithOverflow(bucket); + } + topBucketsPerOrd[ordIdx] = new InternalCategorizationAggregation.Bucket[ordered.size()]; + for (int i = ordered.size() - 1; i >= 0; --i) { + topBucketsPerOrd[ordIdx][i] = ordered.pop(); + } + } + buildSubAggsForAllBuckets(topBucketsPerOrd, b -> b.bucketOrd, (b, a) -> b.aggregations = a); + InternalAggregation[] results = new InternalAggregation[ordsToCollect.length]; + for (int ordIdx = 0; ordIdx < ordsToCollect.length; ordIdx++) { + InternalCategorizationAggregation.Bucket[] bucketArray = topBucketsPerOrd[ordIdx]; + Arrays.sort(bucketArray, Comparator.naturalOrder()); + results[ordIdx] = new InternalCategorizationAggregation( + name, + bucketCountThresholds.getRequiredSize(), + bucketCountThresholds.getMinDocCount(), + maxUniqueTokens, + maxMatchTokens, + similarityThreshold, + metadata(), + Arrays.asList(bucketArray) + ); + } + return results; + } + + @Override + public InternalAggregation buildEmptyAggregation() { + return new InternalCategorizationAggregation( + name, + bucketCountThresholds.getRequiredSize(), + bucketCountThresholds.getMinDocCount(), + maxUniqueTokens, + maxMatchTokens, + similarityThreshold, + metadata() + ); + } + + @Override + protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException { + return new LeafBucketCollectorBase(sub, null) { + @Override + public void collect(int doc, long owningBucketOrd) throws IOException { + categorizers = bigArrays().grow(categorizers, owningBucketOrd + 1); + CategorizationTokenTree categorizer = categorizers.get(owningBucketOrd); + if (categorizer == null) { + categorizer = new CategorizationTokenTree(maxUniqueTokens, maxMatchTokens, similarityThreshold); + addRequestCircuitBreakerBytes(categorizer.ramBytesUsed()); + categorizers.set(owningBucketOrd, categorizer); + } + collectFromSource(doc, owningBucketOrd, categorizer); + } + + private void collectFromSource(int doc, long owningBucketOrd, CategorizationTokenTree categorizer) throws IOException { + sourceLookup.setSegmentAndDocument(ctx, doc); + Iterator itr = sourceLookup.extractRawValues(sourceFieldName).stream().map(obj -> { + if (obj == null) { + return null; + } + if (obj instanceof BytesRef) { + return fieldType.valueForDisplay(obj).toString(); + } + return obj.toString(); + }).iterator(); + while (itr.hasNext()) { + TokenStream ts = analyzer.tokenStream(fieldType.name(), itr.next()); + processTokenStream(owningBucketOrd, ts, doc, categorizer); + } + } + + private void processTokenStream( + long owningBucketOrd, + TokenStream ts, + int doc, + CategorizationTokenTree categorizer + ) throws IOException { + ArrayList tokens = new ArrayList<>(); + try { + CharTermAttribute termAtt = ts.addAttribute(CharTermAttribute.class); + ts.reset(); + while (ts.incrementToken()) { + tokens.add(bytesRefHash.put(new BytesRef(termAtt))); + } + if (tokens.isEmpty()) { + return; + } + } finally { + ts.close(); + } + long previousSize = categorizer.ramBytesUsed(); + TextCategorization lg = categorizer.parseTokens( + tokens.stream().mapToInt(Integer::valueOf).toArray(), + docCountProvider.getDocCount(doc) + ); + long newSize = categorizer.ramBytesUsed(); + if (newSize - previousSize > 0) { + addRequestCircuitBreakerBytes(newSize - previousSize); + } + + long bucketOrd = bucketOrds.add(owningBucketOrd, lg.getId()); + if (bucketOrd < 0) { // already seen + bucketOrd = -1 - bucketOrd; + collectExistingBucket(sub, doc, bucketOrd); + } else { + lg.bucketOrd = bucketOrd; + collectBucket(sub, doc, bucketOrd); + } + } + }; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorFactory.java new file mode 100644 index 0000000000000..f63b4ba1f802b --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorFactory.java @@ -0,0 +1,112 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.search.aggregations.Aggregator; +import org.elasticsearch.search.aggregations.AggregatorFactories; +import org.elasticsearch.search.aggregations.AggregatorFactory; +import org.elasticsearch.search.aggregations.CardinalityUpperBound; +import org.elasticsearch.search.aggregations.InternalAggregation; +import org.elasticsearch.search.aggregations.NonCollectingAggregator; +import org.elasticsearch.search.aggregations.bucket.BucketUtils; +import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregator; +import org.elasticsearch.search.aggregations.support.AggregationContext; +import org.elasticsearch.xpack.core.ml.job.config.CategorizationAnalyzerConfig; + +import java.io.IOException; +import java.util.Map; + +public class CategorizeTextAggregatorFactory extends AggregatorFactory { + + private final MappedFieldType fieldType; + private final String indexedFieldName; + private final int maxUniqueTokens; + private final int maxMatchTokens; + private final int similarityThreshold; + private final CategorizationAnalyzerConfig categorizationAnalyzerConfig; + private final TermsAggregator.BucketCountThresholds bucketCountThresholds; + + public CategorizeTextAggregatorFactory( + String name, + String fieldName, + int maxUniqueTokens, + int maxMatchTokens, + int similarityThreshold, + TermsAggregator.BucketCountThresholds bucketCountThresholds, + CategorizationAnalyzerConfig categorizationAnalyzerConfig, + AggregationContext context, + AggregatorFactory parent, + AggregatorFactories.Builder subFactoriesBuilder, + Map metadata + ) throws IOException { + super(name, context, parent, subFactoriesBuilder, metadata); + this.fieldType = context.getFieldType(fieldName); + if (fieldType != null) { + this.indexedFieldName = fieldType.name(); + } else { + this.indexedFieldName = null; + } + this.maxUniqueTokens = maxUniqueTokens; + this.maxMatchTokens = maxMatchTokens; + this.similarityThreshold = similarityThreshold; + this.categorizationAnalyzerConfig = categorizationAnalyzerConfig; + this.bucketCountThresholds = bucketCountThresholds; + } + + protected Aggregator createUnmapped(Aggregator parent, Map metadata) throws IOException { + final InternalAggregation aggregation = new UnmappedCategorizationAggregation( + name, + bucketCountThresholds.getRequiredSize(), + bucketCountThresholds.getMinDocCount(), + maxUniqueTokens, + maxMatchTokens, + similarityThreshold, + metadata + ); + return new NonCollectingAggregator(name, context, parent, factories, metadata) { + @Override + public InternalAggregation buildEmptyAggregation() { + return aggregation; + } + }; + } + + @Override + protected Aggregator createInternal(Aggregator parent, CardinalityUpperBound cardinality, Map metadata) + throws IOException { + if (fieldType == null) { + return createUnmapped(parent, metadata); + } + TermsAggregator.BucketCountThresholds bucketCountThresholds = new TermsAggregator.BucketCountThresholds(this.bucketCountThresholds); + if (bucketCountThresholds.getShardSize() == CategorizeTextAggregationBuilder.DEFAULT_BUCKET_COUNT_THRESHOLDS.getShardSize()) { + // The user has not made a shardSize selection. Use default + // heuristic to avoid any wrong-ranking caused by distributed + // counting + // TODO significant text does a 2x here, should we as well? + bucketCountThresholds.setShardSize(BucketUtils.suggestShardSideQueueSize(bucketCountThresholds.getRequiredSize())); + } + bucketCountThresholds.ensureValidity(); + + return new CategorizeTextAggregator( + name, + factories, + context, + parent, + indexedFieldName, + fieldType, + bucketCountThresholds, + maxUniqueTokens, + maxMatchTokens, + similarityThreshold, + categorizationAnalyzerConfig, + metadata + ); + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java new file mode 100644 index 0000000000000..92c51b8d75b4c --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java @@ -0,0 +1,453 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.PriorityQueue; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.util.BytesRefHash; +import org.elasticsearch.common.xcontent.ToXContentFragment; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.search.aggregations.AggregationExecutionException; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.InternalAggregation; +import org.elasticsearch.search.aggregations.InternalAggregations; +import org.elasticsearch.search.aggregations.InternalMultiBucketAggregation; +import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash.WILD_CARD_REF; + + +public class InternalCategorizationAggregation extends InternalMultiBucketAggregation< + InternalCategorizationAggregation, + InternalCategorizationAggregation.Bucket> { + + // Carries state allowing for delayed reduction of the bucket + // This allows us to keep from accidentally calling "reduce" on the sub-aggs more than once + private static class DelayedCategorizationBucket { + private final BucketKey key; + private long docCount; + private final List toReduce; + + DelayedCategorizationBucket(BucketKey key, List toReduce, long docCount) { + this.key = key; + this.toReduce = new ArrayList<>(toReduce); + this.docCount = docCount; + } + + public long getDocCount() { + return docCount; + } + + public Bucket reduce(BucketKey key, ReduceContext reduceContext) { + List innerAggs = new ArrayList<>(toReduce.size()); + long docCount = 0; + for (Bucket bucket : toReduce) { + innerAggs.add(bucket.aggregations); + docCount += bucket.docCount; + } + return new Bucket(key, docCount, InternalAggregations.reduce(innerAggs, reduceContext)); + } + + public DelayedCategorizationBucket add(Bucket bucket) { + this.docCount += bucket.docCount; + this.toReduce.add(bucket); + return this; + } + + public DelayedCategorizationBucket add(DelayedCategorizationBucket bucket) { + this.docCount += bucket.docCount; + this.toReduce.addAll(bucket.toReduce); + return this; + } + } + + static class BucketCountPriorityQueue extends PriorityQueue { + BucketCountPriorityQueue(int size) { + super(size); + } + + @Override + protected boolean lessThan(Bucket a, Bucket b) { + return a.docCount < b.docCount; + } + } + + static class BucketKey implements ToXContentFragment, Writeable, Comparable { + + private final BytesRef[] key; + + static BucketKey withCollapsedWildcards(BytesRef[] key) { + if (key.length <= 1) { + return new BucketKey(key); + } + List collapsedWildCards = new ArrayList<>(); + boolean previousTokenWildCard = false; + for (BytesRef token : key) { + if (token.equals(WILD_CARD_REF)) { + if (previousTokenWildCard == false) { + previousTokenWildCard = true; + collapsedWildCards.add(WILD_CARD_REF); + } + } else { + previousTokenWildCard = false; + collapsedWildCards.add(token); + } + } + if (collapsedWildCards.size() == key.length) { + return new BucketKey(key); + } + return new BucketKey(collapsedWildCards.toArray(BytesRef[]::new)); + } + + BucketKey(BytesRef[] key) { + this.key = key; + } + + BucketKey(StreamInput in) throws IOException { + key = in.readArray(StreamInput::readBytesRef, BytesRef[]::new); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder.value(asString()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeArray(StreamOutput::writeBytesRef, key); + } + + public String asString() { + StringBuilder builder = new StringBuilder(); + for (int i = 0; i < key.length - 1; i++) { + builder.append(key[i].utf8ToString()).append(" "); + } + builder.append(key[key.length - 1].utf8ToString()); + return builder.toString(); + } + + @Override + public String toString() { + return asString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + BucketKey bucketKey = (BucketKey) o; + return Arrays.equals(key, bucketKey.key); + } + + @Override + public int hashCode() { + return Arrays.hashCode(key); + } + + public BytesRef[] keyAsTokens() { + return key; + } + + @Override + public int compareTo(BucketKey o) { + return Arrays.compare(key, o.key); + } + + } + + public static class Bucket extends InternalBucket implements MultiBucketsAggregation.Bucket, Comparable { + // Used on the shard level to keep track of sub aggregations + long bucketOrd; + + final BucketKey key; + final long docCount; + InternalAggregations aggregations; + + public Bucket(BucketKey key, long docCount, InternalAggregations aggregations) { + this.key = key; + this.docCount = docCount; + this.aggregations = aggregations; + } + + public Bucket(StreamInput in) throws IOException { + key = new BucketKey(in); + docCount = in.readVLong(); + aggregations = InternalAggregations.readFrom(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + key.writeTo(out); + out.writeVLong(getDocCount()); + aggregations.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(CommonFields.DOC_COUNT.getPreferredName(), docCount); + builder.field(CommonFields.KEY.getPreferredName()); + key.toXContent(builder, params); + aggregations.toXContentInternal(builder, params); + builder.endObject(); + return builder; + } + + BucketKey getRawKey() { + return key; + } + + @Override + public Object getKey() { + return key; + } + + @Override + public String getKeyAsString() { + return key.asString(); + } + + @Override + public long getDocCount() { + return docCount; + } + + @Override + public Aggregations getAggregations() { + return aggregations; + } + + @Override + public String toString() { + return "Bucket{" + "key=" + getKeyAsString() + ", docCount=" + docCount + ", aggregations=" + aggregations.asMap() + "}\n"; + } + + @Override + public int compareTo(Bucket o) { + return key.compareTo(o.key); + } + + } + + private final List buckets; + private final int maxUniqueTokens; + private final int similarityThreshold; + private final int maxMatchTokens; + private final int requiredSize; + private final long minDocCount; + + protected InternalCategorizationAggregation( + String name, + int requiredSize, + long minDocCount, + int maxUniqueTokens, + int maxMatchTokens, + int similarityThreshold, + Map metadata + ) { + this(name, requiredSize, minDocCount, maxUniqueTokens, maxMatchTokens, similarityThreshold, metadata, new ArrayList<>()); + } + + protected InternalCategorizationAggregation( + String name, + int requiredSize, + long minDocCount, + int maxUniqueTokens, + int maxMatchTokens, + int similarityThreshold, + Map metadata, + List buckets + ) { + super(name, metadata); + this.buckets = buckets; + this.maxUniqueTokens = maxUniqueTokens; + this.maxMatchTokens = maxMatchTokens; + this.similarityThreshold = similarityThreshold; + this.minDocCount = minDocCount; + this.requiredSize = requiredSize; + } + + public InternalCategorizationAggregation(StreamInput in) throws IOException { + super(in); + this.maxUniqueTokens = in.readVInt(); + this.maxMatchTokens = in.readVInt(); + this.similarityThreshold = in.readVInt(); + this.buckets = in.readList(Bucket::new); + this.requiredSize = readSize(in); + this.minDocCount = in.readVLong(); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeVInt(maxUniqueTokens); + out.writeVInt(maxMatchTokens); + out.writeVInt(similarityThreshold); + out.writeList(buckets); + writeSize(requiredSize, out); + out.writeVLong(minDocCount); + } + + @Override + public XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException { + builder.startArray(CommonFields.BUCKETS.getPreferredName()); + for (Bucket bucket : buckets) { + bucket.toXContent(builder, params); + } + builder.endArray(); + return builder; + } + + @Override + public InternalCategorizationAggregation create(List buckets) { + return new InternalCategorizationAggregation( + name, + requiredSize, + minDocCount, + maxUniqueTokens, + maxMatchTokens, + similarityThreshold, + super.metadata, + buckets + ); + } + + @Override + public Bucket createBucket(InternalAggregations aggregations, Bucket prototype) { + return new Bucket(prototype.key, prototype.docCount, aggregations); + } + + @Override + protected Bucket reduceBucket(List buckets, ReduceContext context) { + throw new IllegalArgumentException("For optimization purposes, typical bucket path is not supported"); + } + + @Override + public List getBuckets() { + return buckets; + } + + @Override + public String getWriteableName() { + return CategorizeTextAggregationBuilder.NAME; + } + + @Override + public InternalAggregation reduce(List aggregations, ReduceContext reduceContext) { + try (CategorizationBytesRefHash hash = new CategorizationBytesRefHash(new BytesRefHash(1L, reduceContext.bigArrays()))) { + CategorizationTokenTree categorizationTokenTree = new CategorizationTokenTree( + maxUniqueTokens, + maxMatchTokens, + similarityThreshold + ); + // TODO: Could we do a merge sort similar to terms? + // It would require us returning partial reductions sorted by key, not by doc_count + // First, make sure we have all the counts for equal categorizations + Map reduced = new HashMap<>(); + for (InternalAggregation aggregation : aggregations) { + InternalCategorizationAggregation categorizationAggregation = (InternalCategorizationAggregation) aggregation; + for (Bucket bucket : categorizationAggregation.buckets) { + reduced.computeIfAbsent(bucket.key, key -> new DelayedCategorizationBucket(key, new ArrayList<>(1), 0L)).add(bucket); + } + } + + reduced.values() + .stream() + .sorted(Comparator.comparing(DelayedCategorizationBucket::getDocCount).reversed()) + .forEach(bucket -> + // Parse tokens takes document count into account and merging on smallest groups + categorizationTokenTree.parseTokens(hash.getIds(bucket.key.keyAsTokens()), bucket.docCount) + ); + categorizationTokenTree.mergeSmallestChildren(); + Map mergedBuckets = new HashMap<>(); + for (DelayedCategorizationBucket delayedBucket : reduced.values()) { + TextCategorization group = categorizationTokenTree.parseTokensConst(hash.getIds(delayedBucket.key.keyAsTokens())) + .orElseThrow( + () -> new AggregationExecutionException( + "Unexpected null categorization group for bucket [" + delayedBucket.key.asString() + "]" + ) + ); + BytesRef[] categoryTokens = hash.getDeeps(group.getCategorization()); + + BucketKey key = reduceContext.isFinalReduce() ? + BucketKey.withCollapsedWildcards(categoryTokens) : + new BucketKey(categoryTokens); + mergedBuckets.computeIfAbsent( + key, + k -> new DelayedCategorizationBucket(k, new ArrayList<>(delayedBucket.toReduce.size()), 0L) + ).add(delayedBucket); + } + + final int size = reduceContext.isFinalReduce() == false ? mergedBuckets.size() : Math.min(requiredSize, mergedBuckets.size()); + final PriorityQueue pq = new BucketCountPriorityQueue(size); + for (Map.Entry keyAndBuckets : mergedBuckets.entrySet()) { + final BucketKey key = keyAndBuckets.getKey(); + DelayedCategorizationBucket bucket = keyAndBuckets.getValue(); + Bucket newBucket = bucket.reduce(key, reduceContext); + if ((newBucket.docCount >= minDocCount) || reduceContext.isFinalReduce() == false) { + Bucket removed = pq.insertWithOverflow(newBucket); + if (removed == null) { + reduceContext.consumeBucketsAndMaybeBreak(1); + } else { + reduceContext.consumeBucketsAndMaybeBreak(-countInnerBucket(removed)); + } + } else { + reduceContext.consumeBucketsAndMaybeBreak(-countInnerBucket(newBucket)); + } + } + Bucket[] bucketList = new Bucket[pq.size()]; + for (int i = pq.size() - 1; i >= 0; i--) { + bucketList[i] = pq.pop(); + } + // Keep the top categories top, but then sort by the key for those with duplicate counts + if (reduceContext.isFinalReduce()) { + Arrays.sort(bucketList, Comparator.comparing(Bucket::getDocCount).reversed().thenComparing(Bucket::getRawKey)); + } + return new InternalCategorizationAggregation( + name, + requiredSize, + minDocCount, + maxUniqueTokens, + maxMatchTokens, + similarityThreshold, + metadata, + Arrays.asList(bucketList) + ); + } + } + + public int getMaxUniqueTokens() { + return maxUniqueTokens; + } + + public int getSimilarityThreshold() { + return similarityThreshold; + } + + public int getMaxMatchTokens() { + return maxMatchTokens; + } + + public int getRequiredSize() { + return requiredSize; + } + + public long getMinDocCount() { + return minDocCount; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorization.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorization.java new file mode 100644 index 0000000000000..7ea72f489ae2d --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorization.java @@ -0,0 +1,117 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.RamUsageEstimator; + +import java.util.Arrays; + +import static org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash.WILD_CARD_ID; + +/** + * A text categorization group that provides methods for: + * - calculating similarity between it and a new text + * - expanding the existing categorization by adding a new array of tokens + */ +class TextCategorization implements Accountable { + + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TextCategorization.class); + private final long id; + private final int[] categorization; + private final long[] tokenCounts; + private long count; + + // Used at the shard level for tracking the bucket ordinal for collecting sub aggregations + long bucketOrd; + + TextCategorization(int[] tokenIds, long count, long id) { + this.id = id; + this.categorization = tokenIds; + this.count = count; + this.tokenCounts = new long[tokenIds.length]; + Arrays.fill(this.tokenCounts, count); + } + + public long getId() { + return id; + } + + int[] getCategorization() { + return categorization; + } + + public long getCount() { + return count; + } + + Similarity calculateSimilarity(int[] tokenIds) { + assert tokenIds.length == this.categorization.length; + int eqParams = 0; + long tokenCount = 0; + long tokensKept = 0; + for (int i = 0; i < tokenIds.length; i++) { + if (tokenIds[i] == this.categorization[i]) { + tokensKept += tokenCounts[i]; + tokenCount += tokenCounts[i]; + } else if (this.categorization[i] == WILD_CARD_ID) { + eqParams++; + } else { + tokenCount += tokenCounts[i]; + } + } + return new Similarity((double) tokensKept / tokenCount, eqParams); + } + + void addTokens(int[] tokenIds, long docCount) { + assert tokenIds.length == this.categorization.length; + for (int i = 0; i < tokenIds.length; i++) { + if (tokenIds[i] != this.categorization[i]) { + this.categorization[i] = WILD_CARD_ID; + } else { + tokenCounts[i] += docCount; + } + } + this.count += docCount; + } + + @Override + public long ramBytesUsed() { + return SHALLOW_SIZE + + RamUsageEstimator.sizeOf(categorization) // categorization token Ids + + RamUsageEstimator.sizeOf(tokenCounts); // tokenCounts + } + + static class Similarity implements Comparable { + private final double similarity; + private final int wildCardCount; + + private Similarity(double similarity, int wildCardCount) { + this.similarity = similarity; + this.wildCardCount = wildCardCount; + } + + @Override + public int compareTo(Similarity o) { + int d = Double.compare(similarity, o.similarity); + if (d != 0) { + return d; + } + return Integer.compare(wildCardCount, o.wildCardCount); + } + + public double getSimilarity() { + return similarity; + } + + public int getWildCardCount() { + return wildCardCount; + } + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java new file mode 100644 index 0000000000000..7b13e93d8f1ea --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java @@ -0,0 +1,406 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.search.aggregations.AggregationExecutionException; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.PriorityQueue; +import java.util.stream.Collectors; + +import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_REF; +import static org.apache.lucene.util.RamUsageEstimator.sizeOfCollection; +import static org.apache.lucene.util.RamUsageEstimator.sizeOfMap; +import static org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash.WILD_CARD_ID; + +/** + * Tree node classes for the categorization token tree. + * + * Two major node types exist: + * - Inner: which are nodes that have children token nodes + * - Leaf: Which collection multiple {@link TextCategorization} based on similarity restrictions + */ +abstract class TreeNode implements Accountable { + + private long count; + + TreeNode(long count) { + this.count = count; + } + + abstract void mergeWith(TreeNode otherNode); + + abstract boolean isLeaf(); + + final void incCount(long count) { + this.count += count; + } + + final long getCount() { + return count; + } + + // TODO add option for calculating the cost of adding the new group + abstract TextCategorization addText(int[] tokenIds, long docCount, CategorizationTokenTree treeNodeFactory); + + abstract TextCategorization getCategorization(int[] tokenIds); + + abstract List getAllChildrenTextCategorizations(); + + abstract void collapseTinyChildren(); + + static class LeafTreeNode extends TreeNode { + private final List textCategorizations; + private final int similarityThreshold; + + LeafTreeNode(long count, int similarityThreshold) { + super(count); + this.textCategorizations = new ArrayList<>(); + this.similarityThreshold = similarityThreshold; + if (similarityThreshold < 1 || similarityThreshold > 100) { + throw new IllegalArgumentException("similarityThreshold must be between 1 and 100"); + } + } + + @Override + public boolean isLeaf() { + return true; + } + + @Override + void mergeWith(TreeNode treeNode) { + if (treeNode == null) { + return; + } + if (treeNode.isLeaf() == false) { + throw new UnsupportedOperationException( + "cannot merge leaf node with non-leaf node in categorization tree \n[" + this + "]\n[" + treeNode + "]" + ); + } + incCount(treeNode.getCount()); + LeafTreeNode otherLeaf = (LeafTreeNode) treeNode; + for (TextCategorization group : otherLeaf.textCategorizations) { + if (getAndUpdateTextCategorization(group.getCategorization(), group.getCount()).isPresent() == false) { + putNewTextCategorization(group); + } + } + } + + @Override + public long ramBytesUsed() { + return Long.BYTES // count + + NUM_BYTES_OBJECT_REF // list reference + + Integer.BYTES // similarityThreshold + + sizeOfCollection(textCategorizations); + } + + @Override + public TextCategorization addText(int[] tokenIds, long docCount, CategorizationTokenTree treeNodeFactory) { + return getAndUpdateTextCategorization(tokenIds, docCount).orElseGet(() -> { + // Need to update the tree if possible + TextCategorization categorization = putNewTextCategorization(treeNodeFactory.newCategorization(docCount, tokenIds)); + // Get the regular size bytes from the TextCategorization and how much it costs to reference it + treeNodeFactory.incSize(categorization.ramBytesUsed() + RamUsageEstimator.NUM_BYTES_OBJECT_REF); + return categorization; + }); + } + + @Override + List getAllChildrenTextCategorizations() { + return textCategorizations; + } + + @Override + void collapseTinyChildren() {} + + private Optional getAndUpdateTextCategorization(int[] tokenIds, long docCount) { + return getBestCategorization(tokenIds).map(bestGroupAndSimilarity -> { + if ((bestGroupAndSimilarity.v2() * 100) >= similarityThreshold) { + bestGroupAndSimilarity.v1().addTokens(tokenIds, docCount); + return bestGroupAndSimilarity.v1(); + } + return null; + }); + } + + TextCategorization putNewTextCategorization(TextCategorization categorization) { + textCategorizations.add(categorization); + return categorization; + } + + private Optional> getBestCategorization(int[] tokenIds) { + if (textCategorizations.isEmpty()) { + return Optional.empty(); + } + if (textCategorizations.size() == 1) { + return Optional.of( + new Tuple<>(textCategorizations.get(0), textCategorizations.get(0).calculateSimilarity( tokenIds).getSimilarity()) + ); + } + TextCategorization.Similarity maxSimilarity = null; + TextCategorization bestGroup = null; + for (TextCategorization textCategorization : this.textCategorizations) { + TextCategorization.Similarity groupSimilarity = textCategorization.calculateSimilarity( tokenIds); + if (maxSimilarity == null || groupSimilarity.compareTo(maxSimilarity) > 0) { + maxSimilarity = groupSimilarity; + bestGroup = textCategorization; + } + } + return Optional.of(new Tuple<>(bestGroup, maxSimilarity.getSimilarity())); + } + + @Override + public TextCategorization getCategorization(final int[] tokenIds) { + return getBestCategorization(tokenIds).map(Tuple::v1).orElse(null); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + LeafTreeNode that = (LeafTreeNode) o; + return that.similarityThreshold == similarityThreshold + && Objects.equals(textCategorizations, that.textCategorizations); + } + + @Override + public int hashCode() { + return Objects.hash(textCategorizations, similarityThreshold); + } + } + + static class InnerTreeNode extends TreeNode { + + // TODO: Change to LongObjectMap? + private final Map children; + private final int childrenTokenPos; + private final int maxChildren; + private final PriorityQueue smallestChild; + + InnerTreeNode(long count, int childrenTokenPos, int maxChildren) { + super(count); + children = new HashMap<>(); + this.childrenTokenPos = childrenTokenPos; + this.maxChildren = maxChildren; + this.smallestChild = new PriorityQueue<>(maxChildren, Comparator.comparing(NativeIntLongPair::count)); + } + + @Override + boolean isLeaf() { + return false; + } + + @Override + public TextCategorization getCategorization(final int[] tokenIds) { + return getChild(tokenIds[childrenTokenPos]).or(() -> getChild(WILD_CARD_ID)) + .map(node -> node.getCategorization(tokenIds)) + .orElse(null); + } + + @Override + public long ramBytesUsed() { + return Long.BYTES // count + + NUM_BYTES_OBJECT_REF // children reference + + Integer.BYTES // childrenTokenPos + + Integer.BYTES // maxChildren + + NUM_BYTES_OBJECT_REF // smallestChildReference + + sizeOfMap(children, NUM_BYTES_OBJECT_REF) // children, + // Number of items in the queue, reference to tuple, and then the tuple references + + (long) smallestChild.size() * (NUM_BYTES_OBJECT_REF + Integer.BYTES + Long.BYTES); + } + + @Override + public TextCategorization addText(final int[] tokenIds, final long docCount, final CategorizationTokenTree treeNodeFactory) { + final int currentToken = tokenIds[childrenTokenPos]; + TreeNode child = getChild(currentToken).map(node -> { + node.incCount(docCount); + if (smallestChild.isEmpty() == false && smallestChild.peek().tokenId == currentToken) { + smallestChild.add(smallestChild.poll()); + } + return node; + }).orElseGet(() -> { + TreeNode newNode = treeNodeFactory.newNode(docCount, childrenTokenPos + 1, tokenIds); + // The size of the node + entry (since it is a map entry) + extra reference for priority queue + treeNodeFactory.incSize( + newNode.ramBytesUsed() + RamUsageEstimator.HASHTABLE_RAM_BYTES_PER_ENTRY + RamUsageEstimator.NUM_BYTES_OBJECT_REF + ); + return addChild(currentToken, newNode); + }); + return child.addText(tokenIds, docCount, treeNodeFactory); + } + + @Override + void collapseTinyChildren() { + if (this.isLeaf()) { + return; + } + if (children.size() <= 1) { + return; + } + Optional maybeWildChild = getChild(WILD_CARD_ID).or(() -> { + if ((double) smallestChild.peek().count / this.getCount() <= 1.0 / maxChildren) { + TreeNode tinyChild = children.remove(smallestChild.poll().tokenId); + return Optional.of(addChild(WILD_CARD_ID, tinyChild)); + } + return Optional.empty(); + }); + if (maybeWildChild.isPresent()) { + TreeNode wildChild = maybeWildChild.get(); + NativeIntLongPair tinyNode; + while ((tinyNode = smallestChild.poll()) != null) { + // If we have no more tiny nodes, stop iterating over them + if ((double) tinyNode.count / this.getCount() > 1.0 / maxChildren) { + smallestChild.add(tinyNode); + break; + } else { + wildChild.mergeWith(children.remove(tinyNode.count)); + } + } + } + children.values().forEach(TreeNode::collapseTinyChildren); + } + + @Override + void mergeWith(TreeNode treeNode) { + if (treeNode == null) { + return; + } + incCount(treeNode.count); + if (treeNode.isLeaf()) { + throw new UnsupportedOperationException( + "cannot merge non-leaf node with leaf node in categorization tree \n[" + this + "]\n[" + treeNode + "]" + ); + } + InnerTreeNode innerTreeNode = (InnerTreeNode) treeNode; + TreeNode siblingWildChild = innerTreeNode.children.remove(WILD_CARD_ID); + addChild(WILD_CARD_ID, siblingWildChild); + NativeIntLongPair siblingChild; + while ((siblingChild = innerTreeNode.smallestChild.poll()) != null) { + TreeNode nephewNode = innerTreeNode.children.remove(siblingChild.tokenId); + addChild(siblingChild.tokenId, nephewNode); + } + } + + private TreeNode addChild(int tokenId, TreeNode node) { + if (node == null) { + return null; + } + Optional existingChild = getChild(tokenId).map(existingNode -> { + existingNode.mergeWith(node); + if (smallestChild.isEmpty() == false && smallestChild.peek().tokenId == tokenId) { + smallestChild.poll(); + smallestChild.add(NativeIntLongPair.of(tokenId, existingNode.getCount())); + } + return existingNode; + }); + if (existingChild.isPresent()) { + return existingChild.get(); + } + if (children.size() == maxChildren) { + return getChild(WILD_CARD_ID).map(wildChild -> { + final TreeNode toMerge; + final TreeNode toReturn; + if (smallestChild.isEmpty() == false && node.getCount() > smallestChild.peek().count) { + toMerge = children.remove(smallestChild.poll().tokenId); + addChildAndUpdateSmallest(tokenId, node); + toReturn = node; + } else { + toMerge = node; + toReturn = wildChild; + } + wildChild.mergeWith(toMerge); + return toReturn; + }).orElseThrow(() -> new AggregationExecutionException("Missing wild_card child even though maximum children reached")); + } + // we are about to hit the limit, add a wild card if we need to and then add the new child as appropriate + if (children.size() == maxChildren - 1) { + // If we already have a wild token, simply adding the new token is acceptable as we won't breach our limit + if (children.containsKey(WILD_CARD_ID)) { + addChildAndUpdateSmallest(tokenId, node); + } else { // if we don't have a wild card child, we need to add one now + if (tokenId == WILD_CARD_ID) { + addChildAndUpdateSmallest(tokenId, node); + } else { + if (smallestChild.isEmpty() == false && node.count > smallestChild.peek().count) { + addChildAndUpdateSmallest(WILD_CARD_ID, children.remove(smallestChild.poll().tokenId)); + addChildAndUpdateSmallest(tokenId, node); + } else { + addChildAndUpdateSmallest(WILD_CARD_ID, node); + } + } + } + } else { + addChildAndUpdateSmallest(tokenId, node); + } + return node; + } + + private void addChildAndUpdateSmallest(int tokenId, TreeNode node) { + children.put(tokenId, node); + if (tokenId != WILD_CARD_ID) { + smallestChild.add(NativeIntLongPair.of(tokenId, node.count)); + } + } + + private Optional getChild(int tokenId) { + return Optional.ofNullable(children.get(tokenId)); + } + + public List getAllChildrenTextCategorizations() { + return children.values().stream().flatMap(c -> c.getAllChildrenTextCategorizations().stream()).collect(Collectors.toList()); + } + + boolean hasChild(int tokenId) { + return children.containsKey(tokenId); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + InnerTreeNode treeNode = (InnerTreeNode) o; + return childrenTokenPos == treeNode.childrenTokenPos + && getCount() == treeNode.getCount() + && Objects.equals(children, treeNode.children) + && Objects.equals(smallestChild, treeNode.smallestChild); + } + + @Override + public int hashCode() { + return Objects.hash(children, childrenTokenPos, smallestChild, getCount()); + } + } + + private static class NativeIntLongPair { + private final int tokenId; + private final long count; + + static NativeIntLongPair of(int tokenId, long count) { + return new NativeIntLongPair(tokenId, count); + } + + NativeIntLongPair(int tokenId, long count) { + this.tokenId = tokenId; + this.count = count; + } + + public long count() { + return count; + } + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/UnmappedCategorizationAggregation.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/UnmappedCategorizationAggregation.java new file mode 100644 index 0000000000000..ae1081f66d09f --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/UnmappedCategorizationAggregation.java @@ -0,0 +1,71 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.elasticsearch.search.aggregations.InternalAggregation; +import org.elasticsearch.search.aggregations.InternalAggregations; + +import java.util.List; +import java.util.Map; + + +class UnmappedCategorizationAggregation extends InternalCategorizationAggregation { + protected UnmappedCategorizationAggregation( + String name, + int requiredSize, + long minDocCount, + int maxChildren, + int maxDepth, + int similarityThreshold, + Map metadata + ) { + super(name, requiredSize, minDocCount, maxChildren, maxDepth, similarityThreshold, metadata); + } + + @Override + public InternalCategorizationAggregation create(List buckets) { + return new UnmappedCategorizationAggregation( + name, + getRequiredSize(), + getMinDocCount(), + getMaxUniqueTokens(), + getMaxMatchTokens(), + getSimilarityThreshold(), + super.metadata + ); + } + + @Override + public Bucket createBucket(InternalAggregations aggregations, Bucket prototype) { + throw new UnsupportedOperationException("not supported for UnmappedCategorizationAggregation"); + } + + @Override + public InternalAggregation reduce(List aggregations, ReduceContext reduceContext) { + return new UnmappedCategorizationAggregation( + name, + getRequiredSize(), + getMinDocCount(), + getMaxUniqueTokens(), + getMaxMatchTokens(), + getSimilarityThreshold(), + super.metadata + ); + } + + @Override + public boolean isMapped() { + return false; + } + + @Override + public List getBuckets() { + return List.of(); + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/categorization/CategorizationAnalyzer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/categorization/CategorizationAnalyzer.java index 4fc99e1502851..6147bc0256ca5 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/categorization/CategorizationAnalyzer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/categorization/CategorizationAnalyzer.java @@ -10,11 +10,11 @@ import org.apache.lucene.analysis.TokenStream; import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Tuple; import org.elasticsearch.index.analysis.AnalysisRegistry; import org.elasticsearch.xpack.core.ml.job.config.CategorizationAnalyzerConfig; -import java.io.Closeable; import java.io.IOException; import java.util.ArrayList; import java.util.List; @@ -25,7 +25,7 @@ * Converts messages to lists of tokens that will be fed to the ML categorization algorithm. * */ -public class CategorizationAnalyzer implements Closeable { +public class CategorizationAnalyzer implements Releasable { private final Analyzer analyzer; private final boolean closeAnalyzer; @@ -38,6 +38,16 @@ public CategorizationAnalyzer(AnalysisRegistry analysisRegistry, closeAnalyzer = tuple.v2(); } + public CategorizationAnalyzer(Analyzer analyzer, boolean closeAnalyzer) { + this.analyzer = analyzer; + this.closeAnalyzer = closeAnalyzer; + } + + public final TokenStream tokenStream(final String fieldName, + final String text) { + return analyzer.tokenStream(fieldName, text); + } + /** * Release resources held by the analyzer (unless it's global). */ diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/LocalStateMachineLearning.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/LocalStateMachineLearning.java index 750f36863a141..d6f147f831df7 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/LocalStateMachineLearning.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/LocalStateMachineLearning.java @@ -17,6 +17,9 @@ import org.elasticsearch.common.breaker.NoopCircuitBreaker; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.analysis.CharFilterFactory; +import org.elasticsearch.index.analysis.TokenizerFactory; +import org.elasticsearch.indices.analysis.AnalysisModule; import org.elasticsearch.license.LicenseService; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.plugins.ActionPlugin; @@ -33,6 +36,7 @@ import java.util.Collections; import java.util.HashSet; import java.util.List; +import java.util.Map; import static org.elasticsearch.xpack.ml.MachineLearning.TRAINED_MODEL_CIRCUIT_BREAKER_NAME; @@ -84,6 +88,20 @@ public void cleanUpFeature( mlPlugin.cleanUpFeature(clusterService, client, finalListener); } + @Override + public List getAggregations() { + return mlPlugin.getAggregations(); + } + + @Override + public Map> getCharFilters() { + return mlPlugin.getCharFilters(); + } + + @Override + public Map> getTokenizers() { + return mlPlugin.getTokenizers(); + } /** * This is only required as we now have to have the GetRollupIndexCapsAction as a valid action in our node. diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilderTests.java new file mode 100644 index 0000000000000..7b907ea3ecd29 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilderTests.java @@ -0,0 +1,60 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.search.aggregations.BaseAggregationTestCase; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.job.config.CategorizationAnalyzerConfigTests; + +import java.util.Collection; +import java.util.Collections; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class CategorizeTextAggregationBuilderTests extends BaseAggregationTestCase { + + @Override + protected Collection> getExtraPlugins() { + return Collections.singletonList(MachineLearning.class); + } + + @Override + protected CategorizeTextAggregationBuilder createTestAggregatorBuilder() { + CategorizeTextAggregationBuilder builder = new CategorizeTextAggregationBuilder(randomAlphaOfLength(10), randomAlphaOfLength(10)); + final boolean setFilters = randomBoolean(); + if (setFilters) { + builder.setCategorizationFilters(Stream.generate(() -> randomAlphaOfLength(10)).limit(5).collect(Collectors.toList())); + } + if (setFilters == false) { + builder.setCategorizationAnalyzerConfig(CategorizationAnalyzerConfigTests.createRandomized().build()); + } + if (randomBoolean()) { + builder.setMaxUniqueTokens(randomIntBetween(1, 500)); + } + if (randomBoolean()) { + builder.setMaxMatchedTokens(randomIntBetween(1, 10)); + } + if (randomBoolean()) { + builder.setSimilarityThreshold(randomIntBetween(1, 100)); + } + if (randomBoolean()) { + builder.minDocCount(randomLongBetween(1, 100)); + } + if (randomBoolean()) { + builder.shardMinDocCount(randomLongBetween(1, 100)); + } + if (randomBoolean()) { + builder.size(randomIntBetween(1, 100)); + } + if (randomBoolean()) { + builder.shardSize(randomIntBetween(1, 100)); + } + return builder; + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorTests.java new file mode 100644 index 0000000000000..95cfdcb0f8f8f --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorTests.java @@ -0,0 +1,342 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.apache.lucene.document.SortedNumericDocValuesField; +import org.apache.lucene.document.StoredField; +import org.apache.lucene.index.RandomIndexWriter; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.env.Environment; +import org.elasticsearch.env.TestEnvironment; +import org.elasticsearch.index.mapper.TextFieldMapper; +import org.elasticsearch.indices.analysis.AnalysisModule; +import org.elasticsearch.plugins.SearchPlugin; +import org.elasticsearch.search.aggregations.AggregatorTestCase; +import org.elasticsearch.search.aggregations.bucket.histogram.Histogram; +import org.elasticsearch.search.aggregations.bucket.histogram.HistogramAggregationBuilder; +import org.elasticsearch.search.aggregations.bucket.histogram.InternalHistogram; +import org.elasticsearch.search.aggregations.metrics.Avg; +import org.elasticsearch.search.aggregations.metrics.AvgAggregationBuilder; +import org.elasticsearch.search.aggregations.metrics.Max; +import org.elasticsearch.search.aggregations.metrics.MaxAggregationBuilder; +import org.elasticsearch.search.aggregations.metrics.Min; +import org.elasticsearch.search.aggregations.metrics.MinAggregationBuilder; +import org.elasticsearch.xpack.ml.MachineLearning; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; + +public class CategorizeTextAggregatorTests extends AggregatorTestCase { + + @Override + protected AnalysisModule createAnalysisModule() throws Exception { + return new AnalysisModule( + TestEnvironment.newEnvironment( + Settings.builder().put(Environment.PATH_HOME_SETTING.getKey(), createTempDir().toString()).build() + ), + List.of(new MachineLearning(Settings.EMPTY, null)) + ); + } + + @Override + protected List getSearchPlugins() { + return List.of(new MachineLearning(Settings.EMPTY, null)); + } + + private static final String TEXT_FIELD_NAME = "text"; + private static final String NUMERIC_FIELD_NAME = "value"; + + public void testCategorizationWithoutSubAggs() throws Exception { + testCase( + new CategorizeTextAggregationBuilder("my_agg", TEXT_FIELD_NAME), + new MatchAllDocsQuery(), + CategorizeTextAggregatorTests::writeTestDocs, + (InternalCategorizationAggregation result) -> { + assertThat(result.getBuckets(), hasSize(2)); + assertThat(result.getBuckets().get(0).docCount, equalTo(6L)); + assertThat(result.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); + assertThat(result.getBuckets().get(1).docCount, equalTo(2L)); + assertThat( + result.getBuckets().get(1).getKeyAsString(), + equalTo("Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception") + ); + }, + new TextFieldMapper.TextFieldType(TEXT_FIELD_NAME), + longField(NUMERIC_FIELD_NAME) + ); + } + + public void testCategorizationWithSubAggs() throws Exception { + CategorizeTextAggregationBuilder aggBuilder = new CategorizeTextAggregationBuilder("my_agg", TEXT_FIELD_NAME).subAggregation( + new MaxAggregationBuilder("max").field(NUMERIC_FIELD_NAME) + ) + .subAggregation(new AvgAggregationBuilder("avg").field(NUMERIC_FIELD_NAME)) + .subAggregation(new MinAggregationBuilder("min").field(NUMERIC_FIELD_NAME)); + testCase( + aggBuilder, + new MatchAllDocsQuery(), + CategorizeTextAggregatorTests::writeTestDocs, + (InternalCategorizationAggregation result) -> { + assertThat(result.getBuckets(), hasSize(2)); + assertThat(result.getBuckets().get(0).docCount, equalTo(6L)); + assertThat(result.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); + assertThat(((Max) result.getBuckets().get(0).aggregations.get("max")).getValue(), equalTo(5.0)); + assertThat(((Min) result.getBuckets().get(0).aggregations.get("min")).getValue(), equalTo(0.0)); + assertThat(((Avg) result.getBuckets().get(0).aggregations.get("avg")).getValue(), equalTo(2.5)); + + assertThat(result.getBuckets().get(1).docCount, equalTo(2L)); + assertThat( + result.getBuckets().get(1).getKeyAsString(), + equalTo("Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception") + ); + assertThat(((Max) result.getBuckets().get(1).aggregations.get("max")).getValue(), equalTo(4.0)); + assertThat(((Min) result.getBuckets().get(1).aggregations.get("min")).getValue(), equalTo(0.0)); + assertThat(((Avg) result.getBuckets().get(1).aggregations.get("avg")).getValue(), equalTo(2.0)); + }, + new TextFieldMapper.TextFieldType(TEXT_FIELD_NAME), + longField(NUMERIC_FIELD_NAME) + ); + } + + public void testCategorizationWithMultiBucketSubAggs() throws Exception { + CategorizeTextAggregationBuilder aggBuilder = new CategorizeTextAggregationBuilder("my_agg", TEXT_FIELD_NAME).subAggregation( + new HistogramAggregationBuilder("histo").field(NUMERIC_FIELD_NAME) + .interval(2) + .subAggregation(new MaxAggregationBuilder("max").field(NUMERIC_FIELD_NAME)) + .subAggregation(new AvgAggregationBuilder("avg").field(NUMERIC_FIELD_NAME)) + .subAggregation(new MinAggregationBuilder("min").field(NUMERIC_FIELD_NAME)) + ); + testCase( + aggBuilder, + new MatchAllDocsQuery(), + CategorizeTextAggregatorTests::writeTestDocs, + (InternalCategorizationAggregation result) -> { + assertThat(result.getBuckets(), hasSize(2)); + assertThat(result.getBuckets().get(0).docCount, equalTo(6L)); + assertThat(result.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); + Histogram histo = result.getBuckets().get(0).aggregations.get("histo"); + assertThat(histo.getBuckets(), hasSize(3)); + for (Histogram.Bucket bucket : histo.getBuckets()) { + assertThat(bucket.getDocCount(), equalTo(2L)); + } + assertThat(((Max) histo.getBuckets().get(0).getAggregations().get("max")).getValue(), equalTo(1.0)); + assertThat(((Min) histo.getBuckets().get(0).getAggregations().get("min")).getValue(), equalTo(0.0)); + assertThat(((Avg) histo.getBuckets().get(0).getAggregations().get("avg")).getValue(), equalTo(0.5)); + assertThat(((Max) histo.getBuckets().get(1).getAggregations().get("max")).getValue(), equalTo(3.0)); + assertThat(((Min) histo.getBuckets().get(1).getAggregations().get("min")).getValue(), equalTo(2.0)); + assertThat(((Avg) histo.getBuckets().get(1).getAggregations().get("avg")).getValue(), equalTo(2.5)); + assertThat(((Max) histo.getBuckets().get(2).getAggregations().get("max")).getValue(), equalTo(5.0)); + assertThat(((Min) histo.getBuckets().get(2).getAggregations().get("min")).getValue(), equalTo(4.0)); + assertThat(((Avg) histo.getBuckets().get(2).getAggregations().get("avg")).getValue(), equalTo(4.5)); + + assertThat(result.getBuckets().get(1).docCount, equalTo(2L)); + assertThat( + result.getBuckets().get(1).getKeyAsString(), + equalTo("Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception") + ); + histo = result.getBuckets().get(1).aggregations.get("histo"); + assertThat(histo.getBuckets(), hasSize(3)); + assertThat(histo.getBuckets().get(0).getDocCount(), equalTo(1L)); + assertThat(histo.getBuckets().get(1).getDocCount(), equalTo(0L)); + assertThat(histo.getBuckets().get(2).getDocCount(), equalTo(1L)); + assertThat(((Avg) histo.getBuckets().get(0).getAggregations().get("avg")).getValue(), equalTo(0.0)); + assertThat(((Avg) histo.getBuckets().get(2).getAggregations().get("avg")).getValue(), equalTo(4.0)); + }, + new TextFieldMapper.TextFieldType(TEXT_FIELD_NAME), + longField(NUMERIC_FIELD_NAME) + ); + } + + public void testCategorizationAsSubAgg() throws Exception { + HistogramAggregationBuilder aggBuilder = new HistogramAggregationBuilder("histo").field(NUMERIC_FIELD_NAME) + .interval(2) + .subAggregation( + new CategorizeTextAggregationBuilder("my_agg", TEXT_FIELD_NAME).subAggregation( + new MaxAggregationBuilder("max").field(NUMERIC_FIELD_NAME) + ) + .subAggregation(new AvgAggregationBuilder("avg").field(NUMERIC_FIELD_NAME)) + .subAggregation(new MinAggregationBuilder("min").field(NUMERIC_FIELD_NAME)) + ); + testCase( + aggBuilder, + new MatchAllDocsQuery(), + CategorizeTextAggregatorTests::writeTestDocs, + (InternalHistogram result) -> { + assertThat(result.getBuckets(), hasSize(3)); + + // First histo bucket + assertThat(result.getBuckets().get(0).getDocCount(), equalTo(3L)); + InternalCategorizationAggregation categorizationAggregation = result.getBuckets().get(0).getAggregations().get("my_agg"); + assertThat(categorizationAggregation.getBuckets(), hasSize(2)); + assertThat(categorizationAggregation.getBuckets().get(0).docCount, equalTo(2L)); + assertThat(categorizationAggregation.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); + assertThat(((Max) categorizationAggregation.getBuckets().get(0).aggregations.get("max")).getValue(), equalTo(1.0)); + assertThat(((Min) categorizationAggregation.getBuckets().get(0).aggregations.get("min")).getValue(), equalTo(0.0)); + assertThat(((Avg) categorizationAggregation.getBuckets().get(0).aggregations.get("avg")).getValue(), equalTo(0.5)); + + assertThat(categorizationAggregation.getBuckets().get(1).docCount, equalTo(1L)); + assertThat( + categorizationAggregation.getBuckets().get(1).getKeyAsString(), + equalTo("Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception") + ); + assertThat(((Max) categorizationAggregation.getBuckets().get(1).aggregations.get("max")).getValue(), equalTo(0.0)); + assertThat(((Min) categorizationAggregation.getBuckets().get(1).aggregations.get("min")).getValue(), equalTo(0.0)); + assertThat(((Avg) categorizationAggregation.getBuckets().get(1).aggregations.get("avg")).getValue(), equalTo(0.0)); + + // Second histo bucket + assertThat(result.getBuckets().get(1).getDocCount(), equalTo(2L)); + categorizationAggregation = result.getBuckets().get(1).getAggregations().get("my_agg"); + assertThat(categorizationAggregation.getBuckets(), hasSize(1)); + assertThat(categorizationAggregation.getBuckets().get(0).docCount, equalTo(2L)); + assertThat(categorizationAggregation.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); + assertThat(((Max) categorizationAggregation.getBuckets().get(0).aggregations.get("max")).getValue(), equalTo(3.0)); + assertThat(((Min) categorizationAggregation.getBuckets().get(0).aggregations.get("min")).getValue(), equalTo(2.0)); + assertThat(((Avg) categorizationAggregation.getBuckets().get(0).aggregations.get("avg")).getValue(), equalTo(2.5)); + + // Third histo bucket + assertThat(result.getBuckets().get(2).getDocCount(), equalTo(3L)); + categorizationAggregation = result.getBuckets().get(2).getAggregations().get("my_agg"); + assertThat(categorizationAggregation.getBuckets(), hasSize(2)); + assertThat(categorizationAggregation.getBuckets().get(0).docCount, equalTo(2L)); + assertThat(categorizationAggregation.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); + assertThat(((Max) categorizationAggregation.getBuckets().get(0).aggregations.get("max")).getValue(), equalTo(5.0)); + assertThat(((Min) categorizationAggregation.getBuckets().get(0).aggregations.get("min")).getValue(), equalTo(4.0)); + assertThat(((Avg) categorizationAggregation.getBuckets().get(0).aggregations.get("avg")).getValue(), equalTo(4.5)); + + assertThat(categorizationAggregation.getBuckets().get(1).docCount, equalTo(1L)); + assertThat( + categorizationAggregation.getBuckets().get(1).getKeyAsString(), + equalTo("Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception") + ); + assertThat(((Max) categorizationAggregation.getBuckets().get(1).aggregations.get("max")).getValue(), equalTo(4.0)); + assertThat(((Min) categorizationAggregation.getBuckets().get(1).aggregations.get("min")).getValue(), equalTo(4.0)); + assertThat(((Avg) categorizationAggregation.getBuckets().get(1).aggregations.get("avg")).getValue(), equalTo(4.0)); + }, + new TextFieldMapper.TextFieldType(TEXT_FIELD_NAME), + longField(NUMERIC_FIELD_NAME) + ); + } + + public void testCategorizationWithSubAggsManyDocs() throws Exception { + CategorizeTextAggregationBuilder aggBuilder = new CategorizeTextAggregationBuilder("my_agg", TEXT_FIELD_NAME).subAggregation( + new HistogramAggregationBuilder("histo").field(NUMERIC_FIELD_NAME) + .interval(2) + .subAggregation(new MaxAggregationBuilder("max").field(NUMERIC_FIELD_NAME)) + .subAggregation(new AvgAggregationBuilder("avg").field(NUMERIC_FIELD_NAME)) + .subAggregation(new MinAggregationBuilder("min").field(NUMERIC_FIELD_NAME)) + ); + testCase( + aggBuilder, + new MatchAllDocsQuery(), + CategorizeTextAggregatorTests::writeManyTestDocs, + (InternalCategorizationAggregation result) -> { + assertThat(result.getBuckets(), hasSize(2)); + assertThat(result.getBuckets().get(0).docCount, equalTo(30_000L)); + assertThat(result.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); + Histogram histo = result.getBuckets().get(0).aggregations.get("histo"); + assertThat(histo.getBuckets(), hasSize(3)); + for (Histogram.Bucket bucket : histo.getBuckets()) { + assertThat(bucket.getDocCount(), equalTo(10_000L)); + } + assertThat(((Max) histo.getBuckets().get(0).getAggregations().get("max")).getValue(), equalTo(1.0)); + assertThat(((Min) histo.getBuckets().get(0).getAggregations().get("min")).getValue(), equalTo(0.0)); + assertThat(((Avg) histo.getBuckets().get(0).getAggregations().get("avg")).getValue(), equalTo(0.5)); + assertThat(((Max) histo.getBuckets().get(1).getAggregations().get("max")).getValue(), equalTo(3.0)); + assertThat(((Min) histo.getBuckets().get(1).getAggregations().get("min")).getValue(), equalTo(2.0)); + assertThat(((Avg) histo.getBuckets().get(1).getAggregations().get("avg")).getValue(), equalTo(2.5)); + assertThat(((Max) histo.getBuckets().get(2).getAggregations().get("max")).getValue(), equalTo(5.0)); + assertThat(((Min) histo.getBuckets().get(2).getAggregations().get("min")).getValue(), equalTo(4.0)); + assertThat(((Avg) histo.getBuckets().get(2).getAggregations().get("avg")).getValue(), equalTo(4.5)); + + assertThat(result.getBuckets().get(1).docCount, equalTo(10_000L)); + assertThat( + result.getBuckets().get(1).getKeyAsString(), + equalTo("Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception") + ); + histo = result.getBuckets().get(1).aggregations.get("histo"); + assertThat(histo.getBuckets(), hasSize(3)); + assertThat(histo.getBuckets().get(0).getDocCount(), equalTo(5_000L)); + assertThat(histo.getBuckets().get(1).getDocCount(), equalTo(0L)); + assertThat(histo.getBuckets().get(2).getDocCount(), equalTo(5_000L)); + assertThat(((Avg) histo.getBuckets().get(0).getAggregations().get("avg")).getValue(), equalTo(0.0)); + assertThat(((Avg) histo.getBuckets().get(2).getAggregations().get("avg")).getValue(), equalTo(4.0)); + }, + new TextFieldMapper.TextFieldType(TEXT_FIELD_NAME), + longField(NUMERIC_FIELD_NAME) + ); + } + + private static void writeTestDocs(RandomIndexWriter w) throws IOException { + w.addDocument( + Arrays.asList( + new StoredField("_source", new BytesRef("{\"text\":\"Node 1 started\"}")), + new SortedNumericDocValuesField(NUMERIC_FIELD_NAME, 0) + ) + ); + w.addDocument( + Arrays.asList( + new StoredField("_source", new BytesRef("{\"text\":\"Node 1 started\"}")), + new SortedNumericDocValuesField(NUMERIC_FIELD_NAME, 1) + ) + ); + w.addDocument( + Arrays.asList( + new StoredField( + "_source", + new BytesRef("{\"text\":\"Failed to shutdown [error org.aaaa.bbbb.Cccc line 54 caused by foo exception]\"}") + ), + new SortedNumericDocValuesField(NUMERIC_FIELD_NAME, 0) + ) + ); + w.addDocument( + Arrays.asList( + new StoredField( + "_source", + new BytesRef("{\"text\":\"Failed to shutdown [error org.aaaa.bbbb.Cccc line 54 caused by foo exception]\"}") + ), + new SortedNumericDocValuesField(NUMERIC_FIELD_NAME, 4) + ) + ); + w.addDocument( + Arrays.asList( + new StoredField("_source", new BytesRef("{\"text\":\"Node 2 started\"}")), + new SortedNumericDocValuesField(NUMERIC_FIELD_NAME, 2) + ) + ); + w.addDocument( + Arrays.asList( + new StoredField("_source", new BytesRef("{\"text\":\"Node 2 started\"}")), + new SortedNumericDocValuesField(NUMERIC_FIELD_NAME, 3) + ) + ); + w.addDocument( + Arrays.asList( + new StoredField("_source", new BytesRef("{\"text\":\"Node 3 started\"}")), + new SortedNumericDocValuesField(NUMERIC_FIELD_NAME, 4) + ) + ); + w.addDocument( + Arrays.asList( + new StoredField("_source", new BytesRef("{\"text\":\"Node 3 started\"}")), + new SortedNumericDocValuesField(NUMERIC_FIELD_NAME, 5) + ) + ); + } + + private static void writeManyTestDocs(RandomIndexWriter w) throws IOException { + for (int i = 0; i < 5_000; i++) { + writeTestDocs(w); + } + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InnerTreeNodeTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InnerTreeNodeTests.java new file mode 100644 index 0000000000000..e7f78a01d0130 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InnerTreeNodeTests.java @@ -0,0 +1,123 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.util.BytesRefHash; +import org.elasticsearch.test.ESTestCase; +import org.junit.After; +import org.junit.Before; + +import static org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash.WILD_CARD_ID; +import static org.elasticsearch.xpack.ml.aggs.categorization.TextCategorizationTests.getTokens; +import static org.elasticsearch.xpack.ml.aggs.categorization.TextCategorizationTests.mockBigArrays; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; + +public class InnerTreeNodeTests extends ESTestCase { + + private final CategorizationTokenTree factory = new CategorizationTokenTree(3, 4, 60); + private CategorizationBytesRefHash bytesRefHash; + + @Before + public void createRefHash() { + bytesRefHash = new CategorizationBytesRefHash(new BytesRefHash(1L, mockBigArrays())); + } + + @After + public void closeRefHash() { + bytesRefHash.close(); + } + + public void testAddText() { + TreeNode.InnerTreeNode innerTreeNode = new TreeNode.InnerTreeNode(1, 0, 3); + TextCategorization group = innerTreeNode.addText(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 1, factory); + assertArrayEquals(group.getCategorization(), getTokens(bytesRefHash, "foo", "bar", "baz", "biz")); + + assertArrayEquals( + innerTreeNode.addText(getTokens(bytesRefHash, "foo2", "bar", "baz", "biz"), 1, factory).getCategorization(), + getTokens(bytesRefHash, "foo2", "bar", "baz", "biz") + ); + assertArrayEquals( + innerTreeNode.addText(getTokens(bytesRefHash, "foo3", "bar", "baz", "biz"), 1, factory).getCategorization(), + getTokens(bytesRefHash, "foo3", "bar", "baz", "biz") + ); + assertArrayEquals( + innerTreeNode.addText(getTokens(bytesRefHash, "foo4", "bar", "baz", "biz"), 1, factory).getCategorization(), + getTokens(bytesRefHash, "*", "bar", "baz", "biz") + ); + assertArrayEquals( + innerTreeNode.addText(getTokens(bytesRefHash, "foo", "bar", "baz", "bizzy"), 1, factory).getCategorization(), + getTokens(bytesRefHash, "foo", "bar", "baz", "*") + ); + } + + public void testAddTokensWithLargerIncoming() { + TreeNode.InnerTreeNode innerTreeNode = new TreeNode.InnerTreeNode(1, 0, 3); + TextCategorization group = innerTreeNode.addText(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 100, factory); + assertArrayEquals(group.getCategorization(), getTokens(bytesRefHash, "foo", "bar", "baz", "biz")); + + assertArrayEquals( + innerTreeNode.addText(getTokens(bytesRefHash, "foo2", "bar", "baz", "biz"), 100, factory).getCategorization(), + getTokens(bytesRefHash, "foo2", "bar", "baz", "biz") + ); + assertArrayEquals( + innerTreeNode.addText(getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz"), 1, factory).getCategorization(), + getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz") + ); + assertArrayEquals( + innerTreeNode.addText(getTokens(bytesRefHash, "foobigun", "bar", "baz", "biz"), 1000, factory).getCategorization(), + getTokens(bytesRefHash, "foobigun", "bar", "baz", "biz") + ); + assertThat( + innerTreeNode.getCategorization(getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz")).getCategorization(), + equalTo(getTokens(bytesRefHash, "*", "bar", "baz", "biz")) + ); + } + + public void testCollapseTinyChildren() { + TreeNode.InnerTreeNode innerTreeNode = new TreeNode.InnerTreeNode(1000, 0, 4); + TextCategorization group = innerTreeNode.addText(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 1000, factory); + assertArrayEquals(group.getCategorization(), getTokens(bytesRefHash, "foo", "bar", "baz", "biz")); + + assertArrayEquals( + innerTreeNode.addText(getTokens(bytesRefHash, "foo2", "bar", "baz", "biz"), 1000, factory).getCategorization(), + getTokens(bytesRefHash, "foo2", "bar", "baz", "biz") + ); + innerTreeNode.incCount(1000); + assertArrayEquals( + innerTreeNode.addText(getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz"), 1, factory).getCategorization(), + getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz") + ); + innerTreeNode.incCount(1); + innerTreeNode.collapseTinyChildren(); + assertThat(innerTreeNode.hasChild(bytesRefHash.put(new BytesRef("foosmall"))), is(false)); + assertThat(innerTreeNode.hasChild(WILD_CARD_ID), is(true)); + } + + public void testMergeWith() { + TreeNode.InnerTreeNode innerTreeNode = new TreeNode.InnerTreeNode(1000, 0, 3); + innerTreeNode.addText(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 1000, factory); + innerTreeNode.incCount(1000); + innerTreeNode.addText(getTokens(bytesRefHash, "foo2", "bar", "baz", "biz"), 1000, factory); + + expectThrows(UnsupportedOperationException.class, () -> innerTreeNode.mergeWith(new TreeNode.LeafTreeNode(1, 60))); + + TreeNode.InnerTreeNode mergeWith = new TreeNode.InnerTreeNode(1, 0, 3); + innerTreeNode.addText(getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz"), 1, factory); + innerTreeNode.incCount(1); + innerTreeNode.addText(getTokens(bytesRefHash, "footiny", "bar", "baz", "biz"), 1, factory); + + innerTreeNode.mergeWith(mergeWith); + assertThat(innerTreeNode.hasChild(WILD_CARD_ID), is(true)); + assertArrayEquals( + innerTreeNode.getCategorization(getTokens(bytesRefHash, "footiny", "bar", "baz", "biz")).getCategorization(), + getTokens(bytesRefHash, "*", "bar", "baz", "biz") + ); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregationTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregationTests.java new file mode 100644 index 0000000000000..749863b336fa2 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregationTests.java @@ -0,0 +1,117 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.CollectionUtils; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.ParseField; +import org.elasticsearch.plugins.SearchPlugin; +import org.elasticsearch.search.aggregations.Aggregation; +import org.elasticsearch.search.aggregations.InternalAggregations; +import org.elasticsearch.search.aggregations.ParsedMultiBucketAggregation; +import org.elasticsearch.test.InternalMultiBucketAggregationTestCase; +import org.elasticsearch.xpack.ml.MachineLearning; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class InternalCategorizationAggregationTests extends InternalMultiBucketAggregationTestCase { + + @Override + protected SearchPlugin registerPlugin() { + return new MachineLearning(Settings.EMPTY, null); + } + + @Override + protected List getNamedXContents() { + return CollectionUtils.appendToCopy( + super.getNamedXContents(), + new NamedXContentRegistry.Entry( + Aggregation.class, + new ParseField(CategorizeTextAggregationBuilder.NAME), + (p, c) -> ParsedCategorization.fromXContent(p, (String) c) + ) + ); + } + + @Override + protected void assertReduced(InternalCategorizationAggregation reduced, List inputs) { + Map reducedCounts = toCounts(reduced.getBuckets().stream()); + Map totalCounts = toCounts(inputs.stream().map(InternalCategorizationAggregation::getBuckets).flatMap(List::stream)); + + Map expectedReducedCounts = new HashMap<>(totalCounts); + expectedReducedCounts.keySet().retainAll(reducedCounts.keySet()); + assertEquals(expectedReducedCounts, reducedCounts); + } + + @Override + protected Predicate excludePathsFromXContentInsertion() { + return p -> p.contains("key"); + } + + static InternalCategorizationAggregation.BucketKey randomKey() { + int numVals = randomIntBetween(1, 50); + return new InternalCategorizationAggregation.BucketKey( + Stream.generate(() -> randomAlphaOfLength(10)).limit(numVals).map(BytesRef::new).toArray(BytesRef[]::new) + ); + } + + @Override + protected InternalCategorizationAggregation createTestInstance( + String name, + Map metadata, + InternalAggregations aggregations + ) { + List buckets = new ArrayList<>(); + final int numBuckets = randomNumberOfBuckets(); + HashSet keys = new HashSet<>(); + for (int i = 0; i < numBuckets; ++i) { + InternalCategorizationAggregation.BucketKey key = randomValueOtherThanMany( + l -> keys.add(l) == false, + InternalCategorizationAggregationTests::randomKey + ); + int docCount = randomIntBetween(1, 100); + buckets.add(new InternalCategorizationAggregation.Bucket(key, docCount, aggregations)); + } + Collections.sort(buckets); + return new InternalCategorizationAggregation( + name, + randomIntBetween(10, 100), + randomLongBetween(1, 10), + randomIntBetween(1, 500), + randomIntBetween(1, 10), + randomIntBetween(1, 100), + metadata, + buckets + ); + } + + @Override + protected Class> implementationClass() { + return ParsedCategorization.class; + } + + private static Map toCounts(Stream buckets) { + return buckets.collect( + Collectors.toMap( + InternalCategorizationAggregation.Bucket::getKey, + InternalCategorizationAggregation.Bucket::getDocCount, + Long::sum + ) + ); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LeafTreeNodeTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LeafTreeNodeTests.java new file mode 100644 index 0000000000000..2bef18993e019 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LeafTreeNodeTests.java @@ -0,0 +1,88 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.elasticsearch.common.util.BytesRefHash; +import org.elasticsearch.test.ESTestCase; +import org.junit.After; +import org.junit.Before; + +import static org.elasticsearch.xpack.ml.aggs.categorization.TextCategorizationTests.getTokens; +import static org.elasticsearch.xpack.ml.aggs.categorization.TextCategorizationTests.mockBigArrays; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.hasSize; + +public class LeafTreeNodeTests extends ESTestCase { + + private final CategorizationTokenTree factory = new CategorizationTokenTree(10, 10, 60); + + private CategorizationBytesRefHash bytesRefHash; + + @Before + public void createRefHash() { + bytesRefHash = new CategorizationBytesRefHash(new BytesRefHash(1L, mockBigArrays())); + } + + @After + public void closeRefHash() { + bytesRefHash.close(); + } + + public void testAddGroup() { + TreeNode.LeafTreeNode leafTreeNode = new TreeNode.LeafTreeNode(0, 60); + TextCategorization group = leafTreeNode.addText(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 1, factory); + + assertArrayEquals(group.getCategorization(), getTokens(bytesRefHash, "foo", "bar", "baz", "biz")); + assertThat(group.getCount(), equalTo(1L)); + assertThat(leafTreeNode.getAllChildrenTextCategorizations(), hasSize(1)); + long previousBytesUsed = leafTreeNode.ramBytesUsed(); + + group = leafTreeNode.addText(getTokens(bytesRefHash, "foo", "bar", "bozo", "bizzy"), 1, factory); + assertArrayEquals(group.getCategorization(), getTokens(bytesRefHash, "foo", "bar", "bozo", "bizzy")); + assertThat(group.getCount(), equalTo(1L)); + assertThat(leafTreeNode.getAllChildrenTextCategorizations(), hasSize(2)); + assertThat(leafTreeNode.ramBytesUsed(), greaterThan(previousBytesUsed)); + previousBytesUsed = leafTreeNode.ramBytesUsed(); + + group = leafTreeNode.addText(getTokens(bytesRefHash, "foo", "bar", "baz", "different"), 3, factory); + assertArrayEquals(group.getCategorization(), getTokens(bytesRefHash, "foo", "bar", "baz", "*")); + assertThat(group.getCount(), equalTo(4L)); + assertThat(leafTreeNode.getAllChildrenTextCategorizations(), hasSize(2)); + assertThat(previousBytesUsed, equalTo(leafTreeNode.ramBytesUsed())); + } + + public void testMergeWith() { + TreeNode.LeafTreeNode leafTreeNode = new TreeNode.LeafTreeNode(0, 60); + leafTreeNode.mergeWith(null); + assertThat(leafTreeNode, equalTo(new TreeNode.LeafTreeNode(0, 60))); + + expectThrows(UnsupportedOperationException.class, () -> leafTreeNode.mergeWith(new TreeNode.InnerTreeNode(1, 2, 3))); + + leafTreeNode.incCount(5); + leafTreeNode.addText(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 5, factory); + + TreeNode.LeafTreeNode toMerge = new TreeNode.LeafTreeNode(0, 60); + leafTreeNode.incCount(1); + toMerge.addText(getTokens(bytesRefHash, "foo", "bar", "baz", "bizzy"), 1, factory); + leafTreeNode.incCount(1); + toMerge.addText(getTokens(bytesRefHash, "foo", "bart", "bat", "built"), 1, factory); + leafTreeNode.mergeWith(toMerge); + + assertThat(leafTreeNode.getAllChildrenTextCategorizations(), hasSize(2)); + assertThat(leafTreeNode.getCount(), equalTo(7L)); + assertArrayEquals( + leafTreeNode.getAllChildrenTextCategorizations().get(0).getCategorization(), + getTokens(bytesRefHash, "foo", "bar", "baz", "*") + ); + assertArrayEquals( + leafTreeNode.getAllChildrenTextCategorizations().get(1).getCategorization(), + getTokens(bytesRefHash, "foo", "bart", "bat", "built") + ); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/ParsedCategorization.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/ParsedCategorization.java new file mode 100644 index 0000000000000..b554f9cfc43e1 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/ParsedCategorization.java @@ -0,0 +1,113 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentParserUtils; +import org.elasticsearch.search.aggregations.ParsedMultiBucketAggregation; +import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + +class ParsedCategorization extends ParsedMultiBucketAggregation { + + @Override + public String getType() { + return CategorizeTextAggregationBuilder.NAME; + } + + private static final ObjectParser PARSER = new ObjectParser<>( + ParsedCategorization.class.getSimpleName(), + true, + ParsedCategorization::new + ); + static { + declareMultiBucketAggregationFields(PARSER, ParsedBucket::fromXContent, ParsedBucket::fromXContent); + } + + public static ParsedCategorization fromXContent(XContentParser parser, String name) throws IOException { + ParsedCategorization aggregation = PARSER.parse(parser, null); + aggregation.setName(name); + return aggregation; + } + + @Override + public List getBuckets() { + return buckets; + } + + public static class ParsedBucket extends ParsedMultiBucketAggregation.ParsedBucket implements MultiBucketsAggregation.Bucket { + + private InternalCategorizationAggregation.BucketKey key; + + protected void setKeyAsString(String keyAsString) { + if (keyAsString == null) { + key = null; + return; + } + if (keyAsString.isEmpty()) { + key = new InternalCategorizationAggregation.BucketKey(new BytesRef[0]); + return; + } + String[] split = Strings.tokenizeToStringArray(keyAsString, " "); + key = new InternalCategorizationAggregation.BucketKey( + split == null + ? new BytesRef[] { new BytesRef(keyAsString) } + : Arrays.stream(split).map(BytesRef::new).toArray(BytesRef[]::new) + ); + } + + @Override + public Object getKey() { + return key; + } + + @Override + public String getKeyAsString() { + return key.asString(); + } + + @Override + protected XContentBuilder keyToXContent(XContentBuilder builder) throws IOException { + return builder.field(CommonFields.KEY.getPreferredName(), getKey()); + } + + static InternalCategorizationAggregation.BucketKey parsedKey(final XContentParser parser) throws IOException { + if (parser.currentToken() == XContentParser.Token.VALUE_STRING) { + String toSplit = parser.text(); + String[] split = Strings.tokenizeToStringArray(toSplit, " "); + return new InternalCategorizationAggregation.BucketKey( + split == null + ? new BytesRef[] { new BytesRef(toSplit) } + : Arrays.stream(split).map(BytesRef::new).toArray(BytesRef[]::new) + ); + } else { + return new InternalCategorizationAggregation.BucketKey( + XContentParserUtils.parseList(parser, p -> new BytesRef(p.binaryValue())).toArray(BytesRef[]::new) + ); + } + } + + static ParsedBucket fromXContent(final XContentParser parser) throws IOException { + return ParsedMultiBucketAggregation.ParsedBucket.parseXContent( + parser, + false, + ParsedBucket::new, + (p, bucket) -> bucket.key = parsedKey(p) + ); + } + + } + +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorizationTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorizationTests.java new file mode 100644 index 0000000000000..59129f8801937 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorizationTests.java @@ -0,0 +1,75 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BytesRefHash; +import org.elasticsearch.common.util.MockBigArrays; +import org.elasticsearch.common.util.MockPageCacheRecycler; +import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; +import org.elasticsearch.test.ESTestCase; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; + +import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.equalTo; + +public class TextCategorizationTests extends ESTestCase { + + static BigArrays mockBigArrays() { + return new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService()); + } + + private CategorizationBytesRefHash bytesRefHash; + + @Before + public void createRefHash() { + bytesRefHash = new CategorizationBytesRefHash(new BytesRefHash(1L, mockBigArrays())); + } + + @After + public void closeRefHash() throws IOException { + bytesRefHash.close(); + } + + public void testSimilarity() { + TextCategorization lg = new TextCategorization(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 1, 1); + TextCategorization.Similarity sims = lg.calculateSimilarity(getTokens(bytesRefHash, "not", "matching", "anything", "nope")); + assertThat(sims.getSimilarity(), equalTo(0.0)); + assertThat(sims.getWildCardCount(), equalTo(0)); + + sims = lg.calculateSimilarity(getTokens(bytesRefHash, "foo", "bar", "baz", "biz")); + assertThat(sims.getSimilarity(), equalTo(1.0)); + assertThat(sims.getWildCardCount(), equalTo(0)); + + sims = lg.calculateSimilarity(getTokens(bytesRefHash, "foo", "fooagain", "notbar", "biz")); + assertThat(sims.getSimilarity(), closeTo(0.5, 0.0001)); + assertThat(sims.getWildCardCount(), equalTo(0)); + } + + public void testAddTokens() { + TextCategorization lg = new TextCategorization(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 1, 1); + lg.addTokens(getTokens(bytesRefHash, "foo", "bar", "baz", "bozo"), 2); + assertThat(lg.getCount(), equalTo(3L)); + assertArrayEquals(lg.getCategorization(), getTokens(bytesRefHash, "foo", "bar", "baz", "*")); + } + + static int[] getTokens(CategorizationBytesRefHash bytesRefHash, String... tokens) { + BytesRef[] refs = new BytesRef[tokens.length]; + int i = 0; + for (String token : tokens) { + refs[i++] = new BytesRef(token); + } + return bytesRefHash.getIds(refs); + } + +} diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/categorization_agg.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/categorization_agg.yml new file mode 100644 index 0000000000000..c2d5e0dbf09f1 --- /dev/null +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/categorization_agg.yml @@ -0,0 +1,153 @@ +setup: + - skip: + features: headers + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + indices.create: + index: to_categorize + body: + mappings: + properties: + kind: + type: keyword + text: + type: text + + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + Content-Type: application/json + bulk: + index: to_categorize + refresh: true + body: | + {"index": {}} + {"product": "server","text": "Node 2 stopping"} + {"index": {}} + {"product": "server", "text": "Node 2 starting"} + {"index": {}} + {"product": "server", "text": "Node 4 stopping"} + {"index": {}} + {"product": "server", "text": "Node 5 stopping"} + {"index": {}} + {"product": "user", "text": "User Foo logging on"} + {"index": {}} + {"product": "user", "text": "User Foo logging on"} + {"index": {}} + {"product": "user", "text": "User Foo logging off"} + +--- +"Test categorization agg simple": + + - do: + search: + index: to_categorize + body: > + { + "size": 0, + "aggs": { + "categories": { + "categorize_text": { + "field": "text" + } + } + } + } + - length: { aggregations.categories.buckets: 4} + - match: {aggregations.categories.buckets.0.doc_count: 3} + - match: {aggregations.categories.buckets.0.key: "Node stopping" } + + - do: + search: + index: to_categorize + body: > + { + "size": 0, + "aggs": { + "categories": { + "categorize_text": { + "field": "text", + "size": 10, + "max_unique_tokens": 2, + "max_matched_tokens": 1, + "similarity_threshold": 11 + } + } + } + } + + - length: { aggregations.categories.buckets: 2 } + - match: { aggregations.categories.buckets.0.doc_count: 4 } + - match: { aggregations.categories.buckets.0.key: "Node *" } + - match: { aggregations.categories.buckets.1.doc_count: 3 } + - match: { aggregations.categories.buckets.1.key: "User Foo logging *" } +--- +"Test categorization aggregation with poor settings": + + - do: + catch: /\[max_unique_tokens\] must be greater than 0 and less than \[100\]/ + search: + index: to_categorize + body: > + { + "size": 0, + "aggs": { + "categories": { + "categorize_text": { + "field": "text", + "max_unique_tokens": -2 + } + } + } + } + - do: + catch: /\[max_matched_tokens\] must be greater than 0 and less than \[100\]/ + search: + index: to_categorize + body: > + { + "size": 0, + "aggs": { + "categories": { + "categorize_text": { + "field": "text", + "max_matched_tokens": -2 + } + } + } + } + - do: + catch: /\[similarity_threshold\] must be in the range \[1, 100\]/ + search: + index: to_categorize + body: > + { + "size": 0, + "aggs": { + "categories": { + "categorize_text": { + "field": "text", + "similarity_threshold": 0 + } + } + } + } + + - do: + catch: /\[categorization_filters\] cannot be used with \[categorization_analyzer\]/ + search: + index: to_categorize + body: > + { + "size": 0, + "aggs": { + "categories": { + "categorize_text": { + "field": "text", + "categorization_filters": ["foo"], + "categorization_analyzer": "english" + } + } + } + }