Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Replace the implementation of the categorize_text aggregation #85872

Merged
merged 18 commits into from
May 23, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.xpack.ml.aggs.categorization;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
Expand Down Expand Up @@ -46,6 +47,13 @@ public class CategorizeTextAggregationBuilder extends AbstractAggregationBuilder

public static final String NAME = "categorize_text";

// In 8.3 the algorithm used by this aggregation was completely changed.
// Prior to 8.3 the Drain algorithm was used. From 8.3 the same algorithm
// we use in our C++ categorization code was used. As a result of this
// the aggregation will not perform well in mixed version clusters where
// some nodes are pre-8.3 and others are newer, so we throw an error in
// this situation. The aggregation was experimental at the time this change
// was made, so this is acceptable.
public static final Version ALGORITHM_CHANGED_VERSION = Version.V_8_3_0;

static final ParseField FIELD_NAME = new ParseField("field");
Expand Down Expand Up @@ -116,15 +124,18 @@ public CategorizeTextAggregationBuilder setFieldName(String fieldName) {

public CategorizeTextAggregationBuilder(StreamInput in) throws IOException {
super(in);
this.bucketCountThresholds = new TermsAggregator.BucketCountThresholds(in);
this.fieldName = in.readString();
// If the coordinating node is an older version then we might still receive messages from older
// nodes. In this case we can send back results for this node created using the new algorithm.
// They won't necessarily merge well with results from other nodes, but are better than nothing.
// Disallow this aggregation in mixed version clusters that cross the algorithm change boundary.
if (in.getVersion().before(ALGORITHM_CHANGED_VERSION)) {
in.readVInt(); // maxUniqueTokens
in.readVInt(); // maxMatchedTokens
throw new ElasticsearchException(
"["
+ NAME
+ "] aggregation cannot be used in a cluster where some nodes have version ["
+ ALGORITHM_CHANGED_VERSION
+ "] or higher and others have a version before this"
);
}
this.bucketCountThresholds = new TermsAggregator.BucketCountThresholds(in);
this.fieldName = in.readString();
this.similarityThreshold = in.readVInt();
this.categorizationAnalyzerConfig = in.readOptionalWriteable(CategorizationAnalyzerConfig::new);
}
Expand Down Expand Up @@ -269,6 +280,16 @@ protected CategorizeTextAggregationBuilder(

@Override
protected void doWriteTo(StreamOutput out) throws IOException {
// Disallow this aggregation in mixed version clusters that cross the algorithm change boundary.
if (out.getVersion().before(ALGORITHM_CHANGED_VERSION)) {
Comment on lines +283 to +284
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we MAY get this for free with the versioned named writeable stuff. But, keeping this here is cool with me.

throw new ElasticsearchException(
"["
+ NAME
+ "] aggregation cannot be used in a cluster where some nodes have version ["
+ ALGORITHM_CHANGED_VERSION
+ "] or higher and others have a version before this"
);
}
bucketCountThresholds.writeTo(out);
out.writeString(fieldName);
out.writeVInt(similarityThreshold);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.ml.aggs.categorization;

import org.apache.lucene.util.BytesRef;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.BytesRefHash;
Expand Down Expand Up @@ -106,30 +107,35 @@ public Bucket(SerializableTokenListCategory serializableCategory, long bucketOrd
}

public Bucket(StreamInput in) throws IOException {
// Disallow this aggregation in mixed version clusters that cross the algorithm change boundary.
if (in.getVersion().before(CategorizeTextAggregationBuilder.ALGORITHM_CHANGED_VERSION)) {
// This shouldn't happen because a coordinating node from after the algorithm change
// won't have sent requests to nodes from before the algorithm change. But if we get
// here then just use a dummy empty category to avoid crashing the whole aggregation.
in.readArray(StreamInput::readBytesRef, BytesRef[]::new); // key
in.readVLong(); // docCount
serializableCategory = SerializableTokenListCategory.EMPTY;
} else {
serializableCategory = new SerializableTokenListCategory(in);
throw new ElasticsearchException(
"["
+ CategorizeTextAggregationBuilder.NAME
+ "] aggregation cannot be used in a cluster where some nodes have version ["
+ CategorizeTextAggregationBuilder.ALGORITHM_CHANGED_VERSION
+ "] or higher and others have a version before this"
);
}
serializableCategory = new SerializableTokenListCategory(in);
key = new BucketKey(serializableCategory);
bucketOrd = -1;
aggregations = InternalAggregations.readFrom(in);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
// Disallow this aggregation in mixed version clusters that cross the algorithm change boundary.
if (out.getVersion().before(CategorizeTextAggregationBuilder.ALGORITHM_CHANGED_VERSION)) {
// Send the best results we can if the coordinating node is on an old version.
out.writeArray(StreamOutput::writeBytesRef, serializableCategory.getKeyTokens()); // key
out.writeVLong(serializableCategory.getNumMatches()); // docCount
} else {
serializableCategory.writeTo(out);
throw new ElasticsearchException(
"["
+ CategorizeTextAggregationBuilder.NAME
+ "] aggregation cannot be used in a cluster where some nodes have version ["
+ CategorizeTextAggregationBuilder.ALGORITHM_CHANGED_VERSION
+ "] or higher and others have a version before this"
);
}
serializableCategory.writeTo(out);
aggregations.writeTo(out);
}

Expand Down Expand Up @@ -230,10 +236,15 @@ protected InternalCategorizationAggregation(

public InternalCategorizationAggregation(StreamInput in) throws IOException {
super(in);
// Disallow this aggregation in mixed version clusters that cross the algorithm change boundary.
if (in.getVersion().before(CategorizeTextAggregationBuilder.ALGORITHM_CHANGED_VERSION)) {
// These are no longer used.
in.readVInt(); // maxUniqueTokens
in.readVInt(); // maxMatchedTokens
throw new ElasticsearchException(
"["
+ CategorizeTextAggregationBuilder.NAME
+ "] aggregation cannot be used in a cluster where some nodes have version ["
+ CategorizeTextAggregationBuilder.ALGORITHM_CHANGED_VERSION
+ "] or higher and others have a version before this"
);
}
this.similarityThreshold = in.readVInt();
this.buckets = in.readList(Bucket::new);
Expand All @@ -243,10 +254,15 @@ public InternalCategorizationAggregation(StreamInput in) throws IOException {

@Override
protected void doWriteTo(StreamOutput out) throws IOException {
// Disallow this aggregation in mixed version clusters that cross the algorithm change boundary.
if (out.getVersion().before(CategorizeTextAggregationBuilder.ALGORITHM_CHANGED_VERSION)) {
// These were the defaults prior to the algorithm change.
out.writeVInt(50); // maxUniqueTokens
out.writeVInt(5); // maxMatchedTokens
throw new ElasticsearchException(
"["
+ CategorizeTextAggregationBuilder.NAME
+ "] aggregation cannot be used in a cluster where some nodes have version ["
+ CategorizeTextAggregationBuilder.ALGORITHM_CHANGED_VERSION
+ "] or higher and others have a version before this"
);
}
out.writeVInt(similarityThreshold);
out.writeList(buckets);
Expand Down