Skip to content

Commit

Permalink
add config check
Browse files Browse the repository at this point in the history
Signed-off-by: xianliang.li <[email protected]>
  • Loading branch information
foxspy committed Oct 15, 2024
1 parent 0ec434d commit 1aa6181
Show file tree
Hide file tree
Showing 13 changed files with 253 additions and 66 deletions.
2 changes: 2 additions & 0 deletions include/knowhere/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ const float defaultRangeFilter = 1.0f / 0.0;

class BaseConfig : public Config {
public:
CFG_INT dim; // just used for config verify
CFG_STRING metric_type;
CFG_INT k;
CFG_INT num_build_thread;
Expand Down Expand Up @@ -535,6 +536,7 @@ class BaseConfig : public Config {
CFG_FLOAT bm25_b;
CFG_FLOAT bm25_avgdl;
KNOHWERE_DECLARE_CONFIG(BaseConfig) {
KNOWHERE_CONFIG_DECLARE_FIELD(dim).allow_empty_without_default().description("vector dim").for_train();
KNOWHERE_CONFIG_DECLARE_FIELD(metric_type)
.set_default("L2")
.description("metric type")
Expand Down
14 changes: 14 additions & 0 deletions include/knowhere/index/index_static.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ struct Resource {
DEFINE_HAS_STATIC_FUNC(StaticCreateConfig)
DEFINE_HAS_STATIC_FUNC(StaticEstimateLoadResource)
DEFINE_HAS_STATIC_FUNC(StaticHasRawData)
DEFINE_HAS_STATIC_FUNC(StaticConfigCheck)

template <typename DataType>
class IndexStaticFaced {
Expand All @@ -61,6 +62,10 @@ class IndexStaticFaced {
static std::unique_ptr<BaseConfig>
CreateConfig(const knowhere::IndexType& indexType, const knowhere::IndexVersion& version);

static knowhere::Status
ConfigCheck(const knowhere::IndexType& indexType, const knowhere::IndexVersion& version,
const knowhere::Json& params, std::string& msg);

/**
* @brief estimate the memory and disk resource usage before index loading by index params
* @param indexType vector index type (HNSW, IVFFLAT, etc)
Expand Down Expand Up @@ -103,6 +108,11 @@ class IndexStaticFaced {
staticHasRawDataMap[indexType] = VecIndexNode::StaticHasRawData;
}

if constexpr (has_static_StaticConfigCheck<VecIndexNode,
decltype(IndexStaticFaced<DataType>::InternalConfigCheck)>::value) {
staticConfigCheckMap[indexType] = VecIndexNode::StaticConfigCheck;
}

return Instance();
}

Expand All @@ -117,12 +127,16 @@ class IndexStaticFaced {
static bool
InternalStaticHasRawData(const knowhere::BaseConfig& config, const IndexVersion& version);

static knowhere::Status
InternalConfigCheck(const knowhere::BaseConfig& config, const IndexVersion& version, std::string& msg);

static std::unique_ptr<BaseConfig>
InternalStaticCreateConfig();

std::map<std::string, std::function<decltype(InternalStaticCreateConfig)>> staticCreateConfigMap;
std::map<std::string, std::function<decltype(InternalStaticHasRawData)>> staticHasRawDataMap;
std::map<std::string, std::function<decltype(InternalEstimateLoadResource)>> staticEstimateLoadResourceMap;
std::map<std::string, std::function<decltype(InternalConfigCheck)>> staticConfigCheckMap;
};

} // namespace knowhere
Expand Down
75 changes: 20 additions & 55 deletions src/common/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "index/hnsw/hnsw_config.h"
#include "index/ivf/ivf_config.h"
#include "index/sparse/sparse_inverted_index_config.h"
#include "knowhere/index/index_factory.h"
#include "knowhere/log.h"

namespace knowhere {
Expand Down Expand Up @@ -86,12 +87,26 @@ Config::FormatAndCheck(const Config& cfg, Json& json, std::string* const err_msg
if (json.find(it.first) != json.end() && json[it.first].is_string()) {
if (std::get_if<Entry<CFG_INT>>(&var)) {
std::string::size_type sz;
auto value_str = json[it.first].get<std::string>();
CFG_INT::value_type v = std::stoi(value_str.c_str(), &sz);
if (sz < value_str.length()) {
KNOWHERE_THROW_MSG(std::string("wrong data type in json ") + value_str);
auto key_str = it.first;
auto value_str = json[key_str].get<std::string>();
try {
int64_t v = std::stoll(value_str, &sz);
if (sz < value_str.length()) {
KNOWHERE_THROW_MSG("wrong data type in json, key: '" + key_str + "', value: '" + value_str +
"'");
}
if (v < std::numeric_limits<CFG_INT::value_type>::min() ||
v > std::numeric_limits<CFG_INT::value_type>::max()) {
*err_msg = "integer value out of range, key: '" + key_str + "', value: '" + value_str + "'";
return knowhere::Status::invalid_value_in_json;
}
json[key_str] = static_cast<CFG_INT::value_type>(v);
} catch (const std::out_of_range&) {
*err_msg = "integer value out of range, key: '" + key_str + "', value: '" + value_str + "'";
return knowhere::Status::invalid_value_in_json;
} catch (const std::invalid_argument&) {
KNOWHERE_THROW_MSG("invalid integer value, key: '" + key_str + "', value: '" + value_str + "'");
}
json[it.first] = v;
}
if (std::get_if<Entry<CFG_FLOAT>>(&var)) {
CFG_FLOAT::value_type v = std::stof(json[it.first].get<std::string>().c_str());
Expand Down Expand Up @@ -119,53 +134,3 @@ Config::FormatAndCheck(const Config& cfg, Json& json, std::string* const err_msg
}

} // namespace knowhere

extern "C" __attribute__((visibility("default"))) int
CheckConfig(int index_type, char const* str, int n, int param_type);

int
CheckConfig(int index_type, const char* str, int n, int param_type) {
if (!str || n <= 0) {
return int(knowhere::Status::invalid_args);
}
knowhere::Json json = knowhere::Json::parse(str, str + n);
std::unique_ptr<knowhere::Config> cfg;

switch (index_type) {
case 0:
cfg = std::make_unique<knowhere::FlatConfig>();
break;
case 1:
cfg = std::make_unique<knowhere::DiskANNConfig>();
break;
case 2:
cfg = std::make_unique<knowhere::HnswConfig>();
break;
case 3:
cfg = std::make_unique<knowhere::IvfFlatConfig>();
break;
case 4:
cfg = std::make_unique<knowhere::IvfPqConfig>();
break;
case 5:
cfg = std::make_unique<knowhere::GpuRaftCagraConfig>();
break;
case 6:
cfg = std::make_unique<knowhere::GpuRaftIvfPqConfig>();
break;
case 7:
cfg = std::make_unique<knowhere::GpuRaftIvfFlatConfig>();
break;
case 8:
cfg = std::make_unique<knowhere::GpuRaftBruteForceConfig>();
break;
default:
return int(knowhere::Status::invalid_args);
}

auto res = knowhere::Config::FormatAndCheck(*cfg, json, nullptr);
if (res != knowhere::Status::success) {
return int(res);
}
return int(knowhere::Config::Load(*cfg, json, knowhere::PARAM_TYPE(param_type), nullptr));
}
15 changes: 14 additions & 1 deletion src/index/gpu_raft/gpu_raft_brute_force_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,20 @@

namespace knowhere {

struct GpuRaftBruteForceConfig : public BaseConfig {};
struct GpuRaftBruteForceConfig : public BaseConfig {
Status
CheckAndAdjust(PARAM_TYPE param_type, std::string* err_msg) override {
if (param_type == PARAM_TYPE::TRAIN) {
auto legal_metric_list = std::vector<std::string>{"L2", "IP"};
std::string metric = metric_type.value();
if (std::find(legal_metric_list.begin(), legal_metric_list.end(), metric) == legal_metric_list.end()) {
*err_msg = "metric type " + metric + " not found or not supported, supported: [L2 IP]";
return Status::invalid_metric_type;
}
}
return Status::success;
}
};

[[nodiscard]] inline auto
to_raft_knowhere_config(GpuRaftBruteForceConfig const& cfg) {
Expand Down
2 changes: 1 addition & 1 deletion src/index/gpu_raft/gpu_raft_cagra_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ struct GpuRaftCagraConfig : public BaseConfig {
KNOWHERE_CONFIG_DECLARE_FIELD(max_queries).description("maximum batch size").set_default(0).for_search();
KNOWHERE_CONFIG_DECLARE_FIELD(build_algo)
.description("algorithm used to build knn graph")
.set_default("IVF_PQ")
.set_default("NN_DESCENT")
.for_train();
KNOWHERE_CONFIG_DECLARE_FIELD(search_algo)
.description("algorithm used for search")
Expand Down
13 changes: 13 additions & 0 deletions src/index/gpu_raft/gpu_raft_ivf_flat_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,19 @@ struct GpuRaftIvfFlatConfig : public IvfFlatConfig {
.set_default(false)
.for_train();
}

Status
CheckAndAdjust(PARAM_TYPE param_type, std::string* err_msg) override {
if (param_type == PARAM_TYPE::TRAIN) {
auto legal_metric_list = std::vector<std::string>{"L2", "IP"};
std::string metric = metric_type.value();
if (std::find(legal_metric_list.begin(), legal_metric_list.end(), metric) == legal_metric_list.end()) {
*err_msg = "metric type " + metric + " not found or not supported, supported: [L2 IP]";
return Status::invalid_metric_type;
}
}
return Status::success;
}
};

[[nodiscard]] inline auto
Expand Down
13 changes: 13 additions & 0 deletions src/index/gpu_raft/gpu_raft_ivf_pq_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,19 @@ struct GpuRaftIvfPqConfig : public IvfPqConfig {
.set_default(1.0f)
.for_search();
}

Status
CheckAndAdjust(PARAM_TYPE param_type, std::string* err_msg) override {
if (param_type == PARAM_TYPE::TRAIN) {
auto legal_metric_list = std::vector<std::string>{"L2", "IP"};
std::string metric = metric_type.value();
if (std::find(legal_metric_list.begin(), legal_metric_list.end(), metric) == legal_metric_list.end()) {
*err_msg = "metric type " + metric + " not found or not supported, supported: [L2 IP]";
return Status::invalid_metric_type;
}
}
return Status::success;
}
};

[[nodiscard]] inline auto
Expand Down
27 changes: 27 additions & 0 deletions src/index/hnsw/hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,33 @@ class HnswIndexNode : public IndexNode {
return Status::success;
}

static Status
StaticConfigCheck(const Config& cfg, PARAM_TYPE paramType, std::string& msg) {
auto hnsw_cfg = static_cast<const HnswConfig&>(cfg);

if (paramType == PARAM_TYPE::TRAIN) {
if constexpr (KnowhereFloatTypeCheck<DataType>::value) {
if (IsMetricType(hnsw_cfg.metric_type.value(), metric::L2) ||
IsMetricType(hnsw_cfg.metric_type.value(), metric::IP) ||
IsMetricType(hnsw_cfg.metric_type.value(), metric::COSINE)) {
} else {
msg = "metric type " + hnsw_cfg.metric_type.value() +
" not found or not supported, supported: [L2 IP COSINE]";
return Status::invalid_metric_type;
}
} else {
if (IsMetricType(hnsw_cfg.metric_type.value(), metric::HAMMING) ||
IsMetricType(hnsw_cfg.metric_type.value(), metric::JACCARD)) {
} else {
msg = "metric type " + hnsw_cfg.metric_type.value() +
" not found or not supported, supported: [HAMMING JACCARD]";
return Status::invalid_metric_type;
}
}
}
return Status::success;
}

Status
Add(const DataSetPtr dataset, std::shared_ptr<Config> cfg) override {
if (!index_) {
Expand Down
26 changes: 26 additions & 0 deletions src/index/index_static.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,25 @@ IndexStaticFaced<DataType>::CreateConfig(const IndexType& indexType, const Index
return std::make_unique<BaseConfig>();
}

template <typename DataType>
knowhere::Status
IndexStaticFaced<DataType>::ConfigCheck(const IndexType& indexType, const IndexVersion& version, const Json& params,
std::string& msg) {
auto cfg = IndexStaticFaced<DataType>::CreateConfig(indexType, version);

const Status status = LoadStaticConfig(cfg.get(), params, knowhere::PARAM_TYPE::TRAIN, "ConfigCheck", &msg);
if (status != Status::success) {
LOG_KNOWHERE_ERROR_ << "Load Config failed, msg = " << msg;
return status;
}

if (Instance().staticConfigCheckMap.find(indexType) != Instance().staticConfigCheckMap.end()) {
return Instance().staticConfigCheckMap[indexType](*cfg, version, msg);
}

return knowhere::Status::success;
}

template <typename DataType>
expected<Resource>
IndexStaticFaced<DataType>::EstimateLoadResource(const knowhere::IndexType& indexType,
Expand Down Expand Up @@ -123,6 +142,13 @@ IndexStaticFaced<DataType>::InternalStaticCreateConfig() {
return std::unique_ptr<BaseConfig>();
}

template <typename DataType>
knowhere::Status
IndexStaticFaced<DataType>::InternalConfigCheck(const BaseConfig& config, const IndexVersion& version,
std::string& msg) {
return knowhere::Status::success;
}

template class IndexStaticFaced<knowhere::fp32>;
template class IndexStaticFaced<knowhere::fp16>;
template class IndexStaticFaced<knowhere::bf16>;
Expand Down
27 changes: 27 additions & 0 deletions src/index/ivf/ivf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,33 @@ class IvfIndexNode : public IndexNode {
expected<DataSetPtr>
GetVectorByIds(const DataSetPtr dataset) const override;

static Status
StaticConfigCheck(const Config& cfg, PARAM_TYPE paramType, std::string& msg) {
auto ivf_cfg = static_cast<const IvfConfig&>(cfg);

if (paramType == PARAM_TYPE::TRAIN) {
if constexpr (KnowhereFloatTypeCheck<DataType>::value) {
if (IsMetricType(ivf_cfg.metric_type.value(), metric::L2) ||
IsMetricType(ivf_cfg.metric_type.value(), metric::IP) ||
IsMetricType(ivf_cfg.metric_type.value(), metric::COSINE)) {
} else {
msg = "metric type " + ivf_cfg.metric_type.value() +
" not found or not supported, supported: [L2 IP COSINE]";
return Status::invalid_metric_type;
}
} else {
if (IsMetricType(ivf_cfg.metric_type.value(), metric::HAMMING) ||
IsMetricType(ivf_cfg.metric_type.value(), metric::JACCARD)) {
} else {
msg = "metric type " + ivf_cfg.metric_type.value() +
" not found or not supported, supported: [HAMMING JACCARD]";
return Status::invalid_metric_type;
}
}
}
return Status::success;
}

static bool
CommonHasRawData() {
if constexpr (std::is_same<faiss::IndexIVFFlat, IndexType>::value) {
Expand Down
Loading

0 comments on commit 1aa6181

Please sign in to comment.