From c7ed095fdfd5b239d44f89f59265d42252433f88 Mon Sep 17 00:00:00 2001 From: groot Date: Thu, 31 Oct 2024 10:10:17 +0800 Subject: [PATCH] Add group size control for V1 and V2 (#1158) Signed-off-by: yhmo --- src/main/java/io/milvus/param/Constant.java | 3 ++ src/main/java/io/milvus/param/ParamUtils.java | 36 +++++++++++++ .../milvus/param/dml/HybridSearchParam.java | 52 +++++++++++++++++-- .../java/io/milvus/param/dml/SearchParam.java | 33 ++++++++++++ .../service/vector/request/AnnSearchReq.java | 4 ++ .../vector/request/HybridSearchReq.java | 4 ++ .../v2/service/vector/request/SearchReq.java | 4 ++ .../java/io/milvus/v2/utils/VectorUtils.java | 52 +++++++++++++++++++ 8 files changed, 184 insertions(+), 4 deletions(-) diff --git a/src/main/java/io/milvus/param/Constant.java b/src/main/java/io/milvus/param/Constant.java index 5e566ef7e..9a0c26a9c 100644 --- a/src/main/java/io/milvus/param/Constant.java +++ b/src/main/java/io/milvus/param/Constant.java @@ -34,6 +34,9 @@ public class Constant { public static final String REDUCE_STOP_FOR_BEST = "reduce_stop_for_best"; public static final String ITERATOR_FIELD = "iterator"; public static final String GROUP_BY_FIELD = "group_by_field"; + public static final String GROUP_SIZE = "group_size"; + public static final String GROUP_STRICT_SIZE = "group_strict_size"; + public static final String INDEX_TYPE = "index_type"; public static final String METRIC_TYPE = "metric_type"; public static final String ROUND_DECIMAL = "round_decimal"; diff --git a/src/main/java/io/milvus/param/ParamUtils.java b/src/main/java/io/milvus/param/ParamUtils.java index 7bdfbcff9..911d37593 100644 --- a/src/main/java/io/milvus/param/ParamUtils.java +++ b/src/main/java/io/milvus/param/ParamUtils.java @@ -824,6 +824,20 @@ public static SearchRequest convertSearchParam(@NonNull SearchParam requestParam .setKey(Constant.GROUP_BY_FIELD) .setValue(requestParam.getGroupByFieldName()) .build()); + if (requestParam.getGroupSize() != null) { + builder.addSearchParams( + KeyValuePair.newBuilder() + .setKey(Constant.GROUP_SIZE) + .setValue(requestParam.getGroupSize().toString()) + .build()); + } + if (requestParam.getGroupStrictSize() != null) { + builder.addSearchParams( + KeyValuePair.newBuilder() + .setKey(Constant.GROUP_STRICT_SIZE) + .setValue(requestParam.getGroupStrictSize().toString()) + .build()); + } } if (null != requestParam.getParams() && !requestParam.getParams().isEmpty()) { @@ -937,6 +951,28 @@ public static HybridSearchRequest convertHybridSearchParam(@NonNull HybridSearch builder.addRequests(searchRequest); } + if (!StringUtils.isEmpty(requestParam.getGroupByFieldName())) { + builder.addRankParams( + KeyValuePair.newBuilder() + .setKey(Constant.GROUP_BY_FIELD) + .setValue(requestParam.getGroupByFieldName()) + .build()); + if (requestParam.getGroupSize() != null) { + builder.addRankParams( + KeyValuePair.newBuilder() + .setKey(Constant.GROUP_SIZE) + .setValue(requestParam.getGroupSize().toString()) + .build()); + } + if (requestParam.getGroupStrictSize() != null) { + builder.addRankParams( + KeyValuePair.newBuilder() + .setKey(Constant.GROUP_STRICT_SIZE) + .setValue(requestParam.getGroupStrictSize().toString()) + .build()); + } + } + // set ranker BaseRanker ranker = requestParam.getRanker(); Map props = ranker.getProperties(); diff --git a/src/main/java/io/milvus/param/dml/HybridSearchParam.java b/src/main/java/io/milvus/param/dml/HybridSearchParam.java index d063656c6..8b8fc5d38 100644 --- a/src/main/java/io/milvus/param/dml/HybridSearchParam.java +++ b/src/main/java/io/milvus/param/dml/HybridSearchParam.java @@ -22,8 +22,6 @@ import com.google.common.collect.Lists; import io.milvus.common.clientenum.ConsistencyLevelEnum; import io.milvus.exception.ParamException; -import io.milvus.param.Constant; -import io.milvus.param.MetricType; import io.milvus.param.ParamUtils; import io.milvus.param.dml.ranker.BaseRanker; @@ -31,9 +29,7 @@ import lombok.NonNull; import lombok.ToString; -import java.nio.ByteBuffer; import java.util.List; -import java.util.SortedMap; /** * Parameters for search interface. @@ -51,6 +47,10 @@ public class HybridSearchParam { private final int roundDecimal; private final ConsistencyLevelEnum consistencyLevel; + private final String groupByFieldName; + private final Integer groupSize; + private final Boolean groupStrictSize; + private HybridSearchParam(@NonNull Builder builder) { this.databaseName = builder.databaseName; this.collectionName = builder.collectionName; @@ -61,6 +61,9 @@ private HybridSearchParam(@NonNull Builder builder) { this.outFields = builder.outFields; this.roundDecimal = builder.roundDecimal; this.consistencyLevel = builder.consistencyLevel; + this.groupByFieldName = builder.groupByFieldName; + this.groupSize = builder.groupSize; + this.groupStrictSize = builder.groupStrictSize; } public static Builder newBuilder() { @@ -80,6 +83,9 @@ public static class Builder { private final List outFields = Lists.newArrayList(); private Integer roundDecimal = -1; private ConsistencyLevelEnum consistencyLevel = null; + private String groupByFieldName = null; + private Integer groupSize = null; + private Boolean groupStrictSize = null; Builder() { } @@ -209,6 +215,40 @@ public Builder withRoundDecimal(@NonNull Integer decimal) { return this; } + /** + * Groups the results by a scalar field name. + * + * @param fieldName a scalar field name + * @return Builder + */ + public Builder withGroupByFieldName(@NonNull String groupByFieldName) { + this.groupByFieldName = groupByFieldName; + return this; + } + + /** + * Defines the max number of items for each group, the value must greater than zero. + * + * @param groupSize the max number of items + * @return Builder + */ + public Builder withGroupSize(@NonNull Integer groupSize) { + this.groupSize = groupSize; + return this; + } + + /** + * Whether to force the number of each group to be groupSize. + * Set to false, milvus might return some groups with number of items less than groupSize. + * + * @param groupStrictSize whether to force the number of each group to be groupSize + * @return Builder + */ + public Builder withGroupStrictSize(@NonNull Boolean groupStrictSize) { + this.groupStrictSize = groupStrictSize; + return this; + } + /** * Verifies parameters and creates a new {@link HybridSearchParam} instance. * @@ -238,6 +278,10 @@ public HybridSearchParam build() throws ParamException { throw new ParamException("TopK value is illegal"); } + if (groupByFieldName != null && groupSize != null && groupSize <= 0) { + throw new ParamException("GroupSize value cannot be zero or negative"); + } + return new HybridSearchParam(this); } } diff --git a/src/main/java/io/milvus/param/dml/SearchParam.java b/src/main/java/io/milvus/param/dml/SearchParam.java index 18eebe9bf..09293b107 100644 --- a/src/main/java/io/milvus/param/dml/SearchParam.java +++ b/src/main/java/io/milvus/param/dml/SearchParam.java @@ -56,6 +56,8 @@ public class SearchParam { private final ConsistencyLevelEnum consistencyLevel; private final boolean ignoreGrowing; private final String groupByFieldName; + private final Integer groupSize; + private final Boolean groupStrictSize; private final PlaceholderType plType; private final boolean iterator; @@ -78,6 +80,8 @@ private SearchParam(@NonNull Builder builder) { this.consistencyLevel = builder.consistencyLevel; this.ignoreGrowing = builder.ignoreGrowing; this.groupByFieldName = builder.groupByFieldName; + this.groupSize = builder.groupSize; + this.groupStrictSize = builder.groupStrictSize; this.plType = builder.plType; this.iterator = builder.iterator; } @@ -108,6 +112,8 @@ public static class Builder { private ConsistencyLevelEnum consistencyLevel = null; private Boolean ignoreGrowing = Boolean.FALSE; private String groupByFieldName; + private Integer groupSize = null; + private Boolean groupStrictSize = null; private Boolean iterator = Boolean.FALSE; // plType is used to distinct vector type @@ -377,6 +383,29 @@ public Builder withGroupByFieldName(@NonNull String groupByFieldName) { return this; } + /** + * Defines the max number of items for each group, the value must greater than zero. + * + * @param groupSize the max number of items + * @return Builder + */ + public Builder withGroupSize(@NonNull Integer groupSize) { + this.groupSize = groupSize; + return this; + } + + /** + * Whether to force the number of each group to be groupSize. + * Set to false, milvus might return some groups with number of items less than groupSize. + * + * @param groupStrictSize whether to force the number of each group to be groupSize + * @return Builder + */ + public Builder withGroupStrictSize(@NonNull Boolean groupStrictSize) { + this.groupStrictSize = groupStrictSize; + return this; + } + /** * Optimizing specifically for iterators can yield correct data results. Default is False. * @@ -411,6 +440,10 @@ public SearchParam build() throws ParamException { SearchParam.verifyVectors(vectors); + if (groupByFieldName != null && groupSize != null && groupSize <= 0) { + throw new ParamException("GroupSize value cannot be zero or negative"); + } + return new SearchParam(this); } } diff --git a/src/main/java/io/milvus/v2/service/vector/request/AnnSearchReq.java b/src/main/java/io/milvus/v2/service/vector/request/AnnSearchReq.java index 9e7a78034..a0ac05068 100644 --- a/src/main/java/io/milvus/v2/service/vector/request/AnnSearchReq.java +++ b/src/main/java/io/milvus/v2/service/vector/request/AnnSearchReq.java @@ -19,6 +19,7 @@ package io.milvus.v2.service.vector.request; +import io.milvus.v2.common.IndexParam; import io.milvus.v2.service.vector.request.data.BaseVector; import lombok.Builder; import lombok.Data; @@ -35,4 +36,7 @@ public class AnnSearchReq { private String expr = ""; private List vectors; private String params; + + @Builder.Default + private IndexParam.MetricType metricType = null; } diff --git a/src/main/java/io/milvus/v2/service/vector/request/HybridSearchReq.java b/src/main/java/io/milvus/v2/service/vector/request/HybridSearchReq.java index cf3b35f24..8eabf2bd8 100644 --- a/src/main/java/io/milvus/v2/service/vector/request/HybridSearchReq.java +++ b/src/main/java/io/milvus/v2/service/vector/request/HybridSearchReq.java @@ -42,4 +42,8 @@ public class HybridSearchReq private int roundDecimal = -1; @Builder.Default private ConsistencyLevel consistencyLevel = null; + + private String groupByFieldName; + private Integer groupSize; + private Boolean groupStrictSize; } diff --git a/src/main/java/io/milvus/v2/service/vector/request/SearchReq.java b/src/main/java/io/milvus/v2/service/vector/request/SearchReq.java index b120c1924..6c8559f50 100644 --- a/src/main/java/io/milvus/v2/service/vector/request/SearchReq.java +++ b/src/main/java/io/milvus/v2/service/vector/request/SearchReq.java @@ -20,6 +20,7 @@ package io.milvus.v2.service.vector.request; import io.milvus.v2.common.ConsistencyLevel; +import io.milvus.v2.common.IndexParam; import io.milvus.v2.service.vector.request.data.BaseVector; import lombok.Builder; import lombok.Data; @@ -38,6 +39,7 @@ public class SearchReq { private List partitionNames = new ArrayList<>(); @Builder.Default private String annsField = ""; + private IndexParam.MetricType metricType; private int topK; private String filter; @Builder.Default @@ -57,6 +59,8 @@ public class SearchReq { private ConsistencyLevel consistencyLevel = null; private boolean ignoreGrowing; private String groupByFieldName; + private Integer groupSize; + private Boolean groupStrictSize; // Expression template, to improve expression parsing performance in complicated list // Assume user has a filter = "pk > 3 and city in ["beijing", "shanghai", ......] diff --git a/src/main/java/io/milvus/v2/utils/VectorUtils.java b/src/main/java/io/milvus/v2/utils/VectorUtils.java index fb591272e..1564f8140 100644 --- a/src/main/java/io/milvus/v2/utils/VectorUtils.java +++ b/src/main/java/io/milvus/v2/utils/VectorUtils.java @@ -156,6 +156,13 @@ public SearchRequest ConvertToGrpcSearchRequest(SearchReq request) { .setValue(String.valueOf(request.getOffset())) .build()); + if (null != request.getMetricType()) { + builder.addSearchParams( + KeyValuePair.newBuilder() + .setKey(Constant.METRIC_TYPE) + .setValue(request.getMetricType().name()) + .build()); + } if (null != request.getSearchParams()) { try { @@ -176,6 +183,21 @@ public SearchRequest ConvertToGrpcSearchRequest(SearchReq request) { .setKey(Constant.GROUP_BY_FIELD) .setValue(request.getGroupByFieldName()) .build()); + if (request.getGroupSize() != null) { + builder.addSearchParams( + KeyValuePair.newBuilder() + .setKey(Constant.GROUP_SIZE) + .setValue(request.getGroupSize().toString()) + .build()); + } + + if (request.getGroupStrictSize() != null) { + builder.addSearchParams( + KeyValuePair.newBuilder() + .setKey(Constant.GROUP_STRICT_SIZE) + .setValue(request.getGroupStrictSize().toString()) + .build()); + } } if (!request.getOutputFields().isEmpty()) { @@ -287,6 +309,13 @@ public static SearchRequest convertAnnSearchParam(@NonNull AnnSearchReq annSearc .setKey(Constant.TOP_K) .setValue(String.valueOf(annSearchReq.getTopK())) .build()); + if (annSearchReq.getMetricType() != null) { + builder.addSearchParams( + KeyValuePair.newBuilder() + .setKey(Constant.METRIC_TYPE) + .setValue(annSearchReq.getMetricType().name()) + .build()); + } // params String params = "{}"; @@ -347,6 +376,29 @@ public HybridSearchRequest ConvertToGrpcHybridSearchRequest(HybridSearchReq requ propertiesList.forEach(builder::addRankParams); } + if (request.getGroupByFieldName() != null && !request.getGroupByFieldName().isEmpty()) { + builder.addRankParams( + KeyValuePair.newBuilder() + .setKey(Constant.GROUP_BY_FIELD) + .setValue(request.getGroupByFieldName()) + .build()); + if (request.getGroupSize() != null) { + builder.addRankParams( + KeyValuePair.newBuilder() + .setKey(Constant.GROUP_SIZE) + .setValue(request.getGroupSize().toString().toString()) + .build()); + } + + if (request.getGroupStrictSize() != null) { + builder.addRankParams( + KeyValuePair.newBuilder() + .setKey(Constant.GROUP_STRICT_SIZE) + .setValue(request.getGroupStrictSize().toString()) + .build()); + } + } + // output fields if (request.getOutFields() != null && !request.getOutFields().isEmpty()) { request.getOutFields().forEach(builder::addOutputFields);