Skip to content

Commit

Permalink
Add group size control for V1 and V2 (#1158)
Browse files Browse the repository at this point in the history
Signed-off-by: yhmo <[email protected]>
  • Loading branch information
yhmo authored Oct 31, 2024
1 parent 4e84d69 commit c7ed095
Show file tree
Hide file tree
Showing 8 changed files with 184 additions and 4 deletions.
3 changes: 3 additions & 0 deletions src/main/java/io/milvus/param/Constant.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ public class Constant {
public static final String REDUCE_STOP_FOR_BEST = "reduce_stop_for_best";
public static final String ITERATOR_FIELD = "iterator";
public static final String GROUP_BY_FIELD = "group_by_field";
public static final String GROUP_SIZE = "group_size";
public static final String GROUP_STRICT_SIZE = "group_strict_size";

public static final String INDEX_TYPE = "index_type";
public static final String METRIC_TYPE = "metric_type";
public static final String ROUND_DECIMAL = "round_decimal";
Expand Down
36 changes: 36 additions & 0 deletions src/main/java/io/milvus/param/ParamUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,20 @@ public static SearchRequest convertSearchParam(@NonNull SearchParam requestParam
.setKey(Constant.GROUP_BY_FIELD)
.setValue(requestParam.getGroupByFieldName())
.build());
if (requestParam.getGroupSize() != null) {
builder.addSearchParams(
KeyValuePair.newBuilder()
.setKey(Constant.GROUP_SIZE)
.setValue(requestParam.getGroupSize().toString())
.build());
}
if (requestParam.getGroupStrictSize() != null) {
builder.addSearchParams(
KeyValuePair.newBuilder()
.setKey(Constant.GROUP_STRICT_SIZE)
.setValue(requestParam.getGroupStrictSize().toString())
.build());
}
}

if (null != requestParam.getParams() && !requestParam.getParams().isEmpty()) {
Expand Down Expand Up @@ -937,6 +951,28 @@ public static HybridSearchRequest convertHybridSearchParam(@NonNull HybridSearch
builder.addRequests(searchRequest);
}

if (!StringUtils.isEmpty(requestParam.getGroupByFieldName())) {
builder.addRankParams(
KeyValuePair.newBuilder()
.setKey(Constant.GROUP_BY_FIELD)
.setValue(requestParam.getGroupByFieldName())
.build());
if (requestParam.getGroupSize() != null) {
builder.addRankParams(
KeyValuePair.newBuilder()
.setKey(Constant.GROUP_SIZE)
.setValue(requestParam.getGroupSize().toString())
.build());
}
if (requestParam.getGroupStrictSize() != null) {
builder.addRankParams(
KeyValuePair.newBuilder()
.setKey(Constant.GROUP_STRICT_SIZE)
.setValue(requestParam.getGroupStrictSize().toString())
.build());
}
}

// set ranker
BaseRanker ranker = requestParam.getRanker();
Map<String, String> props = ranker.getProperties();
Expand Down
52 changes: 48 additions & 4 deletions src/main/java/io/milvus/param/dml/HybridSearchParam.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,14 @@
import com.google.common.collect.Lists;
import io.milvus.common.clientenum.ConsistencyLevelEnum;
import io.milvus.exception.ParamException;
import io.milvus.param.Constant;
import io.milvus.param.MetricType;
import io.milvus.param.ParamUtils;

import io.milvus.param.dml.ranker.BaseRanker;
import lombok.Getter;
import lombok.NonNull;
import lombok.ToString;

import java.nio.ByteBuffer;
import java.util.List;
import java.util.SortedMap;

/**
* Parameters for <code>search</code> interface.
Expand All @@ -51,6 +47,10 @@ public class HybridSearchParam {
private final int roundDecimal;
private final ConsistencyLevelEnum consistencyLevel;

private final String groupByFieldName;
private final Integer groupSize;
private final Boolean groupStrictSize;

private HybridSearchParam(@NonNull Builder builder) {
this.databaseName = builder.databaseName;
this.collectionName = builder.collectionName;
Expand All @@ -61,6 +61,9 @@ private HybridSearchParam(@NonNull Builder builder) {
this.outFields = builder.outFields;
this.roundDecimal = builder.roundDecimal;
this.consistencyLevel = builder.consistencyLevel;
this.groupByFieldName = builder.groupByFieldName;
this.groupSize = builder.groupSize;
this.groupStrictSize = builder.groupStrictSize;
}

public static Builder newBuilder() {
Expand All @@ -80,6 +83,9 @@ public static class Builder {
private final List<String> outFields = Lists.newArrayList();
private Integer roundDecimal = -1;
private ConsistencyLevelEnum consistencyLevel = null;
private String groupByFieldName = null;
private Integer groupSize = null;
private Boolean groupStrictSize = null;

Builder() {
}
Expand Down Expand Up @@ -209,6 +215,40 @@ public Builder withRoundDecimal(@NonNull Integer decimal) {
return this;
}

/**
* Groups the results by a scalar field name.
*
* @param fieldName a scalar field name
* @return <code>Builder</code>
*/
public Builder withGroupByFieldName(@NonNull String groupByFieldName) {
this.groupByFieldName = groupByFieldName;
return this;
}

/**
* Defines the max number of items for each group, the value must greater than zero.
*
* @param groupSize the max number of items
* @return <code>Builder</code>
*/
public Builder withGroupSize(@NonNull Integer groupSize) {
this.groupSize = groupSize;
return this;
}

/**
* Whether to force the number of each group to be groupSize.
* Set to false, milvus might return some groups with number of items less than groupSize.
*
* @param groupStrictSize whether to force the number of each group to be groupSize
* @return <code>Builder</code>
*/
public Builder withGroupStrictSize(@NonNull Boolean groupStrictSize) {
this.groupStrictSize = groupStrictSize;
return this;
}

/**
* Verifies parameters and creates a new {@link HybridSearchParam} instance.
*
Expand Down Expand Up @@ -238,6 +278,10 @@ public HybridSearchParam build() throws ParamException {
throw new ParamException("TopK value is illegal");
}

if (groupByFieldName != null && groupSize != null && groupSize <= 0) {
throw new ParamException("GroupSize value cannot be zero or negative");
}

return new HybridSearchParam(this);
}
}
Expand Down
33 changes: 33 additions & 0 deletions src/main/java/io/milvus/param/dml/SearchParam.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ public class SearchParam {
private final ConsistencyLevelEnum consistencyLevel;
private final boolean ignoreGrowing;
private final String groupByFieldName;
private final Integer groupSize;
private final Boolean groupStrictSize;
private final PlaceholderType plType;
private final boolean iterator;

Expand All @@ -78,6 +80,8 @@ private SearchParam(@NonNull Builder builder) {
this.consistencyLevel = builder.consistencyLevel;
this.ignoreGrowing = builder.ignoreGrowing;
this.groupByFieldName = builder.groupByFieldName;
this.groupSize = builder.groupSize;
this.groupStrictSize = builder.groupStrictSize;
this.plType = builder.plType;
this.iterator = builder.iterator;
}
Expand Down Expand Up @@ -108,6 +112,8 @@ public static class Builder {
private ConsistencyLevelEnum consistencyLevel = null;
private Boolean ignoreGrowing = Boolean.FALSE;
private String groupByFieldName;
private Integer groupSize = null;
private Boolean groupStrictSize = null;
private Boolean iterator = Boolean.FALSE;

// plType is used to distinct vector type
Expand Down Expand Up @@ -377,6 +383,29 @@ public Builder withGroupByFieldName(@NonNull String groupByFieldName) {
return this;
}

/**
* Defines the max number of items for each group, the value must greater than zero.
*
* @param groupSize the max number of items
* @return <code>Builder</code>
*/
public Builder withGroupSize(@NonNull Integer groupSize) {
this.groupSize = groupSize;
return this;
}

/**
* Whether to force the number of each group to be groupSize.
* Set to false, milvus might return some groups with number of items less than groupSize.
*
* @param groupStrictSize whether to force the number of each group to be groupSize
* @return <code>Builder</code>
*/
public Builder withGroupStrictSize(@NonNull Boolean groupStrictSize) {
this.groupStrictSize = groupStrictSize;
return this;
}

/**
* Optimizing specifically for iterators can yield correct data results. Default is False.
*
Expand Down Expand Up @@ -411,6 +440,10 @@ public SearchParam build() throws ParamException {

SearchParam.verifyVectors(vectors);

if (groupByFieldName != null && groupSize != null && groupSize <= 0) {
throw new ParamException("GroupSize value cannot be zero or negative");
}

return new SearchParam(this);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package io.milvus.v2.service.vector.request;

import io.milvus.v2.common.IndexParam;
import io.milvus.v2.service.vector.request.data.BaseVector;
import lombok.Builder;
import lombok.Data;
Expand All @@ -35,4 +36,7 @@ public class AnnSearchReq {
private String expr = "";
private List<BaseVector> vectors;
private String params;

@Builder.Default
private IndexParam.MetricType metricType = null;
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,8 @@ public class HybridSearchReq
private int roundDecimal = -1;
@Builder.Default
private ConsistencyLevel consistencyLevel = null;

private String groupByFieldName;
private Integer groupSize;
private Boolean groupStrictSize;
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package io.milvus.v2.service.vector.request;

import io.milvus.v2.common.ConsistencyLevel;
import io.milvus.v2.common.IndexParam;
import io.milvus.v2.service.vector.request.data.BaseVector;
import lombok.Builder;
import lombok.Data;
Expand All @@ -38,6 +39,7 @@ public class SearchReq {
private List<String> partitionNames = new ArrayList<>();
@Builder.Default
private String annsField = "";
private IndexParam.MetricType metricType;
private int topK;
private String filter;
@Builder.Default
Expand All @@ -57,6 +59,8 @@ public class SearchReq {
private ConsistencyLevel consistencyLevel = null;
private boolean ignoreGrowing;
private String groupByFieldName;
private Integer groupSize;
private Boolean groupStrictSize;

// Expression template, to improve expression parsing performance in complicated list
// Assume user has a filter = "pk > 3 and city in ["beijing", "shanghai", ......]
Expand Down
52 changes: 52 additions & 0 deletions src/main/java/io/milvus/v2/utils/VectorUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,13 @@ public SearchRequest ConvertToGrpcSearchRequest(SearchReq request) {
.setValue(String.valueOf(request.getOffset()))
.build());

if (null != request.getMetricType()) {
builder.addSearchParams(
KeyValuePair.newBuilder()
.setKey(Constant.METRIC_TYPE)
.setValue(request.getMetricType().name())
.build());
}

if (null != request.getSearchParams()) {
try {
Expand All @@ -176,6 +183,21 @@ public SearchRequest ConvertToGrpcSearchRequest(SearchReq request) {
.setKey(Constant.GROUP_BY_FIELD)
.setValue(request.getGroupByFieldName())
.build());
if (request.getGroupSize() != null) {
builder.addSearchParams(
KeyValuePair.newBuilder()
.setKey(Constant.GROUP_SIZE)
.setValue(request.getGroupSize().toString())
.build());
}

if (request.getGroupStrictSize() != null) {
builder.addSearchParams(
KeyValuePair.newBuilder()
.setKey(Constant.GROUP_STRICT_SIZE)
.setValue(request.getGroupStrictSize().toString())
.build());
}
}

if (!request.getOutputFields().isEmpty()) {
Expand Down Expand Up @@ -287,6 +309,13 @@ public static SearchRequest convertAnnSearchParam(@NonNull AnnSearchReq annSearc
.setKey(Constant.TOP_K)
.setValue(String.valueOf(annSearchReq.getTopK()))
.build());
if (annSearchReq.getMetricType() != null) {
builder.addSearchParams(
KeyValuePair.newBuilder()
.setKey(Constant.METRIC_TYPE)
.setValue(annSearchReq.getMetricType().name())
.build());
}

// params
String params = "{}";
Expand Down Expand Up @@ -347,6 +376,29 @@ public HybridSearchRequest ConvertToGrpcHybridSearchRequest(HybridSearchReq requ
propertiesList.forEach(builder::addRankParams);
}

if (request.getGroupByFieldName() != null && !request.getGroupByFieldName().isEmpty()) {
builder.addRankParams(
KeyValuePair.newBuilder()
.setKey(Constant.GROUP_BY_FIELD)
.setValue(request.getGroupByFieldName())
.build());
if (request.getGroupSize() != null) {
builder.addRankParams(
KeyValuePair.newBuilder()
.setKey(Constant.GROUP_SIZE)
.setValue(request.getGroupSize().toString().toString())
.build());
}

if (request.getGroupStrictSize() != null) {
builder.addRankParams(
KeyValuePair.newBuilder()
.setKey(Constant.GROUP_STRICT_SIZE)
.setValue(request.getGroupStrictSize().toString())
.build());
}
}

// output fields
if (request.getOutFields() != null && !request.getOutFields().isEmpty()) {
request.getOutFields().forEach(builder::addOutputFields);
Expand Down

0 comments on commit c7ed095

Please sign in to comment.