Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] A text categorization aggregation that works like ML categorization #80867

Merged
merged 24 commits into from
Apr 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
c0f1237
[ML] A text categorization aggregation that works like ML categorization
droberts195 Nov 19, 2021
e1e58d4
Adding dictionary weighting
droberts195 Nov 22, 2021
c79c1b7
Merge branch 'master' into categorize_text2
droberts195 Nov 22, 2021
874edd4
Improve memory accounting
droberts195 Nov 23, 2021
f79a52b
Merge branch 'master' into categorize_text2
droberts195 Nov 23, 2021
5e75308
Renaming class since it only works on token lists in Java
droberts195 Nov 25, 2021
6ed5dcb
Merge branch 'master' into categorize_text2
droberts195 Nov 25, 2021
885103a
Merge branch 'master' into categorize_text2
droberts195 Nov 29, 2021
c10dd22
Merge branch 'master' into categorize_text2
droberts195 Dec 6, 2021
081e04d
Merge branch 'master' into categorize_text2
droberts195 Jan 19, 2022
d5dae2b
Fix compilation
droberts195 Jan 19, 2022
65c9d38
Merge branch 'master' into categorize_text2
droberts195 Apr 4, 2022
66f325c
Bring up-to-date and optimize performance
droberts195 Apr 4, 2022
fce2f1d
Remove redundant parameters
droberts195 Apr 5, 2022
8e3967b
Don't repeatedly reinitialize PoS dictionary
droberts195 Apr 5, 2022
d985601
Bug fixes and tests for memory tracking
droberts195 Apr 5, 2022
930f9b8
Avoid shallowSizeOf() as security manager can make it stall for seconds
droberts195 Apr 5, 2022
aca43fa
Add a multi-node test to prove that post-serialization merges work
droberts195 Apr 6, 2022
67d8359
Part-of-speech dictionary lookup speed improvements
droberts195 Apr 6, 2022
d7887b0
Update docs/changelog/80867.yaml
droberts195 Apr 11, 2022
a9e9837
Merge branch 'master' into categorize_text2
droberts195 Apr 11, 2022
5590a47
Address a few review comments
droberts195 Apr 11, 2022
c267019
Merge branch 'master' into categorize_text2
droberts195 Apr 12, 2022
f9b18b8
Address another review comment
droberts195 Apr 12, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/80867.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 80867
summary: A text categorization aggregation that works like ML categorization
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.ml.integration;

import org.elasticsearch.action.admin.indices.create.CreateIndexRequest;
import org.elasticsearch.action.admin.indices.stats.IndicesStatsResponse;
import org.elasticsearch.action.admin.indices.stats.ShardStats;
import org.elasticsearch.action.bulk.BulkRequestBuilder;
import org.elasticsearch.action.index.IndexRequestBuilder;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.routing.ShardRouting;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.xpack.ml.aggs.categorization2.CategorizeTextAggregationBuilder;
import org.elasticsearch.xpack.ml.aggs.categorization2.InternalCategorizationAggregation;
import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase;

import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notNullValue;

public class CategorizeTextDistributedIT extends BaseMlIntegTestCase {

/**
* When categorizing text in a multi-node cluster the categorize_text2 aggregation has
* a harder job than in a single node cluster. The categories must be serialized between
* nodes and then merged appropriately on the receiving node. This test ensures that
* this serialization and subsequent merging works in the same way that merging would work
* on a single node.
*/
public void testDistributedCategorizeText() {
internalCluster().ensureAtLeastNumDataNodes(3);
ensureStableCluster();

// System indices may affect the distribution of shards of this index,
// but it has so many that it should have shards on all the nodes
String indexName = "data";
CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName).settings(
Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, "9").put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, "0")
);
client().admin().indices().create(createIndexRequest).actionGet();

// Spread 10000 documents in 4 categories across the shards
for (int i = 0; i < 10; ++i) {
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk();
for (int j = 0; j < 250; ++j) {
IndexRequestBuilder indexRequestBuilder = client().prepareIndex(indexName)
.setSource(Map.of("message", "Aug 29, 2019 2:02:51 PM org.apache.coyote.http11.Http11BaseProtocol destroy"));
bulkRequestBuilder.add(indexRequestBuilder);
indexRequestBuilder = client().prepareIndex(indexName)
.setSource(Map.of("message", "Aug 29, 2019 2:02:51 PM org.apache.coyote.http11.Http11BaseProtocol init"));
bulkRequestBuilder.add(indexRequestBuilder);
indexRequestBuilder = client().prepareIndex(indexName)
.setSource(Map.of("message", "Aug 29, 2019 2:02:51 PM org.apache.coyote.http11.Http11BaseProtocol start"));
bulkRequestBuilder.add(indexRequestBuilder);
indexRequestBuilder = client().prepareIndex(indexName)
.setSource(Map.of("message", "Aug 29, 2019 2:02:51 PM org.apache.coyote.http11.Http11BaseProtocol stop"));
bulkRequestBuilder.add(indexRequestBuilder);
}
bulkRequestBuilder.execute().actionGet();
}
client().admin().indices().prepareRefresh(indexName).execute().actionGet();

// Confirm the theory that all 3 nodes will have a shard on
IndicesStatsResponse indicesStatsResponse = client().admin().indices().prepareStats(indexName).execute().actionGet();
Set<String> nodesWithShards = Arrays.stream(indicesStatsResponse.getShards())
.map(ShardStats::getShardRouting)
.map(ShardRouting::currentNodeId)
.collect(Collectors.toSet());
assertThat(nodesWithShards, hasSize(internalCluster().size()));

SearchResponse searchResponse = client().prepareSearch(indexName)
.addAggregation(new CategorizeTextAggregationBuilder("categories", "message"))
.setSize(0)
.execute()
.actionGet();

InternalCategorizationAggregation aggregation = searchResponse.getAggregations().get("categories");
assertThat(aggregation, notNullValue());

// We should have created 4 categories, one for each of the distinct messages we indexed, all with counts of 2500 (= 10000/4)
List<InternalCategorizationAggregation.Bucket> buckets = aggregation.getBuckets();
assertThat(buckets, notNullValue());
assertThat(buckets, hasSize(4));
Set<String> expectedLastTokens = new HashSet<>(List.of("destroy", "init", "start", "stop"));
for (InternalCategorizationAggregation.Bucket bucket : buckets) {
assertThat(bucket.getDocCount(), is(2500L));
String[] tokens = bucket.getKeyAsString().split(" ");
String lastToken = tokens[tokens.length - 1];
assertThat(lastToken + " not found in " + expectedLastTokens, expectedLastTokens.remove(lastToken), is(true));
}
assertThat("Some expected last tokens not found " + expectedLastTokens, expectedLastTokens, empty());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1417,7 +1417,16 @@ public List<AggregationSpec> getAggregations() {
CategorizeTextAggregationBuilder::new,
CategorizeTextAggregationBuilder.PARSER
).addResultReader(InternalCategorizationAggregation::new)
.setAggregatorRegistrar(s -> s.registerUsage(CategorizeTextAggregationBuilder.NAME))
.setAggregatorRegistrar(s -> s.registerUsage(CategorizeTextAggregationBuilder.NAME)),
// TODO: in the long term only keep one or other of these categorization aggregations
new AggregationSpec(
org.elasticsearch.xpack.ml.aggs.categorization2.CategorizeTextAggregationBuilder.NAME,
org.elasticsearch.xpack.ml.aggs.categorization2.CategorizeTextAggregationBuilder::new,
org.elasticsearch.xpack.ml.aggs.categorization2.CategorizeTextAggregationBuilder.PARSER
).addResultReader(org.elasticsearch.xpack.ml.aggs.categorization2.InternalCategorizationAggregation::new)
.setAggregatorRegistrar(
s -> s.registerUsage(org.elasticsearch.xpack.ml.aggs.categorization2.CategorizeTextAggregationBuilder.NAME)
)
);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.ml.aggs.categorization2;

import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.logging.LoggerMessageFormat;
import org.elasticsearch.common.util.BytesRefHash;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.search.aggregations.AggregationExecutionException;

class CategorizationBytesRefHash implements Releasable {

private final BytesRefHash bytesRefHash;

CategorizationBytesRefHash(BytesRefHash bytesRefHash) {
this.bytesRefHash = bytesRefHash;
}

int[] getIds(BytesRef[] tokens) {
int[] ids = new int[tokens.length];
for (int i = 0; i < tokens.length; i++) {
ids[i] = put(tokens[i]);
}
return ids;
}

BytesRef[] getDeeps(int[] ids) {
BytesRef[] tokens = new BytesRef[ids.length];
for (int i = 0; i < tokens.length; i++) {
tokens[i] = getDeep(ids[i]);
}
return tokens;
}

BytesRef getDeep(long id) {
BytesRef shallow = bytesRefHash.get(id, new BytesRef());
return BytesRef.deepCopyOf(shallow);
}

int put(BytesRef bytesRef) {
long hash = bytesRefHash.add(bytesRef);
if (hash < 0) {
// BytesRefHash returns -1 - hash if the entry already existed, but we just want to return the hash
return (int) (-1L - hash);
}
if (hash > Integer.MAX_VALUE) {
throw new AggregationExecutionException(
LoggerMessageFormat.format(
"more than [{}] unique terms encountered. "
+ "Consider restricting the documents queried or adding [{}] in the {} configuration",
Integer.MAX_VALUE,
CategorizeTextAggregationBuilder.CATEGORIZATION_FILTERS.getPreferredName(),
CategorizeTextAggregationBuilder.NAME
)
);
}
return (int) hash;
}

@Override
public void close() {
bytesRefHash.close();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.ml.aggs.categorization2;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
* Port of the C++ class <a href="https://github.com/elastic/ml-cpp/blob/main/include/core/CWordDictionary.h">
* <code>CWordDictionary</code></a>.
*/
public class CategorizationPartOfSpeechDictionary {

static final String DICTIONARY_FILE_PATH = "/org/elasticsearch/xpack/ml/aggs/categorization2/ml-en.dict";

static final String PART_OF_SPEECH_SEPARATOR = "@";

public enum PartOfSpeech {
NOT_IN_DICTIONARY('\0'),
UNKNOWN('?'),
NOUN('N'),
PLURAL('p'),
VERB('V'),
ADJECTIVE('A'),
ADVERB('v'),
CONJUNCTION('C'),
PREPOSITION('P'),
INTERJECTION('!'),
PRONOUN('r'),
DEFINITE_ARTICLE('D'),
INDEFINITE_ARTICLE('I');

private final char code;

PartOfSpeech(char code) {
this.code = code;
}

char getCode() {
return code;
}

private static final Map<Character, PartOfSpeech> CODE_MAPPING =
// 'h', 'o', 't', and 'i' are codes for specialist types of noun and verb that we don't distinguish
Stream.concat(
Map.of('h', NOUN, 'o', NOUN, 't', VERB, 'i', VERB).entrySet().stream(),
Stream.of(PartOfSpeech.values()).collect(Collectors.toMap(PartOfSpeech::getCode, Function.identity())).entrySet().stream()
)
.collect(
Collectors.toUnmodifiableMap(Map.Entry<Character, PartOfSpeech>::getKey, Map.Entry<Character, PartOfSpeech>::getValue)
);

static PartOfSpeech fromCode(char partOfSpeechCode) {
PartOfSpeech pos = CODE_MAPPING.get(partOfSpeechCode);
if (pos == null) {
throw new IllegalArgumentException("Unknown part-of-speech code [" + partOfSpeechCode + "]");
}
return pos;
}
}

/**
* Lazy loaded singleton instance to avoid loading the dictionary repeatedly.
*/
private static CategorizationPartOfSpeechDictionary instance;
private static final Object INIT_LOCK = new Object();

/**
* Keys are lower case.
*/
private final Map<String, PartOfSpeech> partOfSpeechDictionary = new HashMap<>();
private final int maxDictionaryWordLength;

CategorizationPartOfSpeechDictionary(InputStream is) throws IOException {

int maxLength = 0;
BufferedReader reader = new BufferedReader(new InputStreamReader(is, StandardCharsets.UTF_8));
String line;
while ((line = reader.readLine()) != null) {
line = line.trim();
if (line.isEmpty()) {
continue;
}
String[] split = line.split(PART_OF_SPEECH_SEPARATOR);
if (split.length != 2) {
throw new IllegalArgumentException(
"Unexpected format in line [" + line + "]: expected one [" + PART_OF_SPEECH_SEPARATOR + "] separator"
);
}
if (split[0].isEmpty()) {
throw new IllegalArgumentException(
"Unexpected format in line [" + line + "]: nothing preceding [" + PART_OF_SPEECH_SEPARATOR + "] separator"
);
}
if (split[1].isEmpty()) {
throw new IllegalArgumentException(
"Unexpected format in line [" + line + "]: nothing following [" + PART_OF_SPEECH_SEPARATOR + "] separator"
);
}
String lowerCaseWord = split[0].toLowerCase(Locale.ROOT);
partOfSpeechDictionary.put(lowerCaseWord, PartOfSpeech.fromCode(split[1].charAt(0)));
maxLength = Math.max(maxLength, lowerCaseWord.length());
}
maxDictionaryWordLength = maxLength;
}

// TODO: now we have this in Java, perform this operation in Java for anomaly detection categorization instead of in C++.
// (It could maybe be incorporated into the categorization analyzer and then shared between aggregation and anomaly detection.)
/**
* Find the part of speech (noun, verb, adjective, etc.) for a supplied word.
* @return Which part of speech does the supplied word represent? {@link PartOfSpeech#NOT_IN_DICTIONARY} is returned
* for words that aren't in the dictionary at all.
*/
public PartOfSpeech getPartOfSpeech(String word) {
if (word.length() > maxDictionaryWordLength) {
return PartOfSpeech.NOT_IN_DICTIONARY;
}
// This is quite slow as it creates a new string for every lookup. However, experiments show
// that trying to do case-insensitive comparisons instead of creating a lower case string is
// even slower.
return partOfSpeechDictionary.getOrDefault(word.toLowerCase(Locale.ROOT), PartOfSpeech.NOT_IN_DICTIONARY);
}

/**
* @return Is the supplied word in the dictionary?
*/
public boolean isInDictionary(String word) {
return getPartOfSpeech(word) != PartOfSpeech.NOT_IN_DICTIONARY;
}

public static CategorizationPartOfSpeechDictionary getInstance() throws IOException {
if (instance != null) {
return instance;
}
synchronized (INIT_LOCK) {
if (instance == null) {
try (InputStream is = CategorizationPartOfSpeechDictionary.class.getResourceAsStream(DICTIONARY_FILE_PATH)) {
instance = new CategorizationPartOfSpeechDictionary(is);
}
}
return instance;
}
}
}
Comment on lines +145 to +158
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++

I am not sure we need to add bytes to the circuit breaker for this or not. I would say if it is near a MB we may want to.

Basically, getInstance could take the circuit breaker and add bytes if it is loaded, ignoring it if not (since it would have already added bytes). And those bytes just stay for the lifetime of the node.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's best not to add it to the same circuit breaker used by the rest of the aggregation.

Although it's large it's effectively static data, so it would make most sense to include it with the "Accounting requests circuit breaker" rather than the "Request circuit breaker". But if indices.breaker.total.use_real_memory is set to true, which it is by default, then that "memory usage of things held in memory that are not released when a request is completed" will take it into account automatically.

I guess we could try to explicitly add it into the "Accounting requests circuit breaker" for the case where real memory circuit breaking is disabled. But this will be messy within the code as the code is written on the basis that what the docs refer to as "memory usage of things held in memory that are not released when a request is completed" is actually field data related to Lucene indices.

The docs also say about the total memory all circuit breakers can use: "Defaults to 70% of JVM heap if indices.breaker.total.use_real_memory is false. If indices.breaker.total.use_real_memory is true, defaults to 95% of the JVM heap." So that implies that if you don't use the real memory circuit breaker to measure fixed overheads then you have to allow some space for unmeasured fixed overheads. So I think this dictionary can be treated as one of those fixed overheads that either gets captured by the real memory circuit breaker or by implicitly reserving a percentage of memory.

Loading