Skip to content

Commit

Permalink
Add counted_terms agg (WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
danielmitterdorfer committed Nov 3, 2023
1 parent 692572a commit 6664b1c
Show file tree
Hide file tree
Showing 6 changed files with 924 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.index.mapper.extras;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.aggregations.AggregatorFactory;
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregator;
import org.elasticsearch.search.aggregations.support.AggregationContext;
import org.elasticsearch.search.aggregations.support.CoreValuesSourceType;
import org.elasticsearch.search.aggregations.support.ValuesSourceAggregationBuilder;
import org.elasticsearch.search.aggregations.support.ValuesSourceAggregatorFactory;
import org.elasticsearch.search.aggregations.support.ValuesSourceConfig;
import org.elasticsearch.search.aggregations.support.ValuesSourceRegistry;
import org.elasticsearch.search.aggregations.support.ValuesSourceType;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Map;

public class CountedTermsAggregationBuilder extends ValuesSourceAggregationBuilder<CountedTermsAggregationBuilder> {
public static final String NAME = "counted_terms";
public static final ValuesSourceRegistry.RegistryKey<CountedTermsAggregatorSupplier> REGISTRY_KEY =
new ValuesSourceRegistry.RegistryKey<>(NAME, CountedTermsAggregatorSupplier.class);

public static final ParseField REQUIRED_SIZE_FIELD_NAME = new ParseField("size");

public static final ObjectParser<CountedTermsAggregationBuilder, String> PARSER = ObjectParser.fromBuilder(
NAME,
CountedTermsAggregationBuilder::new
);
static {
ValuesSourceAggregationBuilder.declareFields(PARSER, true, true, false);

PARSER.declareInt(CountedTermsAggregationBuilder::size, REQUIRED_SIZE_FIELD_NAME);
}

// see TermsAggregationBuilder.DEFAULT_BUCKET_COUNT_THRESHOLDS
private final TermsAggregator.BucketCountThresholds bucketCountThresholds = new TermsAggregator.BucketCountThresholds(1, 0, 10, -1);

protected CountedTermsAggregationBuilder(String name) {
super(name);
}

protected CountedTermsAggregationBuilder(
ValuesSourceAggregationBuilder<CountedTermsAggregationBuilder> clone,
AggregatorFactories.Builder factoriesBuilder,
Map<String, Object> metadata
) {
super(clone, factoriesBuilder, metadata);
}

protected CountedTermsAggregationBuilder(StreamInput in) throws IOException {
super(in);
}

public static void registerAggregators(ValuesSourceRegistry.Builder builder) {
CountedTermsAggregatorFactory.registerAggregators(builder);
}

public CountedTermsAggregationBuilder size(int size) {
if (size <= 0) {
throw new IllegalArgumentException("[size] must be greater than 0. Found [" + size + "] in [" + name + "]");
}
bucketCountThresholds.setRequiredSize(size);
return this;
}

@Override
public TransportVersion getMinimalSupportedVersion() {
// TODO: Create a new transport version and use that
return TransportVersions.ZERO;
}

@Override
protected AggregationBuilder shallowCopy(AggregatorFactories.Builder factoriesBuilder, Map<String, Object> metadata) {
return new CountedTermsAggregationBuilder(this, factoriesBuilder, metadata);
}

@Override
public BucketCardinality bucketCardinality() {
return BucketCardinality.MANY;
}

@Override
public String getType() {
return NAME;
}

@Override
protected void innerWriteTo(StreamOutput out) throws IOException {
bucketCountThresholds.writeTo(out);
}

@Override
protected ValuesSourceRegistry.RegistryKey<?> getRegistryKey() {
// method is unused - no need to implement it
return null;
}

@Override
protected ValuesSourceType defaultValueSourceType() {
// TODO: Do we need a new source type now or is this good enough?
return CoreValuesSourceType.KEYWORD;
}

@Override
protected ValuesSourceAggregatorFactory innerBuild(
AggregationContext context,
ValuesSourceConfig config,
AggregatorFactory parent,
AggregatorFactories.Builder subFactoriesBuilder
) throws IOException {
CountedTermsAggregatorSupplier aggregatorSupplier = context.getValuesSourceRegistry().getAggregator(REGISTRY_KEY, config);
return new CountedTermsAggregatorFactory(
name,
config,
bucketCountThresholds,
context,
parent,
subFactoriesBuilder,
metadata,
aggregatorSupplier
);
}

@Override
protected XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
bucketCountThresholds.toXContent(builder, params);
return builder;
}
}
Loading

0 comments on commit 6664b1c

Please sign in to comment.