diff --git a/configs/milvus.yaml b/configs/milvus.yaml index 9d6de7e00cb2e..07a59acd88dc4 100644 --- a/configs/milvus.yaml +++ b/configs/milvus.yaml @@ -1066,3 +1066,15 @@ streaming: backoffMultiplier: 2 # The multiplier of balance task trigger backoff, 2 by default txn: defaultKeepaliveTimeout: 10s # The default keepalive timeout for wal txn, 10s by default + +# Any configuration related to the knowhere vector search engine +knowhere: + enable: true # When enable this configuration, the index parameters defined following will be automatically populated as index parameters, without requiring user input. + DISKANN: # Index parameters for diskann + build: # Diskann build params + max_degree: 56 # Maximum degree of the Vamana graph + search_list_size: 100 # Size of the candidate list during building graph + pq_code_budget_gb_ratio: 0.125 # Size limit on the PQ code (compared with raw data) + search_cache_budget_gb_ratio: 0.1 # Ratio of cached node numbers to raw data + search: # Diskann search params + beam_width_ratio: 4.0 # Ratio between the maximum number of IO requests per search iteration and CPU number. \ No newline at end of file diff --git a/internal/core/src/common/Json.h b/internal/core/src/common/Json.h index 297dbcbdcca77..0570bdb56dd2c 100644 --- a/internal/core/src/common/Json.h +++ b/internal/core/src/common/Json.h @@ -133,6 +133,9 @@ class Json { value_result doc() const { + if (data_.size() == 0) { + return {}; + } thread_local simdjson::ondemand::parser parser; // it's always safe to add the padding, @@ -148,6 +151,9 @@ class Json { value_result dom_doc() const { + if (data_.size() == 0) { + return {}; + } thread_local simdjson::dom::parser parser; // it's always safe to add the padding, diff --git a/internal/core/src/storage/LocalChunkManager.cpp b/internal/core/src/storage/LocalChunkManager.cpp index 2b6870cd11893..7d093c7720d40 100644 --- a/internal/core/src/storage/LocalChunkManager.cpp +++ b/internal/core/src/storage/LocalChunkManager.cpp @@ -15,6 +15,7 @@ // limitations under the License. #include "LocalChunkManager.h" +#include "log/Log.h" #include #include @@ -232,7 +233,17 @@ LocalChunkManager::GetSizeOfDir(const std::string& dir) { it != v.end(); ++it) { if (boost::filesystem::is_regular_file(it->path())) { - total_file_size += boost::filesystem::file_size(it->path()); + boost::system::error_code ec; + auto file_size = boost::filesystem::file_size(it->path(), ec); + if (ec) { + // The file may be removed concurrently by other threads. + // So the file size cannot be obtained, just ignore it. + LOG_INFO("size of file {} cannot be obtained with error: {}", + it->path().string(), + ec.message()); + continue; + } + total_file_size += file_size; } if (boost::filesystem::is_directory(it->path())) { total_file_size += GetSizeOfDir(it->path().string()); diff --git a/internal/datacoord/task_index.go b/internal/datacoord/task_index.go index a72cd0019e610..efba15570b27d 100644 --- a/internal/datacoord/task_index.go +++ b/internal/datacoord/task_index.go @@ -29,10 +29,12 @@ import ( "github.com/milvus-io/milvus/internal/proto/workerpb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/vecindexmgr" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/indexparams" "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -159,6 +161,19 @@ func (it *indexBuildTask) PreCheck(ctx context.Context, dependency *taskSchedule fieldID := dependency.meta.indexMeta.GetFieldIDByIndexID(segIndex.CollectionID, segIndex.IndexID) binlogIDs := getBinLogIDs(segment, fieldID) + + // When new index parameters are added, these parameters need to be updated to ensure they are included during the index-building process. + if vecindexmgr.GetVecIndexMgrInstance().IsVecIndex(indexType) && Params.KnowhereConfig.Enable.GetAsBool() { + var ret error + indexParams, ret = Params.KnowhereConfig.UpdateIndexParams(GetIndexType(indexParams), paramtable.BuildStage, indexParams) + + if ret != nil { + log.Ctx(ctx).Warn("failed to update index build params defined in yaml", zap.Int64("taskID", it.taskID), zap.Error(ret)) + it.SetState(indexpb.JobState_JobStateInit, ret.Error()) + return false + } + } + if isDiskANNIndex(GetIndexType(indexParams)) { var err error indexParams, err = indexparams.UpdateDiskIndexBuildParams(Params, indexParams) diff --git a/internal/indexnode/task_index.go b/internal/indexnode/task_index.go index 7808ac43f0250..87e0f44be8684 100644 --- a/internal/indexnode/task_index.go +++ b/internal/indexnode/task_index.go @@ -210,6 +210,7 @@ func (it *indexBuildTask) Execute(ctx context.Context) error { zap.Int32("currentIndexVersion", it.req.GetCurrentIndexVersion())) indexType := it.newIndexParams[common.IndexTypeKey] + var fieldDataSize uint64 if vecindexmgr.GetVecIndexMgrInstance().IsDiskANN(indexType) { // check index node support disk index if !Params.IndexNodeCfg.EnableDisk.GetAsBool() { @@ -225,7 +226,7 @@ func (it *indexBuildTask) Execute(ctx context.Context) error { log.Warn("IndexNode get local used size failed") return err } - fieldDataSize, err := estimateFieldDataSize(it.req.GetDim(), it.req.GetNumRows(), it.req.GetField().GetDataType()) + fieldDataSize, err = estimateFieldDataSize(it.req.GetDim(), it.req.GetNumRows(), it.req.GetField().GetDataType()) if err != nil { log.Warn("IndexNode get local used size failed") return err @@ -247,6 +248,11 @@ func (it *indexBuildTask) Execute(ctx context.Context) error { } } + // system resource-related parameters, such as memory limits, CPU limits, and disk limits, are appended here to the parameter list + if vecindexmgr.GetVecIndexMgrInstance().IsVecIndex(indexType) && Params.KnowhereConfig.Enable.GetAsBool() { + it.newIndexParams, _ = Params.KnowhereConfig.MergeResourceParams(fieldDataSize, paramtable.BuildStage, it.newIndexParams) + } + storageConfig := &indexcgopb.StorageConfig{ Address: it.req.GetStorageConfig().GetAddress(), AccessKeyID: it.req.GetStorageConfig().GetAccessKeyID(), diff --git a/internal/parser/planparserv2/parser_visitor.go b/internal/parser/planparserv2/parser_visitor.go index 24f437feb7d52..ffc4f13623459 100644 --- a/internal/parser/planparserv2/parser_visitor.go +++ b/internal/parser/planparserv2/parser_visitor.go @@ -545,8 +545,7 @@ func (v *ParserVisitor) VisitTerm(ctx *parser.TermContext) interface{} { } else { elementValue := valueExpr.GetValue() if elementValue == nil { - return fmt.Errorf( - "contains_any operation are only supported explicitly specified element, got: %s", ctx.Expr(1).GetText()) + return fmt.Errorf("value '%s' in list cannot be a non-const expression", ctx.Expr(1).GetText()) } if !IsArray(elementValue) { @@ -662,12 +661,12 @@ func (v *ParserVisitor) VisitRange(ctx *parser.RangeContext) interface{} { lowerValue := lowerValueExpr.GetValue() upperValue := upperValueExpr.GetValue() if !isTemplateExpr(lowerValueExpr) { - if err = checkRangeCompared(fieldDataType, lowerValue); err != nil { + if lowerValue, err = castRangeValue(fieldDataType, lowerValue); err != nil { return err } } if !isTemplateExpr(upperValueExpr) { - if err = checkRangeCompared(fieldDataType, upperValue); err != nil { + if upperValue, err = castRangeValue(fieldDataType, upperValue); err != nil { return err } } @@ -744,12 +743,12 @@ func (v *ParserVisitor) VisitReverseRange(ctx *parser.ReverseRangeContext) inter lowerValue := lowerValueExpr.GetValue() upperValue := upperValueExpr.GetValue() if !isTemplateExpr(lowerValueExpr) { - if err = checkRangeCompared(fieldDataType, lowerValue); err != nil { + if lowerValue, err = castRangeValue(fieldDataType, lowerValue); err != nil { return err } } if !isTemplateExpr(upperValueExpr) { - if err = checkRangeCompared(fieldDataType, upperValue); err != nil { + if upperValue, err = castRangeValue(fieldDataType, upperValue); err != nil { return err } } diff --git a/internal/parser/planparserv2/plan_parser_v2_test.go b/internal/parser/planparserv2/plan_parser_v2_test.go index d3adb5b36577c..17cca040e0ffa 100644 --- a/internal/parser/planparserv2/plan_parser_v2_test.go +++ b/internal/parser/planparserv2/plan_parser_v2_test.go @@ -274,6 +274,28 @@ func TestExpr_BinaryRange(t *testing.T) { } } +func TestExpr_castValue(t *testing.T) { + schema := newTestSchema() + helper, err := typeutil.CreateSchemaHelper(schema) + assert.NoError(t, err) + + exprStr := `Int64Field + 1.1 == 2.1` + expr, err := ParseExpr(helper, exprStr, nil) + assert.NoError(t, err, exprStr) + assert.NotNil(t, expr, exprStr) + assert.NotNil(t, expr.GetBinaryArithOpEvalRangeExpr()) + assert.NotNil(t, expr.GetBinaryArithOpEvalRangeExpr().GetRightOperand().GetFloatVal()) + assert.NotNil(t, expr.GetBinaryArithOpEvalRangeExpr().GetValue().GetFloatVal()) + + exprStr = `FloatField +1 == 2` + expr, err = ParseExpr(helper, exprStr, nil) + assert.NoError(t, err, exprStr) + assert.NotNil(t, expr, exprStr) + assert.NotNil(t, expr.GetBinaryArithOpEvalRangeExpr()) + assert.NotNil(t, expr.GetBinaryArithOpEvalRangeExpr().GetRightOperand().GetFloatVal()) + assert.NotNil(t, expr.GetBinaryArithOpEvalRangeExpr().GetValue().GetFloatVal()) +} + func TestExpr_BinaryArith(t *testing.T) { schema := newTestSchema() helper, err := typeutil.CreateSchemaHelper(schema) @@ -283,7 +305,6 @@ func TestExpr_BinaryArith(t *testing.T) { `Int64Field % 10 == 9`, `Int64Field % 10 != 9`, `FloatField + 1.1 == 2.1`, - `Int64Field + 1.1 == 2.1`, `A % 10 != 2`, `Int8Field + 1 < 2`, `Int16Field - 3 <= 4`, diff --git a/internal/parser/planparserv2/utils.go b/internal/parser/planparserv2/utils.go index e61bbd237c9bc..4faef470dd7c1 100644 --- a/internal/parser/planparserv2/utils.go +++ b/internal/parser/planparserv2/utils.go @@ -241,13 +241,22 @@ func castValue(dataType schemapb.DataType, value *planpb.GenericValue) (*planpb. return nil, fmt.Errorf("cannot cast value to %s, value: %s", dataType.String(), value) } -func combineBinaryArithExpr(op planpb.OpType, arithOp planpb.ArithOpType, columnInfo *planpb.ColumnInfo, operandExpr, valueExpr *planpb.ValueExpr) *planpb.Expr { +func combineBinaryArithExpr(op planpb.OpType, arithOp planpb.ArithOpType, arithExprDataType schemapb.DataType, columnInfo *planpb.ColumnInfo, operandExpr, valueExpr *planpb.ValueExpr) (*planpb.Expr, error) { + var err error + operand := operandExpr.GetValue() + if !isTemplateExpr(operandExpr) { + operand, err = castValue(arithExprDataType, operand) + if err != nil { + return nil, err + } + } + return &planpb.Expr{ Expr: &planpb.Expr_BinaryArithOpEvalRangeExpr{ BinaryArithOpEvalRangeExpr: &planpb.BinaryArithOpEvalRangeExpr{ ColumnInfo: columnInfo, ArithOp: arithOp, - RightOperand: operandExpr.GetValue(), + RightOperand: operand, Op: op, Value: valueExpr.GetValue(), OperandTemplateVariableName: operandExpr.GetTemplateVariableName(), @@ -255,7 +264,7 @@ func combineBinaryArithExpr(op planpb.OpType, arithOp planpb.ArithOpType, column }, }, IsTemplate: isTemplateExpr(operandExpr) || isTemplateExpr(valueExpr), - } + }, nil } func combineArrayLengthExpr(op planpb.OpType, arithOp planpb.ArithOpType, columnInfo *planpb.ColumnInfo, valueExpr *planpb.ValueExpr) (*planpb.Expr, error) { @@ -297,7 +306,7 @@ func handleBinaryArithExpr(op planpb.OpType, arithExpr *planpb.BinaryArithExpr, // a * 2 == 3 // a / 2 == 3 // a % 2 == 3 - return combineBinaryArithExpr(op, arithOp, leftExpr.GetInfo(), rightValue, valueExpr), nil + return combineBinaryArithExpr(op, arithOp, arithExprDataType, leftExpr.GetInfo(), rightValue, valueExpr) } else if rightExpr != nil && leftValue != nil { // 2 + a == 3 // 2 - a == 3 @@ -307,7 +316,7 @@ func handleBinaryArithExpr(op planpb.OpType, arithExpr *planpb.BinaryArithExpr, switch arithExpr.GetOp() { case planpb.ArithOpType_Add, planpb.ArithOpType_Mul: - return combineBinaryArithExpr(op, arithOp, rightExpr.GetInfo(), leftValue, valueExpr), nil + return combineBinaryArithExpr(op, arithOp, arithExprDataType, rightExpr.GetInfo(), leftValue, valueExpr) default: return nil, fmt.Errorf("module field is not yet supported") } @@ -625,24 +634,27 @@ func checkValidModArith(tokenType planpb.ArithOpType, leftType, leftElementType, return nil } -func checkRangeCompared(dataType schemapb.DataType, value *planpb.GenericValue) error { +func castRangeValue(dataType schemapb.DataType, value *planpb.GenericValue) (*planpb.GenericValue, error) { switch dataType { case schemapb.DataType_String, schemapb.DataType_VarChar: if !IsString(value) { - return fmt.Errorf("invalid range operations") + return nil, fmt.Errorf("invalid range operations") } case schemapb.DataType_Bool: - return fmt.Errorf("invalid range operations on boolean expr") + return nil, fmt.Errorf("invalid range operations on boolean expr") case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32, schemapb.DataType_Int64: if !IsInteger(value) { - return fmt.Errorf("invalid range operations") + return nil, fmt.Errorf("invalid range operations") } case schemapb.DataType_Float, schemapb.DataType_Double: if !IsNumber(value) { - return fmt.Errorf("invalid range operations") + return nil, fmt.Errorf("invalid range operations") + } + if IsInteger(value) { + return NewFloat(float64(value.GetInt64Val())), nil } } - return nil + return value, nil } func checkContainsElement(columnExpr *ExprWithType, op planpb.JSONContainsExpr_JSONOp, elementValue *planpb.GenericValue) error { diff --git a/internal/proxy/task_index.go b/internal/proxy/task_index.go index f904a708de65b..3b69f59d4a4bb 100644 --- a/internal/proxy/task_index.go +++ b/internal/proxy/task_index.go @@ -341,6 +341,14 @@ func (cit *createIndexTask) parseIndexParams() error { if !exist { return fmt.Errorf("IndexType not specified") } + // index parameters defined in the YAML file are merged with the user-provided parameters during create stage + if Params.KnowhereConfig.Enable.GetAsBool() { + var err error + indexParamsMap, err = Params.KnowhereConfig.MergeIndexParams(indexType, paramtable.BuildStage, indexParamsMap) + if err != nil { + return err + } + } if vecindexmgr.GetVecIndexMgrInstance().IsDiskANN(indexType) { err := indexparams.FillDiskIndexParams(Params, indexParamsMap) if err != nil { diff --git a/internal/proxy/task_index_test.go b/internal/proxy/task_index_test.go index 8d056fadb069b..b5aebd59d595d 100644 --- a/internal/proxy/task_index_test.go +++ b/internal/proxy/task_index_test.go @@ -39,6 +39,7 @@ import ( "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/metric" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -1053,6 +1054,43 @@ func Test_parseIndexParams(t *testing.T) { err := cit.parseIndexParams() assert.Error(t, err) }) + + t.Run("verify merge params with yaml", func(t *testing.T) { + paramtable.Init() + Params.Save("knowhere.HNSW.build.M", "3000") + Params.Save("knowhere.HNSW.build.efConstruction", "120") + defer Params.Reset("knowhere.HNSW.build.M") + defer Params.Reset("knowhere.HNSW.build.efConstruction") + + cit := &createIndexTask{ + Condition: nil, + req: &milvuspb.CreateIndexRequest{ + ExtraParams: []*commonpb.KeyValuePair{ + { + Key: common.IndexTypeKey, + Value: "HNSW", + }, + { + Key: common.MetricTypeKey, + Value: metric.L2, + }, + }, + IndexName: "", + }, + fieldSchema: &schemapb.FieldSchema{ + FieldID: 101, + Name: "FieldVector", + IsPrimaryKey: false, + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: common.DimKey, Value: "768"}, + }, + }, + } + err := cit.parseIndexParams() + // Out of range in json: param 'M' (3000) should be in range [2, 2048] + assert.Error(t, err) + }) } func Test_wrapUserIndexParams(t *testing.T) { diff --git a/pkg/config/config.go b/pkg/config/config.go index fc93c086f74d4..8b9f3cfe108c0 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -20,6 +20,7 @@ import ( "strings" "github.com/cockroachdb/errors" + "gopkg.in/yaml.v3" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -30,6 +31,10 @@ var ( ErrKeyNotFound = errors.New("key not found") ) +const ( + NotFormatPrefix = "knowhere." +) + func Init(opts ...Option) (*Manager, error) { o := &Options{} for _, opt := range opts { @@ -55,7 +60,17 @@ func Init(opts ...Option) (*Manager, error) { var formattedKeys = typeutil.NewConcurrentMap[string, string]() +func lowerKey(key string) string { + if strings.HasPrefix(key, NotFormatPrefix) { + return key + } + return strings.ToLower(key) +} + func formatKey(key string) string { + if strings.HasPrefix(key, NotFormatPrefix) { + return key + } cached, ok := formattedKeys.Get(key) if ok { return cached @@ -64,3 +79,43 @@ func formatKey(key string) string { formattedKeys.Insert(key, result) return result } + +func flattenNode(node *yaml.Node, parentKey string, result map[string]string) { + // The content of the node should contain key-value pairs in a MappingNode + if node.Kind == yaml.MappingNode { + for i := 0; i < len(node.Content); i += 2 { + keyNode := node.Content[i] + valueNode := node.Content[i+1] + + key := keyNode.Value + // Construct the full key with parent hierarchy + fullKey := key + if parentKey != "" { + fullKey = parentKey + "." + key + } + + switch valueNode.Kind { + case yaml.ScalarNode: + // Scalar value, store it as a string + result[lowerKey(fullKey)] = valueNode.Value + result[formatKey(fullKey)] = valueNode.Value + case yaml.MappingNode: + // Nested map, process recursively + flattenNode(valueNode, fullKey, result) + case yaml.SequenceNode: + // List (sequence), process elements + var listStr string + for j, item := range valueNode.Content { + if j > 0 { + listStr += "," + } + if item.Kind == yaml.ScalarNode { + listStr += item.Value + } + } + result[lowerKey(fullKey)] = listStr + result[formatKey(fullKey)] = listStr + } + } + } +} diff --git a/pkg/config/file_source.go b/pkg/config/file_source.go index e8402efe6b6ad..4eace878d890f 100644 --- a/pkg/config/file_source.go +++ b/pkg/config/file_source.go @@ -17,14 +17,16 @@ package config import ( + "bytes" + "fmt" "os" + "path/filepath" "sync" "github.com/cockroachdb/errors" "github.com/samber/lo" - "github.com/spf13/cast" - "github.com/spf13/viper" "go.uber.org/zap" + "gopkg.in/yaml.v3" "github.com/milvus-io/milvus/pkg/log" ) @@ -115,7 +117,6 @@ func (fs *FileSource) UpdateOptions(opts Options) { } func (fs *FileSource) loadFromFile() error { - yamlReader := viper.New() newConfig := make(map[string]string) var configFiles []string @@ -128,37 +129,35 @@ func (fs *FileSource) loadFromFile() error { continue } - yamlReader.SetConfigFile(configFile) - if err := yamlReader.ReadInConfig(); err != nil { + ext := filepath.Ext(configFile) + if len(ext) == 0 || ext[1:] != "yaml" { + return fmt.Errorf("Unsupported Config Type: " + ext) + } + + data, err := os.ReadFile(configFile) + if err != nil { return errors.Wrap(err, "Read config failed: "+configFile) } - for _, key := range yamlReader.AllKeys() { - val := yamlReader.Get(key) - str, err := cast.ToStringE(val) - if err != nil { - switch val := val.(type) { - case []any: - str = str[:0] - for _, v := range val { - ss, err := cast.ToStringE(v) - if err != nil { - log.Warn("cast to string failed", zap.Any("value", v)) - } - if str == "" { - str = ss - } else { - str = str + "," + ss - } - } - - default: - log.Warn("val is not a slice", zap.Any("value", val)) - continue - } - } - newConfig[key] = str - newConfig[formatKey(key)] = str + // handle empty file + if len(data) == 0 { + continue + } + + var node yaml.Node + decoder := yaml.NewDecoder(bytes.NewReader(data)) + if err := decoder.Decode(&node); err != nil { + return errors.Wrap(err, "YAML unmarshal failed: "+configFile) + } + + if node.Kind == yaml.DocumentNode && len(node.Content) > 0 { + // Get the content of the Document Node + contentNode := node.Content[0] + + // Recursively process the content of the Document Node + flattenNode(contentNode, "", newConfig) + } else if node.Kind == yaml.MappingNode { + flattenNode(&node, "", newConfig) } } diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index c8ffa4babd23c..4dcda7cf28c1b 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -69,17 +69,18 @@ type ComponentParam struct { GpuConfig gpuConfig TraceCfg traceConfig - RootCoordCfg rootCoordConfig - ProxyCfg proxyConfig - QueryCoordCfg queryCoordConfig - QueryNodeCfg queryNodeConfig - DataCoordCfg dataCoordConfig - DataNodeCfg dataNodeConfig - IndexNodeCfg indexNodeConfig - HTTPCfg httpConfig - LogCfg logConfig - RoleCfg roleConfig - StreamingCfg streamingConfig + RootCoordCfg rootCoordConfig + ProxyCfg proxyConfig + QueryCoordCfg queryCoordConfig + QueryNodeCfg queryNodeConfig + DataCoordCfg dataCoordConfig + DataNodeCfg dataNodeConfig + IndexNodeCfg indexNodeConfig + KnowhereConfig knowhereConfig + HTTPCfg httpConfig + LogCfg logConfig + RoleCfg roleConfig + StreamingCfg streamingConfig RootCoordGrpcServerCfg GrpcServerConfig ProxyGrpcServerCfg GrpcServerConfig @@ -134,6 +135,7 @@ func (p *ComponentParam) init(bt *BaseTable) { p.LogCfg.init(bt) p.RoleCfg.init(bt) p.GpuConfig.init(bt) + p.KnowhereConfig.init(bt) p.RootCoordGrpcServerCfg.Init("rootCoord", bt) p.ProxyGrpcServerCfg.Init("proxy", bt) diff --git a/pkg/util/paramtable/knowhere_param.go b/pkg/util/paramtable/knowhere_param.go new file mode 100644 index 0000000000000..035631fac5153 --- /dev/null +++ b/pkg/util/paramtable/knowhere_param.go @@ -0,0 +1,118 @@ +package paramtable + +import ( + "fmt" + "strconv" + "strings" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus/pkg/util/hardware" +) + +type knowhereConfig struct { + Enable ParamItem `refreshable:"true"` + IndexParam ParamGroup `refreshable:"true"` +} + +const ( + BuildStage = "build" + LoadStage = "load" + SearchStage = "search" +) + +const ( + BuildDramBudgetKey = "build_dram_budget_gb" + NumBuildThreadKey = "num_build_thread" + VecFieldSizeKey = "vec_field_size_gb" +) + +func (p *knowhereConfig) init(base *BaseTable) { + p.IndexParam = ParamGroup{ + KeyPrefix: "knowhere.", + Version: "2.5.0", + } + p.IndexParam.Init(base.mgr) + + p.Enable = ParamItem{ + Key: "knowhere.enable", + Version: "2.5.0", + DefaultValue: "true", + } + p.Enable.Init(base.mgr) +} + +func (p *knowhereConfig) getIndexParam(indexType string, stage string) map[string]string { + matchedParam := make(map[string]string) + + params := p.IndexParam.GetValue() + prefix := indexType + "." + stage + "." + + for k, v := range params { + if strings.HasPrefix(k, prefix) { + matchedParam[strings.TrimPrefix(k, prefix)] = v + } + } + + return matchedParam +} + +func GetKeyFromSlice(indexParams []*commonpb.KeyValuePair, key string) string { + for _, param := range indexParams { + if param.Key == key { + return param.Value + } + } + return "" +} + +func (p *knowhereConfig) GetRuntimeParameter(stage string) (map[string]string, error) { + params := make(map[string]string) + + if stage == BuildStage { + params[BuildDramBudgetKey] = fmt.Sprintf("%f", float32(hardware.GetFreeMemoryCount())/(1<<30)) + params[NumBuildThreadKey] = strconv.Itoa(int(float32(hardware.GetCPUNum()))) + } + + return params, nil +} + +func (p *knowhereConfig) UpdateIndexParams(indexType string, stage string, indexParams []*commonpb.KeyValuePair) ([]*commonpb.KeyValuePair, error) { + defaultParams := p.getIndexParam(indexType, stage) + + for key, val := range defaultParams { + if GetKeyFromSlice(indexParams, key) == "" { + indexParams = append(indexParams, + &commonpb.KeyValuePair{ + Key: key, + Value: val, + }) + } + } + + return indexParams, nil +} + +func (p *knowhereConfig) MergeIndexParams(indexType string, stage string, indexParam map[string]string) (map[string]string, error) { + defaultParams := p.getIndexParam(indexType, stage) + + for key, val := range defaultParams { + _, existed := indexParam[key] + if !existed { + indexParam[key] = val + } + } + + return indexParam, nil +} + +func (p *knowhereConfig) MergeResourceParams(vecFieldSize uint64, stage string, indexParam map[string]string) (map[string]string, error) { + param, _ := p.GetRuntimeParameter(stage) + + for key, val := range param { + indexParam[key] = val + } + + indexParam[VecFieldSizeKey] = fmt.Sprintf("%f", float32(vecFieldSize)/(1<<30)) + + return indexParam, nil +} diff --git a/pkg/util/paramtable/knowhere_param_test.go b/pkg/util/paramtable/knowhere_param_test.go new file mode 100644 index 0000000000000..87cf771b20346 --- /dev/null +++ b/pkg/util/paramtable/knowhere_param_test.go @@ -0,0 +1,243 @@ +package paramtable + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" +) + +func TestKnowhereConfig_GetIndexParam(t *testing.T) { + bt := NewBaseTable(SkipRemote(true)) + cfg := &knowhereConfig{} + cfg.init(bt) + + // Set some initial config + indexParams := map[string]interface{}{ + "knowhere.IVF_FLAT.build.nlist": 1024, + "knowhere.HNSW.build.efConstruction": 360, + "knowhere.DISKANN.search.search_list": 100, + } + + for key, val := range indexParams { + valStr, _ := json.Marshal(val) + bt.Save(key, string(valStr)) + } + + tests := []struct { + name string + indexType string + stage string + expectedKey string + expectedValue string + }{ + { + name: "IVF_FLAT Build", + indexType: "IVF_FLAT", + stage: BuildStage, + expectedKey: "nlist", + expectedValue: "1024", + }, + { + name: "HNSW Build", + indexType: "HNSW", + stage: BuildStage, + expectedKey: "efConstruction", + expectedValue: "360", + }, + { + name: "DISKANN Search", + indexType: "DISKANN", + stage: SearchStage, + expectedKey: "search_list", + expectedValue: "100", + }, + { + name: "Non-existent", + indexType: "NON_EXISTENT", + stage: BuildStage, + expectedKey: "", + expectedValue: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := cfg.getIndexParam(tt.indexType, tt.stage) + if tt.expectedKey != "" { + assert.Contains(t, result, tt.expectedKey, "The result should contain the expected key") + assert.Equal(t, tt.expectedValue, result[tt.expectedKey], "The value for the key should match the expected value") + } else { + assert.Empty(t, result, "The result should be empty for non-existent index type") + } + }) + } +} + +func TestKnowhereConfig_GetRuntimeParameter(t *testing.T) { + cfg := &knowhereConfig{} + + params, err := cfg.GetRuntimeParameter(BuildStage) + assert.NoError(t, err) + assert.Contains(t, params, BuildDramBudgetKey) + assert.Contains(t, params, NumBuildThreadKey) + + params, err = cfg.GetRuntimeParameter(SearchStage) + assert.NoError(t, err) + assert.Empty(t, params) +} + +func TestKnowhereConfig_UpdateParameter(t *testing.T) { + bt := NewBaseTable(SkipRemote(true)) + cfg := &knowhereConfig{} + cfg.init(bt) + + // Set some initial config + indexParams := map[string]interface{}{ + "knowhere.IVF_FLAT.build.nlist": 1024, + "knowhere.IVF_FLAT.build.num_build_thread": 12, + } + + for key, val := range indexParams { + valStr, _ := json.Marshal(val) + bt.Save(key, string(valStr)) + } + + tests := []struct { + name string + indexType string + stage string + inputParams []*commonpb.KeyValuePair + expectedParams map[string]string + }{ + { + name: "IVF_FLAT Build", + indexType: "IVF_FLAT", + stage: BuildStage, + inputParams: []*commonpb.KeyValuePair{ + {Key: "nlist", Value: "128"}, + {Key: "existing_key", Value: "existing_value"}, + }, + expectedParams: map[string]string{ + "existing_key": "existing_value", + "nlist": "128", + "num_build_thread": "12", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := cfg.UpdateIndexParams(tt.indexType, tt.stage, tt.inputParams) + assert.NoError(t, err) + + for key, expectedValue := range tt.expectedParams { + assert.Equal(t, expectedValue, GetKeyFromSlice(result, key), "The value for key %s should match the expected value", key) + } + }) + } +} + +func TestKnowhereConfig_MergeParameter(t *testing.T) { + bt := NewBaseTable(SkipRemote(true)) + cfg := &knowhereConfig{} + cfg.init(bt) + + indexParams := map[string]interface{}{ + "knowhere.IVF_FLAT.build.nlist": 1024, + "knowhere.IVF_FLAT.build.num_build_thread": 12, + } + + for key, val := range indexParams { + valStr, _ := json.Marshal(val) + bt.Save(key, string(valStr)) + } + + tests := []struct { + name string + indexType string + stage string + inputParams map[string]string + expectedParams map[string]string + }{ + { + name: "IVF_FLAT Build", + indexType: "IVF_FLAT", + stage: BuildStage, + inputParams: map[string]string{ + "nlist": "128", + "existing_key": "existing_value", + }, + expectedParams: map[string]string{ + "existing_key": "existing_value", + "nlist": "128", + "num_build_thread": "12", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := cfg.MergeIndexParams(tt.indexType, tt.stage, tt.inputParams) + assert.NoError(t, err) + + for key, expectedValue := range tt.expectedParams { + assert.Equal(t, expectedValue, result[key], "The value for key %s should match the expected value", key) + } + }) + } +} + +func TestKnowhereConfig_MergeWithResource(t *testing.T) { + cfg := &knowhereConfig{} + + tests := []struct { + name string + vecFieldSize uint64 + inputParams map[string]string + expectedParams map[string]string + }{ + { + name: "Merge with resource", + vecFieldSize: 1024 * 1024 * 1024, + inputParams: map[string]string{ + "existing_key": "existing_value", + }, + expectedParams: map[string]string{ + "existing_key": "existing_value", + BuildDramBudgetKey: "", // We can't predict the exact value, but it should exist + NumBuildThreadKey: "", // We can't predict the exact value, but it should exist + VecFieldSizeKey: "1.000000", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := cfg.MergeResourceParams(tt.vecFieldSize, BuildStage, tt.inputParams) + assert.NoError(t, err) + + for key, expectedValue := range tt.expectedParams { + if expectedValue != "" { + assert.Equal(t, expectedValue, result[key], "The value for key %s should match the expected value", key) + } else { + assert.Contains(t, result, key, "The result should contain the key %s", key) + assert.NotEmpty(t, result[key], "The value for key %s should not be empty", key) + } + } + }) + } +} + +func TestGetKeyFromSlice(t *testing.T) { + indexParams := []*commonpb.KeyValuePair{ + {Key: "key1", Value: "value1"}, + {Key: "key2", Value: "value2"}, + } + + assert.Equal(t, "value1", GetKeyFromSlice(indexParams, "key1")) + assert.Equal(t, "value2", GetKeyFromSlice(indexParams, "key2")) + assert.Equal(t, "", GetKeyFromSlice(indexParams, "non_existent_key")) +}