diff --git a/src/commands/cmd_cms.cc b/src/commands/cmd_cms.cc new file mode 100644 index 00000000000..7f7439c72fb --- /dev/null +++ b/src/commands/cmd_cms.cc @@ -0,0 +1,277 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include +#include + +#include "commander.h" +#include "commands/command_parser.h" +#include "parse_util.h" +#include "server/redis_reply.h" +#include "server/server.h" + +namespace redis { + +/// CMS.INCRBY key item increment [item increment ...] +/// +/// The `key` should be an existing Count-Min Sketch key, +/// otherwise, the command will return an error. +/// +/// The output should be an array of integers, each integer +/// means the counter value after the increment. If the increment +/// overflows, the return value will be `"CMS: INCRBY overflow"`. +class CommandCMSIncrBy final : public Commander { + public: + Status Execute(Server *srv, Connection *conn, std::string *output) override { + if ((args_.size() - 2) % 2 != 0) { + return {Status::RedisParseErr, errWrongNumOfArguments}; + } + redis::CMS cms(srv->storage, conn->GetNamespace()); + engine::Context ctx(srv->storage); + rocksdb::Status s; + // pairs + std::vector elements; + elements.reserve((args_.size() - 2) / 2); + for (size_t i = 2; i < args_.size(); i += 2) { + std::string_view key = args_[i]; + auto parse_result = ParseInt(args_[i + 1]); + if (!parse_result) { + return {Status::RedisParseErr, errValueNotInteger}; + } + int64_t value = *parse_result; + elements.emplace_back(redis::CMS::IncrByPair{key, value}); + } + std::vector counters; + s = cms.IncrBy(ctx, args_[1], elements, &counters); + if (s.IsNotFound()) { + return {Status::RedisExecErr, "Key not found"}; + } + if (!s.ok()) { + return {Status::RedisExecErr, s.ToString()}; + } + std::vector strs; + for (uint32_t counter : counters) { + if (counter == std::numeric_limits::max()) { + strs.push_back(redis::Error({Status::RedisExecErr, "CMS: INCRBY overflow"})); + } else { + strs.push_back(redis::Integer(counter)); + } + } + *output = redis::Array(strs); + return Status::OK(); + } +}; + +/// CMS.INFO key +class CommandCMSInfo final : public Commander { + public: + Status Execute(Server *srv, Connection *conn, std::string *output) override { + redis::CMS cms(srv->storage, conn->GetNamespace()); + engine::Context ctx(srv->storage); + rocksdb::Status s; + CMSketch::CMSInfo ret{}; + + s = cms.Info(ctx, args_[1], &ret); + + if (s.IsNotFound()) { + return {Status::RedisExecErr, "Key not found"}; + } + + if (!s.ok()) { + return {Status::RedisExecErr, s.ToString()}; + } + + *output = redis::Array({redis::BulkString("width"), redis::Integer(ret.width), redis::BulkString("depth"), + redis::Integer(ret.depth), redis::BulkString("count"), redis::Integer(ret.count)}); + + return Status::OK(); + } +}; + +/// CMS.INITBYDIM key width depth +/// +/// If the key already exists, the command will return an error. +class CommandCMSInitByDim final : public Commander { + public: + Status Execute(Server *srv, Connection *conn, std::string *output) override { + redis::CMS cms(srv->storage, conn->GetNamespace()); + engine::Context ctx(srv->storage); + rocksdb::Status s; + auto width_result = ParseInt(this->args_[2]); + if (!width_result) { + return {Status::RedisParseErr, errValueNotInteger}; + } + uint32_t width = *width_result; + + auto depth_result = ParseInt(this->args_[3]); + if (!depth_result) { + return {Status::RedisParseErr, errValueNotInteger}; + } + uint32_t depth = *depth_result; + + s = cms.InitByDim(ctx, args_[1], width, depth); + if (!s.ok()) { + return {Status::RedisExecErr, s.ToString()}; + } + + *output = redis::SimpleString("OK"); + return Status::OK(); + } +}; + +/// CMS.INITBYPROB key error probability +/// +/// If the key already exists, the command will return an error. +class CommandCMSInitByProb final : public Commander { + public: + Status Execute(Server *srv, Connection *conn, std::string *output) override { + redis::CMS cms(srv->storage, conn->GetNamespace()); + engine::Context ctx(srv->storage); + rocksdb::Status s; + + auto error_result = ParseFloat(args_[2]); + if (!error_result) { + return {Status::RedisParseErr, errValueNotInteger}; + } + double error = *error_result; + + auto delta_result = ParseFloat(args_[3]); + if (!delta_result) { + return {Status::RedisParseErr, errValueNotInteger}; + } + double delta = *delta_result; + + s = cms.InitByProb(ctx, args_[1], error, delta); + if (!s.ok()) { + return {Status::RedisExecErr, s.ToString()}; + } + + *output = redis::SimpleString("OK"); + return Status::OK(); + } +}; + +/// CMS.MERGE destination numKeys source [source ...] [WEIGHTS weight [weight ...]] +class CommandCMSMerge final : public Commander { + public: + Status Parse(const std::vector &args) override { + CommandParser parser(args, 2); + destination_ = args[1]; + + StatusOr num_key_result = parser.TakeInt(); + if (!num_key_result || *num_key_result <= 0) { + return {Status::RedisParseErr, "invalid number of source keys"}; + } + num_keys_ = *num_key_result; + + src_keys_.reserve(num_keys_); + for (int i = 0; i < num_keys_; i++) { + auto result = parser.TakeStr(); + if (!result) { + return {Status::RedisParseErr, "Error parsing source key"}; + } + src_keys_.emplace_back(std::move(*result)); + } + + bool weights_found = false; + while (parser.Good()) { + // Parse "WEIGHTS" if exists. + if (parser.EatEqICase("WEIGHTS")) { + if (weights_found) { + return {Status::RedisParseErr, "WEIGHTS option cannot be specified multiple times"}; + } + src_weights_.reserve(num_keys_); + for (int i = 0; i < num_keys_; i++) { + StatusOr weight_result = parser.TakeInt(); + if (!weight_result || *weight_result == 0) { + return {Status::RedisParseErr, "invalid weight value"}; + } + src_weights_.emplace_back(*weight_result); + } + weights_found = true; + } else { + return {Status::RedisParseErr, "Syntax error: unexpected token"}; + } + } + + if (!weights_found) { + src_weights_.resize(num_keys_, 1); + } + + return Status::OK(); + } + + Status Execute(Server *srv, Connection *conn, std::string *output) override { + redis::CMS cms(srv->storage, conn->GetNamespace()); + engine::Context ctx(srv->storage); + + rocksdb::Status s = cms.MergeUserKeys(ctx, destination_, src_keys_, src_weights_); + if (!s.ok()) { + return {Status::RedisExecErr, s.ToString()}; + } + + *output = redis::SimpleString("OK"); + return Status::OK(); + } + + private: + Slice destination_; + int num_keys_{0}; + std::vector src_keys_; + std::vector src_weights_; +}; + +/// CMS.QUERY key item [item ...] +class CommandCMSQuery final : public Commander { + public: + Status Execute(Server *srv, Connection *conn, std::string *output) override { + redis::CMS cms(srv->storage, conn->GetNamespace()); + engine::Context ctx(srv->storage); + + std::vector counters{}; + std::vector elements; + + for (size_t i = 2; i < args_.size(); ++i) { + elements.emplace_back(args_[i]); + } + + rocksdb::Status s = cms.Query(ctx, args_[1], elements, counters); + if (!s.ok()) { + return {Status::RedisExecErr, s.ToString()}; + } + + std::vector output_values; + output_values.reserve(counters.size()); + for (const auto &counter : counters) { + output_values.emplace_back(Integer(counter)); + } + *output = redis::Array(output_values); + + return Status::OK(); + } +}; + +REDIS_REGISTER_COMMANDS(CMS, MakeCmdAttr("cms.incrby", -4, "write", 0, 0, 0), + MakeCmdAttr("cms.info", 2, "read-only", 0, 0, 0), + MakeCmdAttr("cms.initbydim", 4, "write", 0, 0, 0), + MakeCmdAttr("cms.initbyprob", 4, "write", 0, 0, 0), + MakeCmdAttr("cms.merge", -4, "write", 0, 0, 0), + MakeCmdAttr("cms.query", -3, "read-only", 0, 0, 0), ); +} // namespace redis \ No newline at end of file diff --git a/src/commands/commander.h b/src/commands/commander.h index 9d8dd23932d..1bb5bc6079e 100644 --- a/src/commands/commander.h +++ b/src/commands/commander.h @@ -73,6 +73,7 @@ enum class CommandCategory : uint8_t { Bit, BloomFilter, Cluster, + CMS, Function, Geo, Hash, diff --git a/src/storage/redis_metadata.cc b/src/storage/redis_metadata.cc index 76403faaef3..07ae5703e14 100644 --- a/src/storage/redis_metadata.cc +++ b/src/storage/redis_metadata.cc @@ -326,7 +326,9 @@ bool Metadata::ExpireAt(uint64_t expired_ts) const { return expire < expired_ts; } -bool Metadata::IsSingleKVType() const { return Type() == kRedisString || Type() == kRedisJson; } +bool Metadata::IsSingleKVType() const { + return Type() == kRedisString || Type() == kRedisJson || Type() == kRedisCountMinSketch; +} bool Metadata::IsEmptyableType() const { return IsSingleKVType() || Type() == kRedisStream || Type() == kRedisBloomFilter || Type() == kRedisHyperLogLog; @@ -495,3 +497,41 @@ rocksdb::Status HyperLogLogMetadata::Decode(Slice *input) { return rocksdb::Status::OK(); } + +void CountMinSketchMetadata::Encode(std::string *dst) const { + Metadata::Encode(dst); + PutFixed32(dst, width); + PutFixed32(dst, depth); + PutFixed64(dst, counter); + for (const auto &count : array) { + PutFixed32(dst, count); + } +} + +rocksdb::Status CountMinSketchMetadata::Decode(Slice *input) { + if (auto s = Metadata::Decode(input); !s.ok()) { + return s; + } + if (!GetFixed32(input, &width)) { + return rocksdb::Status::InvalidArgument(kErrMetadataTooShort); + } + if (!GetFixed32(input, &depth)) { + return rocksdb::Status::InvalidArgument(kErrMetadataTooShort); + } + if (!GetFixed64(input, &counter)) { + return rocksdb::Status::InvalidArgument(kErrMetadataTooShort); + } + + size_t array_size = width * depth; + array.resize(array_size); + + for (size_t i = 0; i < array_size; ++i) { + uint32_t count = 0; + if (!GetFixed32(input, &count)) { + return rocksdb::Status::InvalidArgument(kErrMetadataTooShort); + } + array[i] = count; + } + + return rocksdb::Status::OK(); +} diff --git a/src/storage/redis_metadata.h b/src/storage/redis_metadata.h index 5590609be37..224b9bcff04 100644 --- a/src/storage/redis_metadata.h +++ b/src/storage/redis_metadata.h @@ -50,6 +50,7 @@ enum RedisType : uint8_t { kRedisBloomFilter = 9, kRedisJson = 10, kRedisHyperLogLog = 11, + kRedisCountMinSketch = 12, }; struct RedisTypes { @@ -91,9 +92,9 @@ enum RedisCommand { kRedisCmdLMove, }; -const std::vector RedisTypeNames = {"none", "string", "hash", "list", - "set", "zset", "bitmap", "sortedint", - "stream", "MBbloom--", "ReJSON-RL", "hyperloglog"}; +const std::vector RedisTypeNames = {"none", "string", "hash", "list", "set", + "zset", "bitmap", "sortedint", "stream", "MBbloom--", + "ReJSON-RL", "hyperloglog", "countminsketch"}; constexpr const char *kErrMsgWrongType = "WRONGTYPE Operation against a key holding the wrong kind of value"; constexpr const char *kErrMsgKeyExpired = "the key was expired"; @@ -335,3 +336,15 @@ class HyperLogLogMetadata : public Metadata { EncodeType encode_type = EncodeType::DENSE; }; + +class CountMinSketchMetadata : public Metadata { + public: + uint32_t width; + uint32_t depth; + uint64_t counter = 0; + std::vector array; + + explicit CountMinSketchMetadata(bool generate_version = true) : Metadata(kRedisCountMinSketch, generate_version) {} + void Encode(std::string *dst) const override; + rocksdb::Status Decode(Slice *input) override; +}; diff --git a/src/types/cms.cc b/src/types/cms.cc new file mode 100644 index 00000000000..ba56b17bf4c --- /dev/null +++ b/src/types/cms.cc @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include "cms.h" + +#include +#include +#include +#include + +#include "xxhash.h" + +uint64_t CMSketch::CountMinSketchHash(std::string_view item, uint64_t seed) { + return XXH64(item.data(), item.size(), seed); +} + +CMSketch::CMSketchDimensions CMSketch::CMSDimFromProb(double error, double delta) { + CMSketchDimensions dims; + dims.width = std::ceil(2 / error); + dims.depth = std::ceil(std::log10(delta) / std::log10(0.5)); + return dims; +} + +uint32_t CMSketch::IncrBy(std::string_view item, uint32_t value) { + uint32_t min_count = std::numeric_limits::max(); + + for (size_t i = 0; i < depth_; ++i) { + uint64_t hash = CountMinSketchHash(item, /*seed=*/i); + size_t loc = GetLocationForHash(hash, i); + // Do overflow check + if (array_[loc] > std::numeric_limits::max() - value) { + array_[loc] = std::numeric_limits::max(); + } else { + array_[loc] += value; + } + min_count = std::min(min_count, array_[loc]); + } + counter_ += value; + return min_count; +} + +uint32_t CMSketch::Query(std::string_view item) const { + uint32_t min_count = std::numeric_limits::max(); + + for (size_t i = 0; i < depth_; ++i) { + uint64_t hash = CountMinSketchHash(item, /*seed=*/i); + min_count = std::min(min_count, array_[GetLocationForHash(hash, i)]); + } + return min_count; +} + +Status CMSketch::Merge(CMSketch* dest, size_t num_keys, std::vector cms_array, + std::vector weights) { + // Perform overflow check + if (CMSketch::CheckOverflow(dest, num_keys, cms_array, weights) != 0) { + return {Status::NotOK, "Overflow error."}; + } + + size_t dest_depth = dest->GetDepth(); + size_t dest_width = dest->GetWidth(); + + // Merge source CMSes into the destination CMS + for (size_t i = 0; i < dest_depth; ++i) { + for (size_t j = 0; j < dest_width; ++j) { + int64_t item_count = 0; + for (size_t k = 0; k < num_keys; ++k) { + item_count += static_cast(cms_array[k]->array_[(i * dest_width) + j]) * weights[k]; + } + dest->GetArray()[(i * dest_width) + j] += static_cast(item_count); + } + } + + for (size_t i = 0; i < num_keys; ++i) { + dest->GetCounter() += cms_array[i]->GetCounter() * weights[i]; + } + + return Status::OK(); +} + +int CMSketch::CheckOverflow(CMSketch* dest, size_t quantity, const std::vector& src, + const std::vector& weights) { + int64_t item_count = 0; + int64_t cms_count = 0; + size_t width = dest->GetWidth(); + size_t depth = dest->GetDepth(); + + for (size_t i = 0; i < depth; ++i) { + for (size_t j = 0; j < width; ++j) { + item_count = 0; + for (size_t k = 0; k < quantity; ++k) { + int64_t mul = 0; + + if (__builtin_mul_overflow(src[k]->GetArray()[(i * width) + j], weights[k], &mul) || + (__builtin_add_overflow(item_count, mul, &item_count))) { + return -1; + } + } + + if (item_count < 0 || item_count > UINT32_MAX) { + return -1; + } + } + } + + for (size_t i = 0; i < quantity; ++i) { + int64_t mul = 0; + + if (__builtin_mul_overflow(src[i]->GetCounter(), weights[i], &mul) || + (__builtin_add_overflow(cms_count, mul, &cms_count))) { + return -1; + } + } + + if (cms_count < 0) { + return -1; + } + + return 0; +} diff --git a/src/types/cms.h b/src/types/cms.h new file mode 100644 index 00000000000..4e3cd467af4 --- /dev/null +++ b/src/types/cms.h @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include +#include + +#include "server/redis_reply.h" + +class CMSketch { + public: + explicit CMSketch(uint32_t width, uint32_t depth, uint64_t counter, std::vector array) + : width_(width), + depth_(depth), + counter_(counter), + array_(array.empty() ? std::vector(width * depth, 0) : std::move(array)) {} + + struct CMSInfo { + uint32_t width; + uint32_t depth; + uint64_t count; + }; + + struct CMSketchDimensions { + uint32_t width; + uint32_t depth; + }; + + static CMSketchDimensions CMSDimFromProb(double error, double delta); + + /// Increment the counter of the given item by the specified increment. + /// + /// \param item The item to increment. Returns UINT32_MAX if the + /// counter overflows. + uint32_t IncrBy(std::string_view item, uint32_t value); + + uint32_t Query(std::string_view item) const; + + static Status Merge(CMSketch* dest, size_t num_keys, std::vector cms_array, + std::vector weights); + + size_t GetLocationForHash(uint64_t hash, size_t i) const { return (hash % width_) + (i * width_); } + + uint64_t& GetCounter() { return counter_; } + std::vector& GetArray() { return array_; } + + const uint64_t& GetCounter() const { return counter_; } + const std::vector& GetArray() const { return array_; } + + uint32_t GetWidth() const { return width_; } + uint32_t GetDepth() const { return depth_; } + + static int CheckOverflow(CMSketch* dest, size_t quantity, const std::vector& src, + const std::vector& weights); + + private: + static uint64_t CountMinSketchHash(std::string_view item, uint64_t seed); + + private: + uint32_t width_; + uint32_t depth_; + uint64_t counter_; + std::vector array_; +}; \ No newline at end of file diff --git a/src/types/redis_cms.cc b/src/types/redis_cms.cc new file mode 100644 index 00000000000..739d791a6d9 --- /dev/null +++ b/src/types/redis_cms.cc @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include "redis_cms.h" + +#include + +#include "cms.h" +#include "rocksdb/status.h" + +namespace redis { + +rocksdb::Status CMS::GetMetadata(engine::Context &ctx, const Slice &ns_key, CountMinSketchMetadata *metadata) { + return Database::GetMetadata(ctx, {kRedisCountMinSketch}, ns_key, metadata); +} + +rocksdb::Status CMS::IncrBy(engine::Context &ctx, const Slice &user_key, const std::vector &elements, + std::vector *counters) { + std::string ns_key = AppendNamespacePrefix(user_key); + + LockGuard guard(storage_->GetLockManager(), ns_key); + CountMinSketchMetadata metadata{}; + rocksdb::Status s = GetMetadata(ctx, ns_key, &metadata); + + if (!s.ok()) { + // If s.NotFound() or !ok(), we should return the error directly. + return s; + } + + auto batch = storage_->GetWriteBatchBase(); + WriteBatchLogData log_data(kRedisCountMinSketch); + batch->PutLogData(log_data.Encode()); + + CMSketch cms(metadata.width, metadata.depth, metadata.counter, metadata.array); + counters->reserve(elements.size()); + for (const auto &element : elements) { + // TODO(mwish): should we put the parsing result outside? + if (element.value > 0 && + metadata.counter > static_cast(std::numeric_limits::max() - element.value)) { + return rocksdb::Status::InvalidArgument("Overflow error: IncrBy would result in counter overflow"); + } + uint32_t local_counter = cms.IncrBy(element.key, element.value); + counters->push_back(local_counter); + metadata.counter += element.value; + } + + metadata.array = std::move(cms.GetArray()); + + std::string bytes; + metadata.Encode(&bytes); + batch->Put(metadata_cf_handle_, ns_key, bytes); + + return storage_->Write(ctx, storage_->DefaultWriteOptions(), batch->GetWriteBatch()); +} + +rocksdb::Status CMS::Info(engine::Context &ctx, const Slice &user_key, CMSketch::CMSInfo *ret) { + std::string ns_key = AppendNamespacePrefix(user_key); + + LockGuard guard(storage_->GetLockManager(), ns_key); + CountMinSketchMetadata metadata{}; + rocksdb::Status s = GetMetadata(ctx, ns_key, &metadata); + + if (!s.ok()) { + return s; + } + + ret->width = metadata.width; + ret->depth = metadata.depth; + ret->count = metadata.counter; + return rocksdb::Status::OK(); +} + +rocksdb::Status CMS::InitByDim(engine::Context &ctx, const Slice &user_key, uint32_t width, uint32_t depth) { + std::string ns_key = AppendNamespacePrefix(user_key); + + size_t memory_used = width * depth * sizeof(uint32_t); + // We firstly limit the memory usage to 1MB. + constexpr size_t kMaxMemory = 1 * 1024 * 1024; + + if (memory_used == 0) { + return rocksdb::Status::InvalidArgument("Memory usage must be greater than 0."); + } + if (memory_used > kMaxMemory) { + return rocksdb::Status::InvalidArgument("Memory usage exceeds 1MB."); + } + + LockGuard guard(storage_->GetLockManager(), ns_key); + CountMinSketchMetadata metadata{}; + + rocksdb::Status s = GetMetadata(ctx, ns_key, &metadata); + + if (s.ok()) { + return rocksdb::Status::InvalidArgument("Key already exists."); + } + + if (!s.IsNotFound()) { + return s; + } + + auto batch = storage_->GetWriteBatchBase(); + WriteBatchLogData log_data(kRedisCountMinSketch); + batch->PutLogData(log_data.Encode()); + + metadata.width = width; + metadata.depth = depth; + metadata.counter = 0; + metadata.array = std::vector(width * depth, 0); + + std::string bytes; + metadata.Encode(&bytes); + batch->Put(metadata_cf_handle_, ns_key, bytes); + + return storage_->Write(ctx, storage_->DefaultWriteOptions(), batch->GetWriteBatch()); +}; + +rocksdb::Status CMS::InitByProb(engine::Context &ctx, const Slice &user_key, double error, double delta) { + if (error <= 0 || error >= 1) { + return rocksdb::Status::InvalidArgument("Error must be between 0 and 1 (exclusive)."); + } + if (delta <= 0 || delta >= 1) { + return rocksdb::Status::InvalidArgument("Delta must be between 0 and 1 (exclusive)."); + } + CMSketch::CMSketchDimensions dim = CMSketch::CMSDimFromProb(error, delta); + return InitByDim(ctx, user_key, dim.width, dim.depth); +}; + +rocksdb::Status CMS::MergeUserKeys(engine::Context &ctx, const Slice &user_key, const std::vector &src_keys, + const std::vector &src_weights) { + size_t num_sources = src_keys.size(); + if (num_sources == 0) { + return rocksdb::Status::InvalidArgument("No source keys provided for merge."); + } + if (src_weights.size() != num_sources) { + return rocksdb::Status::InvalidArgument("Number of weights must match number of source keys."); + } + + std::string dest_ns_key = AppendNamespacePrefix(user_key); + std::vector ns_keys{dest_ns_key}; + for (const auto &src_key : src_keys) { + ns_keys.emplace_back(AppendNamespacePrefix(src_key)); + } + MultiLockGuard guard(storage_->GetLockManager(), ns_keys); + + CountMinSketchMetadata dest_metadata{}; + rocksdb::Status dest_status = GetMetadata(ctx, dest_ns_key, &dest_metadata); + if (dest_status.IsNotFound()) { + return rocksdb::Status::InvalidArgument("Destination CMS does not exist."); + } + if (!dest_status.ok()) { + return dest_status; + } + + CMSketch dest_cms(dest_metadata.width, dest_metadata.depth, dest_metadata.counter, dest_metadata.array); + + std::vector src_cms_objects; + src_cms_objects.reserve(num_sources); + std::vector weights_long; + weights_long.reserve(num_sources); + + for (size_t i = 0; i < num_sources; ++i) { + const auto &src_ns_key = ns_keys[i + 1]; + CountMinSketchMetadata src_metadata{}; + rocksdb::Status src_status = GetMetadata(ctx, src_ns_key, &src_metadata); + if (!src_status.ok()) { + // TODO(mwish): check the not found syntax here. + if (src_status.IsNotFound()) { + return rocksdb::Status::InvalidArgument("Source CMS key not found."); + } + return src_status; + } + + if (src_metadata.width != dest_metadata.width || src_metadata.depth != dest_metadata.depth) { + return rocksdb::Status::InvalidArgument("Source CMS dimensions do not match destination CMS."); + } + + CMSketch src_cms(src_metadata.width, src_metadata.depth, src_metadata.counter, src_metadata.array); + src_cms_objects.emplace_back(std::move(src_cms)); + + weights_long.push_back(static_cast(src_weights[i])); + } + // Initialize the destination CMS with the source CMSes after initializations + // since vector might resize and reallocate memory. + std::vector src_cms_pointers(num_sources); + for (size_t i = 0; i < num_sources; ++i) { + src_cms_pointers[i] = &src_cms_objects[i]; + } + auto merge_result = CMSketch::Merge(&dest_cms, num_sources, src_cms_pointers, weights_long); + if (!merge_result.IsOK()) { + return rocksdb::Status::InvalidArgument("Merge operation failed due to overflow or invalid dimensions."); + } + + auto batch = storage_->GetWriteBatchBase(); + WriteBatchLogData log_data(kRedisCountMinSketch); + batch->PutLogData(log_data.Encode()); + + dest_metadata.counter = dest_cms.GetCounter(); + dest_metadata.array = dest_cms.GetArray(); + + std::string encoded_metadata; + dest_metadata.Encode(&encoded_metadata); + batch->Put(metadata_cf_handle_, dest_ns_key, encoded_metadata); + + return storage_->Write(ctx, storage_->DefaultWriteOptions(), batch->GetWriteBatch()); +} + +rocksdb::Status CMS::Query(engine::Context &ctx, const Slice &user_key, const std::vector &elements, + std::vector &counters) { + std::string ns_key = AppendNamespacePrefix(user_key); + counters.resize(elements.size(), 0); + + LockGuard guard(storage_->GetLockManager(), ns_key); + CountMinSketchMetadata metadata{}; + rocksdb::Status s = GetMetadata(ctx, ns_key, &metadata); + if (!s.ok()) { + // If s.NotFound() or !ok(), we should return the error directly. + return s; + } + + CMSketch cms(metadata.width, metadata.depth, metadata.counter, metadata.array); + for (size_t i = 0; i < elements.size(); ++i) { + counters[i] = cms.Query(elements[i]); + } + + return rocksdb::Status::OK(); +} + +} // namespace redis diff --git a/src/types/redis_cms.h b/src/types/redis_cms.h new file mode 100644 index 00000000000..f8fff5e21a0 --- /dev/null +++ b/src/types/redis_cms.h @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include + +#include "cms.h" +#include "storage/redis_db.h" +#include "storage/redis_metadata.h" + +namespace redis { + +class CMS : public Database { + public: + explicit CMS(engine::Storage *storage, const std::string &ns) : Database(storage, ns) {} + + struct IncrByPair { + std::string_view key; + int64_t value; + }; + + /// Increment the counter of the given item(s) by the specified increment(s). + /// + /// \param[out] counters The counter values after the increment, if the value is UINT32_MAX, + /// it means the item does overflow. + rocksdb::Status IncrBy(engine::Context &ctx, const Slice &user_key, const std::vector &elements, + std::vector *counters); + rocksdb::Status Info(engine::Context &ctx, const Slice &user_key, CMSketch::CMSInfo *ret); + rocksdb::Status InitByDim(engine::Context &ctx, const Slice &user_key, uint32_t width, uint32_t depth); + rocksdb::Status InitByProb(engine::Context &ctx, const Slice &user_key, double error, double delta); + rocksdb::Status Query(engine::Context &ctx, const Slice &user_key, const std::vector &elements, + std::vector &counters); + rocksdb::Status MergeUserKeys(engine::Context &ctx, const Slice &user_key, const std::vector &src_keys, + const std::vector &src_weights); + + private: + [[nodiscard]] rocksdb::Status GetMetadata(engine::Context &ctx, const Slice &ns_key, + CountMinSketchMetadata *metadata); +}; + +} // namespace redis diff --git a/tests/cppunit/types/cms_test.cc b/tests/cppunit/types/cms_test.cc new file mode 100644 index 00000000000..cba7f6033b0 --- /dev/null +++ b/tests/cppunit/types/cms_test.cc @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include "types/cms.h" + +#include + +#include + +#include "test_base.h" +#include "types/redis_cms.h" + +class RedisCMSketchTest : public TestBase { + protected: + using IncrByPair = redis::CMS::IncrByPair; + explicit RedisCMSketchTest() : TestBase() { cms_ = std::make_unique(storage_.get(), "cms_ns"); } + ~RedisCMSketchTest() override = default; + + void SetUp() override { + TestBase::SetUp(); + [[maybe_unused]] auto s = cms_->Del(*ctx_, "cms"); + for (int x = 1; x <= 3; x++) { + s = cms_->Del(*ctx_, "cms" + std::to_string(x)); + } + } + + void TearDown() override { + TestBase::TearDown(); + [[maybe_unused]] auto s = cms_->Del(*ctx_, "cms"); + for (int x = 1; x <= 3; x++) { + s = cms_->Del(*ctx_, "cms" + std::to_string(x)); + } + } + + std::unique_ptr cms_; +}; + +TEST_F(RedisCMSketchTest, CMSInitByDim) { + ASSERT_TRUE(cms_->InitByDim(*ctx_, "cms", 100, 5).ok()); + CMSketch::CMSInfo info; + ASSERT_TRUE(cms_->Info(*ctx_, "cms", &info).ok()); + ASSERT_EQ(info.width, 100); + ASSERT_EQ(info.depth, 5); + ASSERT_EQ(info.count, 0); +} + +TEST_F(RedisCMSketchTest, CMSIncrBy) { + std::vector elements = {{"apple", 2}, {"banana", 3}, {"cherry", 1}}; + // TODO(mwish): check the responses. + std::vector responses; + ASSERT_TRUE(cms_->InitByDim(*ctx_, "cms", 100, 5).ok()); + ASSERT_TRUE(cms_->IncrBy(*ctx_, "cms", elements, &responses).ok()); + + std::vector counts; + ASSERT_TRUE(cms_->Query(*ctx_, "cms", {"apple", "banana", "cherry"}, counts).ok()); + + ASSERT_EQ(counts[0], 2); + ASSERT_EQ(counts[1], 3); + ASSERT_EQ(counts[2], 1); + + CMSketch::CMSInfo info; + ASSERT_TRUE(cms_->Info(*ctx_, "cms", &info).ok()); + ASSERT_EQ(info.count, 6); +} + +TEST_F(RedisCMSketchTest, CMSQuery) { + std::vector elements = {{"orange", 5}, {"grape", 3}, {"melon", 2}}; + std::vector responses; + ASSERT_TRUE(cms_->InitByDim(*ctx_, "cms", 100, 5).ok()); + ASSERT_TRUE(cms_->IncrBy(*ctx_, "cms", elements, &responses).ok()); + + std::vector counts; + ASSERT_TRUE(cms_->Query(*ctx_, "cms", {"orange", "grape", "melon", "nonexistent"}, counts).ok()); + + ASSERT_EQ(counts[0], 5); + ASSERT_EQ(counts[1], 3); + ASSERT_EQ(counts[2], 2); + ASSERT_EQ(counts[3], 0); +} + +TEST_F(RedisCMSketchTest, CMSInfo) { + ASSERT_TRUE(cms_->InitByDim(*ctx_, "cms", 200, 10).ok()); + + CMSketch::CMSInfo info; + ASSERT_TRUE(cms_->Info(*ctx_, "cms", &info).ok()); + + ASSERT_EQ(info.width, 200); + ASSERT_EQ(info.depth, 10); + ASSERT_EQ(info.count, 0); +} + +TEST_F(RedisCMSketchTest, CMSInitByProb) { + ASSERT_TRUE(cms_->InitByProb(*ctx_, "cms", 0.001, 0.1).ok()); + + CMSketch::CMSInfo info; + ASSERT_TRUE(cms_->Info(*ctx_, "cms", &info).ok()); + + ASSERT_EQ(info.width, 2000); + ASSERT_EQ(info.depth, 4); + ASSERT_EQ(info.count, 0); +} + +TEST_F(RedisCMSketchTest, CMSMultipleKeys) { + std::vector elements1 = {{"apple", 2}, {"banana", 3}}; + std::vector elements2 = {{"cherry", 1}, {"date", 4}}; + std::vector responses; + + ASSERT_TRUE(cms_->InitByDim(*ctx_, "cms1", 100, 5).ok()); + ASSERT_TRUE(cms_->InitByDim(*ctx_, "cms2", 100, 5).ok()); + + ASSERT_TRUE(cms_->IncrBy(*ctx_, "cms1", elements1, &responses).ok()); + ASSERT_TRUE(cms_->IncrBy(*ctx_, "cms2", elements2, &responses).ok()); + + std::vector counts1, counts2; + ASSERT_TRUE(cms_->Query(*ctx_, "cms1", {"apple", "banana"}, counts1).ok()); + ASSERT_TRUE(cms_->Query(*ctx_, "cms2", {"cherry", "date"}, counts2).ok()); + + ASSERT_EQ(counts1[0], 2); + ASSERT_EQ(counts1[1], 3); + ASSERT_EQ(counts2[0], 1); + ASSERT_EQ(counts2[1], 4); +} diff --git a/tests/gocase/unit/cms/cms_test.go b/tests/gocase/unit/cms/cms_test.go new file mode 100644 index 00000000000..aabf5692d81 --- /dev/null +++ b/tests/gocase/unit/cms/cms_test.go @@ -0,0 +1,293 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package cms + +import ( + "context" + "fmt" + "strconv" + "testing" + + "github.com/apache/kvrocks/tests/gocase/util" + "github.com/stretchr/testify/require" +) + +func TestCountMinSketch(t *testing.T) { + // Define configuration options if needed. + // Adjust or add more configurations as per your CMS requirements. + configOptions := []util.ConfigOptions{ + { + Name: "txn-context-enabled", + Options: []string{"yes", "no"}, + ConfigType: util.YesNo, + }, + // Add more configuration options here if necessary + } + + // Generate all combinations of configurations + configsMatrix, err := util.GenerateConfigsMatrix(configOptions) + require.NoError(t, err) + + // Iterate over each configuration and run CMS tests + for _, configs := range configsMatrix { + testCMS(t, configs) + } +} + +// testCMS sets up the server with the given configurations and runs CMS tests +func testCMS(t *testing.T, configs util.KvrocksServerConfigs) { + srv := util.StartServer(t, configs) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + // Run individual CMS test cases + t.Run("basic add", func(t *testing.T) { + require.NoError(t, rdb.Do(ctx, "DEL", "cmsA").Err()) + + res := rdb.Do(ctx, "cms.initbydim", "cmsA", 100, 10) + require.NoError(t, res.Err()) + require.Equal(t, "OK", res.Val()) + + require.Equal(t, []interface{}{"width", int64(100), "depth", int64(10), "count", int64(0)}, rdb.Do(ctx, "cms.info", "cmsA").Val()) + + res = rdb.Do(ctx, "cms.incrby", "cmsA", "foo", 1) + require.NoError(t, res.Err()) + addCnt, err := res.Result() + + require.NoError(t, err) + require.Equal(t, string("OK"), addCnt) + + card, err := rdb.Do(ctx, "cms.query", "cmsA", "foo").Result() + require.NoError(t, err) + require.Equal(t, []interface{}([]interface{}{"1"}), card, "The queried count for 'foo' should be 1") + }) + + t.Run("cms.initbyprob - Initialization with Probability Parameters", func(t *testing.T) { + require.NoError(t, rdb.Do(ctx, "DEL", "cmsB").Err()) + + res := rdb.Do(ctx, "cms.initbyprob", "cmsB", "0.001", "0.1") + require.NoError(t, res.Err()) + require.Equal(t, "OK", res.Val()) + + infoRes := rdb.Do(ctx, "cms.info", "cmsB") + require.NoError(t, infoRes.Err()) + infoSlice, ok := infoRes.Val().([]interface{}) + require.True(t, ok, "Expected cms.info to return a slice") + + infoMap := make(map[string]interface{}) + for i := 0; i < len(infoSlice); i += 2 { + key, ok1 := infoSlice[i].(string) + value, ok2 := infoSlice[i+1].(int64) + require.True(t, ok1 && ok2, "Expected cms.info keys to be strings and values to be int64") + infoMap[key] = value + } + + require.Equal(t, int64(2000), infoMap["width"]) + require.Equal(t, int64(4), infoMap["depth"]) + require.Equal(t, int64(0), infoMap["count"]) + }) + + t.Run("cms.incrby - Basic Increment Operations", func(t *testing.T) { + require.NoError(t, rdb.Do(ctx, "DEL", "cmsA").Err()) + res := rdb.Do(ctx, "cms.initbydim", "cmsA", "100", "10") + require.NoError(t, res.Err()) + require.Equal(t, "OK", res.Val()) + + elements := map[string]string{"apple": "7", "orange": "15", "mango": "3"} + for key, count := range elements { + res = rdb.Do(ctx, "cms.incrby", "cmsA", key, count) + require.NoError(t, res.Err()) + require.Equal(t, "OK", res.Val()) + } + + for key, expected := range elements { + res = rdb.Do(ctx, "cms.query", "cmsA", key) + require.NoError(t, res.Err()) + countSlice, ok := res.Val().([]interface{}) + require.True(t, ok, "Expected cms.query to return a slice") + require.Len(t, countSlice, 1, "Expected cms.query to return a single count") + count, ok := countSlice[0].(string) + require.True(t, ok, "Expected count to be a string") + require.Equal(t, expected, count, fmt.Sprintf("Count for key '%s' mismatch", key)) + } + + // Verify total count + infoRes := rdb.Do(ctx, "cms.info", "cmsA") + require.NoError(t, infoRes.Err()) + infoSlice, ok := infoRes.Val().([]interface{}) + require.True(t, ok, "Expected cms.info to return a slice") + + // Convert the slice to a map for easier access + infoMap := make(map[string]interface{}) + for i := 0; i < len(infoSlice); i += 2 { + key, ok1 := infoSlice[i].(string) + value, ok2 := infoSlice[i+1].(int64) + require.True(t, ok1 && ok2, "Expected cms.info keys to be strings and values to be int64") + infoMap[key] = value + } + + total := int64(0) + for _, cntStr := range elements { + cnt, err := strconv.ParseInt(cntStr, 10, 64) + require.NoError(t, err, "Failed to parse count string to int64") + total += cnt + } + require.Equal(t, total, infoMap["count"], "Total count mismatch") + }) + + // Increment operation on a non-existent CMS + t.Run("cms.incrby - Increment Non-Existent CMS", func(t *testing.T) { + res := rdb.Do(ctx, "cms.incrby", "nonexistent_cms", "apple", "5") + require.Error(t, res.Err()) + }) + + t.Run("cms.query - Query Non-Existent CMS", func(t *testing.T) { + // Attempt to query a CMS that doesn't exist + res := rdb.Do(ctx, "cms.query", "nonexistent_cms", "foo") + require.Error(t, res.Err()) + require.Contains(t, res.Err().Error(), "ERR NotFound:") + }) + + // Query for non-existent element + t.Run("cms.query - Query Non-Existent Element", func(t *testing.T) { + require.NoError(t, rdb.Do(ctx, "DEL", "cmsA").Err()) + res := rdb.Do(ctx, "cms.initbydim", "cmsA", "100", "10") + require.NoError(t, res.Err()) + require.Equal(t, "OK", res.Val()) + + // Query a non-existent element + res = rdb.Do(ctx, "cms.query", "cmsA", "nonexistent") + require.NoError(t, res.Err()) + countSlice, ok := res.Val().([]interface{}) + require.True(t, ok, "Expected cms.query to return a slice") + require.Len(t, countSlice, 1, "Expected cms.query to return a single count") + count, ok := countSlice[0].(string) + require.True(t, ok, "Expected count to be a string") + require.Equal(t, "0", count, "Non-existent element should return count '0'") + }) + + // Merging CMS structures + t.Run("cms.merge - Basic Merge Operation", func(t *testing.T) { + require.NoError(t, rdb.Do(ctx, "DEL", "cmsA").Err()) + require.NoError(t, rdb.Do(ctx, "DEL", "cmsB").Err()) + + res := rdb.Do(ctx, "cms.initbydim", "cmsA", "100", "10") + require.NoError(t, res.Err()) + require.Equal(t, "OK", res.Val()) + + res = rdb.Do(ctx, "cms.initbydim", "cmsB", "100", "10") + require.NoError(t, res.Err()) + require.Equal(t, "OK", res.Val()) + + // Increment elements in cmsA + elementsA := map[string]string{"apple": "7", "orange": "15", "mango": "3"} + for key, count := range elementsA { + res = rdb.Do(ctx, "cms.incrby", "cmsA", key, count) + require.NoError(t, res.Err()) + require.Equal(t, "OK", res.Val()) + } + + // Increment elements in cmsB + elementsB := map[string]string{"banana": "5", "apple": "4", "grape": "6"} + for key, count := range elementsB { + res = rdb.Do(ctx, "cms.incrby", "cmsB", key, count) + require.NoError(t, res.Err()) + require.Equal(t, "OK", res.Val()) + } + + // Merge cmsB into cmsA with weights + res = rdb.Do(ctx, "cms.merge", "cmsA", "1", "cmsB", "WEIGHTS", "1") + require.NoError(t, res.Err()) + require.Equal(t, "OK", res.Val()) + + // Query counts after merge + expectedCounts := map[string]string{"apple": "11", "orange": "15", "mango": "3", "banana": "5", "grape": "6"} + for key, expected := range expectedCounts { + res = rdb.Do(ctx, "cms.query", "cmsA", key) + require.NoError(t, res.Err()) + countSlice, ok := res.Val().([]interface{}) + require.True(t, ok, "Expected cms.query to return a slice") + require.Len(t, countSlice, 1, "Expected cms.query to return a single count") + count, ok := countSlice[0].(string) + require.True(t, ok, "Expected count to be a string") + require.Equal(t, expected, count, fmt.Sprintf("Count for key '%s' mismatch after merge", key)) + } + + infoRes := rdb.Do(ctx, "cms.info", "cmsA") + require.NoError(t, infoRes.Err()) + infoSlice, ok := infoRes.Val().([]interface{}) + require.True(t, ok, "Expected cms.info to return a slice") + + infoMap := make(map[string]interface{}) + for i := 0; i < len(infoSlice); i += 2 { + key, ok1 := infoSlice[i].(string) + value, ok2 := infoSlice[i+1].(int64) + require.True(t, ok1 && ok2, "Expected cms.info keys to be strings and values to be int64") + infoMap[key] = value + } + + expectedTotal := int64(40) + require.Equal(t, expectedTotal, infoMap["count"], "Total count mismatch after merge") + }) + + t.Run("cms.merge - Merge with Uninitialized Destination CMS", func(t *testing.T) { + require.NoError(t, rdb.Do(ctx, "DEL", "cmsA").Err()) + require.NoError(t, rdb.Do(ctx, "DEL", "cmsB").Err()) + + // Initialize only the source CMS + res := rdb.Do(ctx, "cms.initbydim", "cmsB", "100", "10") + require.NoError(t, res.Err()) + require.Equal(t, "OK", res.Val()) + + // Attempt to merge cmsB into cmsA without initializing cmsA + res = rdb.Do(ctx, "cms.merge", "cmsA", "1", "cmsB", "WEIGHTS", "1") + require.Error(t, res.Err(), "Merging into an uninitialized destination CMS should return an error") + require.Contains(t, res.Err().Error(), "Destination CMS does not exist.", "Expected error message to contain 'Destination CMS does not exist.'") + }) + + t.Run("cms.merge - Merge with Uninitialized Source CMS", func(t *testing.T) { + require.NoError(t, rdb.Do(ctx, "DEL", "cmsA").Err()) + require.NoError(t, rdb.Do(ctx, "DEL", "cmsB").Err()) + + // Initialize only the destination CMS + res := rdb.Do(ctx, "cms.initbydim", "cmsA", "100", "10") + require.NoError(t, res.Err()) + require.Equal(t, "OK", res.Val()) + + // Attempt to merge a non-initialized cmsB into cmsA + res = rdb.Do(ctx, "cms.merge", "cmsA", "1", "cmsB", "WEIGHTS", "1") + require.Error(t, res.Err(), "Merging from an uninitialized source CMS should return an error") + require.Contains(t, res.Err().Error(), "Source CMS key not found.", "Expected error message to contain 'Source CMS key not found.'") + }) + + t.Run("cms.merge - Merge with Both Destination and Source CMS Uninitialized", func(t *testing.T) { + require.NoError(t, rdb.Do(ctx, "DEL", "cmsA").Err()) + require.NoError(t, rdb.Do(ctx, "DEL", "cmsB").Err()) + + // Attempt to merge two non-initialized CMSes + res := rdb.Do(ctx, "cms.merge", "cmsA", "1", "cmsB", "WEIGHTS", "1") + require.Error(t, res.Err(), "Merging with both destination and source CMS uninitialized should return an error") + errMsg := res.Err().Error() + require.Contains(t, errMsg, "Destination CMS does not exist.") + }) +}