diff --git a/internal/core/src/index/StringIndexMarisa.cpp b/internal/core/src/index/StringIndexMarisa.cpp index fddef5851834e..b2b58e7929454 100644 --- a/internal/core/src/index/StringIndexMarisa.cpp +++ b/internal/core/src/index/StringIndexMarisa.cpp @@ -35,6 +35,7 @@ #include "index/StringIndexMarisa.h" #include "index/Utils.h" #include "index/Index.h" +#include "marisa/base.h" #include "storage/Util.h" namespace milvus::index { @@ -92,7 +93,7 @@ StringIndexMarisa::BuildWithFieldData( } total_num_rows += slice_num; } - trie_.build(keyset); + trie_.build(keyset, MARISA_LABEL_ORDER); // fill str_ids_ str_ids_.resize(total_num_rows, MARISA_NULL_KEY_ID); @@ -130,7 +131,7 @@ StringIndexMarisa::Build(size_t n, const std::string* values) { } } - trie_.build(keyset); + trie_.build(keyset, MARISA_LABEL_ORDER); fill_str_ids(n, values); fill_offsets(); @@ -309,44 +310,103 @@ StringIndexMarisa::Range(std::string value, OpType op) { TargetBitmap bitset(count); std::vector ids; marisa::Agent agent; + bool in_lexico_order = in_lexicographic_order(); switch (op) { case OpType::GreaterThan: { - while (trie_.predictive_search(agent)) { - auto key = std::string(agent.key().ptr(), agent.key().length()); - if (key > value) { + if (in_lexico_order) { + while (trie_.predictive_search(agent)) { + auto key = + std::string(agent.key().ptr(), agent.key().length()); + if (key > value) { + ids.push_back(agent.key().id()); + break; + } + }; + // since in lexicographic order, all following nodes is greater than value + while (trie_.predictive_search(agent)) { ids.push_back(agent.key().id()); } - }; + } else { + // lexicographic order is not guaranteed, check all values + while (trie_.predictive_search(agent)) { + auto key = + std::string(agent.key().ptr(), agent.key().length()); + if (key > value) { + ids.push_back(agent.key().id()); + } + }; + } break; } case OpType::GreaterEqual: { - while (trie_.predictive_search(agent)) { - auto key = std::string(agent.key().ptr(), agent.key().length()); - if (key >= value) { + if (in_lexico_order) { + while (trie_.predictive_search(agent)) { + auto key = + std::string(agent.key().ptr(), agent.key().length()); + if (key >= value) { + ids.push_back(agent.key().id()); + break; + } + }; + // since in lexicographic order, all following nodes is greater than or equal value + while (trie_.predictive_search(agent)) { ids.push_back(agent.key().id()); } + } else { + // lexicographic order is not guaranteed, check all values + while (trie_.predictive_search(agent)) { + auto key = + std::string(agent.key().ptr(), agent.key().length()); + if (key >= value) { + ids.push_back(agent.key().id()); + } + }; } break; } case OpType::LessThan: { - while (trie_.predictive_search(agent)) { - auto key = std::string(agent.key().ptr(), agent.key().length()); - if (key >= value) { - break; + if (in_lexico_order) { + while (trie_.predictive_search(agent)) { + auto key = + std::string(agent.key().ptr(), agent.key().length()); + if (key >= value) { + break; + } + ids.push_back(agent.key().id()); } - ids.push_back(agent.key().id()); + break; + } else { + // lexicographic order is not guaranteed, check all values + while (trie_.predictive_search(agent)) { + auto key = + std::string(agent.key().ptr(), agent.key().length()); + if (key < value) { + ids.push_back(agent.key().id()); + } + }; } - break; } case OpType::LessEqual: { - while (trie_.predictive_search(agent)) { - auto key = std::string(agent.key().ptr(), agent.key().length()); - if (key > value) { - break; + if (in_lexico_order) { + while (trie_.predictive_search(agent)) { + auto key = + std::string(agent.key().ptr(), agent.key().length()); + if (key > value) { + break; + } + ids.push_back(agent.key().id()); } - ids.push_back(agent.key().id()); + break; + } else { + // lexicographic order is not guaranteed, check all values + while (trie_.predictive_search(agent)) { + auto key = + std::string(agent.key().ptr(), agent.key().length()); + if (key <= value) { + ids.push_back(agent.key().id()); + } + }; } - break; } default: PanicInfo( @@ -376,6 +436,8 @@ StringIndexMarisa::Range(std::string lower_bound_value, return bitset; } + bool in_lexico_oder = in_lexicographic_order(); + auto common_prefix = GetCommonPrefix(lower_bound_value, upper_bound_value); marisa::Agent agent; agent.set_query(common_prefix.c_str()); @@ -385,7 +447,12 @@ StringIndexMarisa::Range(std::string lower_bound_value, std::string_view(agent.key().ptr(), agent.key().length()); if (val > upper_bound_value || (!ub_inclusive && val == upper_bound_value)) { - break; + // we could only break when trie in lexicographic order. + if (in_lexico_oder) { + break; + } else { + continue; + } } if (val < lower_bound_value || @@ -477,4 +544,15 @@ StringIndexMarisa::Reverse_Lookup(size_t offset) const { return std::string(agent.key().ptr(), agent.key().length()); } +bool +StringIndexMarisa::in_lexicographic_order() { + // by default, marisa trie uses `MARISA_WEIGHT_ORDER` to build trie + // so `predictive_search` will not iterate in lexicographic order + // now we build trie using `MARISA_LABEL_ORDER` and also handle old index in weight order. + if (trie_.node_order() == MARISA_LABEL_ORDER) { + return true; + } + + return false; +} } // namespace milvus::index diff --git a/internal/core/src/index/StringIndexMarisa.h b/internal/core/src/index/StringIndexMarisa.h index a1227414a3845..72913d6675987 100644 --- a/internal/core/src/index/StringIndexMarisa.h +++ b/internal/core/src/index/StringIndexMarisa.h @@ -112,6 +112,9 @@ class StringIndexMarisa : public StringIndex { std::vector prefix_match(const std::string_view prefix); + bool + in_lexicographic_order(); + void LoadWithoutAssemble(const BinarySet& binary_set, const Config& config) override;