-
Notifications
You must be signed in to change notification settings - Fork 25k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ML] A text categorization aggregation that works like ML categorization #80867
Merged
Merged
Changes from all commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
c0f1237
[ML] A text categorization aggregation that works like ML categorization
droberts195 e1e58d4
Adding dictionary weighting
droberts195 c79c1b7
Merge branch 'master' into categorize_text2
droberts195 874edd4
Improve memory accounting
droberts195 f79a52b
Merge branch 'master' into categorize_text2
droberts195 5e75308
Renaming class since it only works on token lists in Java
droberts195 6ed5dcb
Merge branch 'master' into categorize_text2
droberts195 885103a
Merge branch 'master' into categorize_text2
droberts195 c10dd22
Merge branch 'master' into categorize_text2
droberts195 081e04d
Merge branch 'master' into categorize_text2
droberts195 d5dae2b
Fix compilation
droberts195 65c9d38
Merge branch 'master' into categorize_text2
droberts195 66f325c
Bring up-to-date and optimize performance
droberts195 fce2f1d
Remove redundant parameters
droberts195 8e3967b
Don't repeatedly reinitialize PoS dictionary
droberts195 d985601
Bug fixes and tests for memory tracking
droberts195 930f9b8
Avoid shallowSizeOf() as security manager can make it stall for seconds
droberts195 aca43fa
Add a multi-node test to prove that post-serialization merges work
droberts195 67d8359
Part-of-speech dictionary lookup speed improvements
droberts195 d7887b0
Update docs/changelog/80867.yaml
droberts195 a9e9837
Merge branch 'master' into categorize_text2
droberts195 5590a47
Address a few review comments
droberts195 c267019
Merge branch 'master' into categorize_text2
droberts195 f9b18b8
Address another review comment
droberts195 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
pr: 80867 | ||
summary: A text categorization aggregation that works like ML categorization | ||
area: Machine Learning | ||
type: enhancement | ||
issues: [] |
107 changes: 107 additions & 0 deletions
107
...lClusterTest/java/org/elasticsearch/xpack/ml/integration/CategorizeTextDistributedIT.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.ml.integration; | ||
|
||
import org.elasticsearch.action.admin.indices.create.CreateIndexRequest; | ||
import org.elasticsearch.action.admin.indices.stats.IndicesStatsResponse; | ||
import org.elasticsearch.action.admin.indices.stats.ShardStats; | ||
import org.elasticsearch.action.bulk.BulkRequestBuilder; | ||
import org.elasticsearch.action.index.IndexRequestBuilder; | ||
import org.elasticsearch.action.search.SearchResponse; | ||
import org.elasticsearch.cluster.metadata.IndexMetadata; | ||
import org.elasticsearch.cluster.routing.ShardRouting; | ||
import org.elasticsearch.common.settings.Settings; | ||
import org.elasticsearch.xpack.ml.aggs.categorization2.CategorizeTextAggregationBuilder; | ||
import org.elasticsearch.xpack.ml.aggs.categorization2.InternalCategorizationAggregation; | ||
import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase; | ||
|
||
import java.util.Arrays; | ||
import java.util.HashSet; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Set; | ||
import java.util.stream.Collectors; | ||
|
||
import static org.hamcrest.Matchers.empty; | ||
import static org.hamcrest.Matchers.hasSize; | ||
import static org.hamcrest.Matchers.is; | ||
import static org.hamcrest.Matchers.notNullValue; | ||
|
||
public class CategorizeTextDistributedIT extends BaseMlIntegTestCase { | ||
|
||
/** | ||
* When categorizing text in a multi-node cluster the categorize_text2 aggregation has | ||
* a harder job than in a single node cluster. The categories must be serialized between | ||
* nodes and then merged appropriately on the receiving node. This test ensures that | ||
* this serialization and subsequent merging works in the same way that merging would work | ||
* on a single node. | ||
*/ | ||
public void testDistributedCategorizeText() { | ||
internalCluster().ensureAtLeastNumDataNodes(3); | ||
ensureStableCluster(); | ||
|
||
// System indices may affect the distribution of shards of this index, | ||
// but it has so many that it should have shards on all the nodes | ||
String indexName = "data"; | ||
CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName).settings( | ||
Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, "9").put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, "0") | ||
); | ||
client().admin().indices().create(createIndexRequest).actionGet(); | ||
|
||
// Spread 10000 documents in 4 categories across the shards | ||
for (int i = 0; i < 10; ++i) { | ||
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk(); | ||
for (int j = 0; j < 250; ++j) { | ||
IndexRequestBuilder indexRequestBuilder = client().prepareIndex(indexName) | ||
.setSource(Map.of("message", "Aug 29, 2019 2:02:51 PM org.apache.coyote.http11.Http11BaseProtocol destroy")); | ||
bulkRequestBuilder.add(indexRequestBuilder); | ||
indexRequestBuilder = client().prepareIndex(indexName) | ||
.setSource(Map.of("message", "Aug 29, 2019 2:02:51 PM org.apache.coyote.http11.Http11BaseProtocol init")); | ||
bulkRequestBuilder.add(indexRequestBuilder); | ||
indexRequestBuilder = client().prepareIndex(indexName) | ||
.setSource(Map.of("message", "Aug 29, 2019 2:02:51 PM org.apache.coyote.http11.Http11BaseProtocol start")); | ||
bulkRequestBuilder.add(indexRequestBuilder); | ||
indexRequestBuilder = client().prepareIndex(indexName) | ||
.setSource(Map.of("message", "Aug 29, 2019 2:02:51 PM org.apache.coyote.http11.Http11BaseProtocol stop")); | ||
bulkRequestBuilder.add(indexRequestBuilder); | ||
} | ||
bulkRequestBuilder.execute().actionGet(); | ||
} | ||
client().admin().indices().prepareRefresh(indexName).execute().actionGet(); | ||
|
||
// Confirm the theory that all 3 nodes will have a shard on | ||
IndicesStatsResponse indicesStatsResponse = client().admin().indices().prepareStats(indexName).execute().actionGet(); | ||
Set<String> nodesWithShards = Arrays.stream(indicesStatsResponse.getShards()) | ||
.map(ShardStats::getShardRouting) | ||
.map(ShardRouting::currentNodeId) | ||
.collect(Collectors.toSet()); | ||
assertThat(nodesWithShards, hasSize(internalCluster().size())); | ||
|
||
SearchResponse searchResponse = client().prepareSearch(indexName) | ||
.addAggregation(new CategorizeTextAggregationBuilder("categories", "message")) | ||
.setSize(0) | ||
.execute() | ||
.actionGet(); | ||
|
||
InternalCategorizationAggregation aggregation = searchResponse.getAggregations().get("categories"); | ||
assertThat(aggregation, notNullValue()); | ||
|
||
// We should have created 4 categories, one for each of the distinct messages we indexed, all with counts of 2500 (= 10000/4) | ||
List<InternalCategorizationAggregation.Bucket> buckets = aggregation.getBuckets(); | ||
assertThat(buckets, notNullValue()); | ||
assertThat(buckets, hasSize(4)); | ||
Set<String> expectedLastTokens = new HashSet<>(List.of("destroy", "init", "start", "stop")); | ||
for (InternalCategorizationAggregation.Bucket bucket : buckets) { | ||
assertThat(bucket.getDocCount(), is(2500L)); | ||
String[] tokens = bucket.getKeyAsString().split(" "); | ||
String lastToken = tokens[tokens.length - 1]; | ||
assertThat(lastToken + " not found in " + expectedLastTokens, expectedLastTokens.remove(lastToken), is(true)); | ||
} | ||
assertThat("Some expected last tokens not found " + expectedLastTokens, expectedLastTokens, empty()); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
69 changes: 69 additions & 0 deletions
69
...main/java/org/elasticsearch/xpack/ml/aggs/categorization2/CategorizationBytesRefHash.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.ml.aggs.categorization2; | ||
|
||
import org.apache.lucene.util.BytesRef; | ||
import org.elasticsearch.common.logging.LoggerMessageFormat; | ||
import org.elasticsearch.common.util.BytesRefHash; | ||
import org.elasticsearch.core.Releasable; | ||
import org.elasticsearch.search.aggregations.AggregationExecutionException; | ||
|
||
class CategorizationBytesRefHash implements Releasable { | ||
|
||
private final BytesRefHash bytesRefHash; | ||
|
||
CategorizationBytesRefHash(BytesRefHash bytesRefHash) { | ||
this.bytesRefHash = bytesRefHash; | ||
} | ||
|
||
int[] getIds(BytesRef[] tokens) { | ||
int[] ids = new int[tokens.length]; | ||
for (int i = 0; i < tokens.length; i++) { | ||
ids[i] = put(tokens[i]); | ||
} | ||
return ids; | ||
} | ||
|
||
BytesRef[] getDeeps(int[] ids) { | ||
BytesRef[] tokens = new BytesRef[ids.length]; | ||
for (int i = 0; i < tokens.length; i++) { | ||
tokens[i] = getDeep(ids[i]); | ||
} | ||
return tokens; | ||
} | ||
|
||
BytesRef getDeep(long id) { | ||
BytesRef shallow = bytesRefHash.get(id, new BytesRef()); | ||
return BytesRef.deepCopyOf(shallow); | ||
} | ||
|
||
int put(BytesRef bytesRef) { | ||
long hash = bytesRefHash.add(bytesRef); | ||
if (hash < 0) { | ||
// BytesRefHash returns -1 - hash if the entry already existed, but we just want to return the hash | ||
return (int) (-1L - hash); | ||
} | ||
if (hash > Integer.MAX_VALUE) { | ||
throw new AggregationExecutionException( | ||
LoggerMessageFormat.format( | ||
"more than [{}] unique terms encountered. " | ||
+ "Consider restricting the documents queried or adding [{}] in the {} configuration", | ||
Integer.MAX_VALUE, | ||
CategorizeTextAggregationBuilder.CATEGORIZATION_FILTERS.getPreferredName(), | ||
CategorizeTextAggregationBuilder.NAME | ||
) | ||
); | ||
} | ||
return (int) hash; | ||
} | ||
|
||
@Override | ||
public void close() { | ||
bytesRefHash.close(); | ||
} | ||
} |
158 changes: 158 additions & 0 deletions
158
...org/elasticsearch/xpack/ml/aggs/categorization2/CategorizationPartOfSpeechDictionary.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.ml.aggs.categorization2; | ||
|
||
import java.io.BufferedReader; | ||
import java.io.IOException; | ||
import java.io.InputStream; | ||
import java.io.InputStreamReader; | ||
import java.nio.charset.StandardCharsets; | ||
import java.util.HashMap; | ||
import java.util.Locale; | ||
import java.util.Map; | ||
import java.util.function.Function; | ||
import java.util.stream.Collectors; | ||
import java.util.stream.Stream; | ||
|
||
/** | ||
* Port of the C++ class <a href="https://github.com/elastic/ml-cpp/blob/main/include/core/CWordDictionary.h"> | ||
* <code>CWordDictionary</code></a>. | ||
*/ | ||
public class CategorizationPartOfSpeechDictionary { | ||
|
||
static final String DICTIONARY_FILE_PATH = "/org/elasticsearch/xpack/ml/aggs/categorization2/ml-en.dict"; | ||
|
||
static final String PART_OF_SPEECH_SEPARATOR = "@"; | ||
|
||
public enum PartOfSpeech { | ||
NOT_IN_DICTIONARY('\0'), | ||
UNKNOWN('?'), | ||
NOUN('N'), | ||
PLURAL('p'), | ||
VERB('V'), | ||
ADJECTIVE('A'), | ||
ADVERB('v'), | ||
CONJUNCTION('C'), | ||
PREPOSITION('P'), | ||
INTERJECTION('!'), | ||
PRONOUN('r'), | ||
DEFINITE_ARTICLE('D'), | ||
INDEFINITE_ARTICLE('I'); | ||
|
||
private final char code; | ||
|
||
PartOfSpeech(char code) { | ||
this.code = code; | ||
} | ||
|
||
char getCode() { | ||
return code; | ||
} | ||
|
||
private static final Map<Character, PartOfSpeech> CODE_MAPPING = | ||
// 'h', 'o', 't', and 'i' are codes for specialist types of noun and verb that we don't distinguish | ||
Stream.concat( | ||
Map.of('h', NOUN, 'o', NOUN, 't', VERB, 'i', VERB).entrySet().stream(), | ||
Stream.of(PartOfSpeech.values()).collect(Collectors.toMap(PartOfSpeech::getCode, Function.identity())).entrySet().stream() | ||
) | ||
.collect( | ||
Collectors.toUnmodifiableMap(Map.Entry<Character, PartOfSpeech>::getKey, Map.Entry<Character, PartOfSpeech>::getValue) | ||
); | ||
|
||
static PartOfSpeech fromCode(char partOfSpeechCode) { | ||
PartOfSpeech pos = CODE_MAPPING.get(partOfSpeechCode); | ||
if (pos == null) { | ||
throw new IllegalArgumentException("Unknown part-of-speech code [" + partOfSpeechCode + "]"); | ||
} | ||
return pos; | ||
} | ||
} | ||
|
||
/** | ||
* Lazy loaded singleton instance to avoid loading the dictionary repeatedly. | ||
*/ | ||
private static CategorizationPartOfSpeechDictionary instance; | ||
private static final Object INIT_LOCK = new Object(); | ||
|
||
/** | ||
* Keys are lower case. | ||
*/ | ||
private final Map<String, PartOfSpeech> partOfSpeechDictionary = new HashMap<>(); | ||
private final int maxDictionaryWordLength; | ||
|
||
CategorizationPartOfSpeechDictionary(InputStream is) throws IOException { | ||
|
||
int maxLength = 0; | ||
BufferedReader reader = new BufferedReader(new InputStreamReader(is, StandardCharsets.UTF_8)); | ||
String line; | ||
while ((line = reader.readLine()) != null) { | ||
line = line.trim(); | ||
if (line.isEmpty()) { | ||
continue; | ||
} | ||
String[] split = line.split(PART_OF_SPEECH_SEPARATOR); | ||
if (split.length != 2) { | ||
throw new IllegalArgumentException( | ||
"Unexpected format in line [" + line + "]: expected one [" + PART_OF_SPEECH_SEPARATOR + "] separator" | ||
); | ||
} | ||
if (split[0].isEmpty()) { | ||
throw new IllegalArgumentException( | ||
"Unexpected format in line [" + line + "]: nothing preceding [" + PART_OF_SPEECH_SEPARATOR + "] separator" | ||
); | ||
} | ||
if (split[1].isEmpty()) { | ||
throw new IllegalArgumentException( | ||
"Unexpected format in line [" + line + "]: nothing following [" + PART_OF_SPEECH_SEPARATOR + "] separator" | ||
); | ||
} | ||
String lowerCaseWord = split[0].toLowerCase(Locale.ROOT); | ||
partOfSpeechDictionary.put(lowerCaseWord, PartOfSpeech.fromCode(split[1].charAt(0))); | ||
maxLength = Math.max(maxLength, lowerCaseWord.length()); | ||
} | ||
maxDictionaryWordLength = maxLength; | ||
} | ||
|
||
// TODO: now we have this in Java, perform this operation in Java for anomaly detection categorization instead of in C++. | ||
// (It could maybe be incorporated into the categorization analyzer and then shared between aggregation and anomaly detection.) | ||
/** | ||
* Find the part of speech (noun, verb, adjective, etc.) for a supplied word. | ||
* @return Which part of speech does the supplied word represent? {@link PartOfSpeech#NOT_IN_DICTIONARY} is returned | ||
* for words that aren't in the dictionary at all. | ||
*/ | ||
public PartOfSpeech getPartOfSpeech(String word) { | ||
if (word.length() > maxDictionaryWordLength) { | ||
return PartOfSpeech.NOT_IN_DICTIONARY; | ||
} | ||
// This is quite slow as it creates a new string for every lookup. However, experiments show | ||
// that trying to do case-insensitive comparisons instead of creating a lower case string is | ||
// even slower. | ||
return partOfSpeechDictionary.getOrDefault(word.toLowerCase(Locale.ROOT), PartOfSpeech.NOT_IN_DICTIONARY); | ||
} | ||
|
||
/** | ||
* @return Is the supplied word in the dictionary? | ||
*/ | ||
public boolean isInDictionary(String word) { | ||
return getPartOfSpeech(word) != PartOfSpeech.NOT_IN_DICTIONARY; | ||
} | ||
|
||
public static CategorizationPartOfSpeechDictionary getInstance() throws IOException { | ||
if (instance != null) { | ||
return instance; | ||
} | ||
synchronized (INIT_LOCK) { | ||
if (instance == null) { | ||
try (InputStream is = CategorizationPartOfSpeechDictionary.class.getResourceAsStream(DICTIONARY_FILE_PATH)) { | ||
instance = new CategorizationPartOfSpeechDictionary(is); | ||
} | ||
} | ||
return instance; | ||
} | ||
} | ||
} | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
++
I am not sure we need to add bytes to the circuit breaker for this or not. I would say if it is near a MB we may want to.
Basically,
getInstance
could take the circuit breaker and add bytes if it is loaded, ignoring it if not (since it would have already added bytes). And those bytes just stay for the lifetime of the node.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's best not to add it to the same circuit breaker used by the rest of the aggregation.
Although it's large it's effectively static data, so it would make most sense to include it with the "Accounting requests circuit breaker" rather than the "Request circuit breaker". But if
indices.breaker.total.use_real_memory
is set totrue
, which it is by default, then that "memory usage of things held in memory that are not released when a request is completed" will take it into account automatically.I guess we could try to explicitly add it into the "Accounting requests circuit breaker" for the case where real memory circuit breaking is disabled. But this will be messy within the code as the code is written on the basis that what the docs refer to as "memory usage of things held in memory that are not released when a request is completed" is actually field data related to Lucene indices.
The docs also say about the total memory all circuit breakers can use: "Defaults to 70% of JVM heap if
indices.breaker.total.use_real_memory
is false. Ifindices.breaker.total.use_real_memory
is true, defaults to 95% of the JVM heap." So that implies that if you don't use the real memory circuit breaker to measure fixed overheads then you have to allow some space for unmeasured fixed overheads. So I think this dictionary can be treated as one of those fixed overheads that either gets captured by the real memory circuit breaker or by implicitly reserving a percentage of memory.