From 0f7c0384c626d3cf0f2f3a9f030c483aed86d255 Mon Sep 17 00:00:00 2001 From: Louis Gandelin Date: Thu, 18 May 2023 14:07:03 +0200 Subject: [PATCH] Added ZMScore command Added tests in integration --- cmd_sorted_set.go | 44 ++++++++++++++++++++ cmd_sorted_set_test.go | 73 ++++++++++++++++++++++++++++++++++ db.go | 10 +++++ direct.go | 18 +++++++++ integration/sorted_set_test.go | 21 ++++++++++ 5 files changed, 166 insertions(+) diff --git a/cmd_sorted_set.go b/cmd_sorted_set.go index ab25c257..8e7e4ee8 100644 --- a/cmd_sorted_set.go +++ b/cmd_sorted_set.go @@ -33,6 +33,7 @@ func commandsSortedSet(m *Miniredis) { m.srv.Register("ZREVRANGEBYSCORE", m.makeCmdZrangebyscore(true)) m.srv.Register("ZREVRANK", m.makeCmdZrank(true)) m.srv.Register("ZSCORE", m.cmdZscore) + m.srv.Register("ZMSCORE", m.cmdZMscore) m.srv.Register("ZUNION", m.cmdZunion) m.srv.Register("ZUNIONSTORE", m.cmdZunionstore) m.srv.Register("ZSCAN", m.cmdZscan) @@ -1044,6 +1045,49 @@ func (m *Miniredis) cmdZscore(c *server.Peer, cmd string, args []string) { }) } +// ZMSCORE +func (m *Miniredis) cmdZMscore(c *server.Peer, cmd string, args []string) { + if len(args) < 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + if m.checkPubsub(c, cmd) { + return + } + + key, members := args[0], args[1:] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if !db.exists(key) { + c.WriteLen(len(members)) + for range members { + c.WriteNull() + } + return + } + + if db.t(key) != "zset" { + c.WriteError(ErrWrongType.Error()) + return + } + + c.WriteLen(len(args) - 1) + for _, member := range members { + if !db.ssetExists(key, member) { + c.WriteNull() + continue + } + c.WriteFloat(db.ssetScore(key, member)) + } + }) +} + // parseFloatRange handles ZRANGEBYSCORE floats. They are inclusive unless the // string starts with '(' func parseFloatRange(s string) (float64, bool, error) { diff --git a/cmd_sorted_set_test.go b/cmd_sorted_set_test.go index f6404715..7cb73c37 100644 --- a/cmd_sorted_set_test.go +++ b/cmd_sorted_set_test.go @@ -1109,6 +1109,79 @@ func TestSortedSetScore(t *testing.T) { }) } +// Test ZMSCORE +func TestSortedSetMultiScore(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := proto.Dial(s.Addr()) + ok(t, err) + defer c.Close() + + s.ZAdd("z", 1, "one") + s.ZAdd("z", 2, "two") + s.ZAdd("z", 2, "zwei") + + // One member only + mustDo(t, c, + "ZMSCORE", "z", "two", + proto.Strings("2"), + ) + + // Two members + mustDo(t, c, + "ZMSCORE", "z", "one", "two", + proto.Strings("1", "2"), + ) + + // Three members + mustDo(t, c, + "ZMSCORE", "z", "one", "two", "zwei", + proto.Strings("1", "2", "2"), + ) + + // No such member + mustDo(t, c, "ZMSCORE", "z", "nosuch", + proto.Array(proto.Nil), + ) + + // One member exists, one doesn't + mustDo(t, c, "ZMSCORE", "z", "nosuch", "two", + proto.Array(proto.Nil, proto.String("2")), + ) + + // No such key + mustNil(t, c, + "ZSCORE", "nosuch", "nosuch", + ) + + // Direct + { + s.ZAdd("z2", 1, "one") + s.ZAdd("z2", 2, "two") + scores, err := s.ZMScore("z2", "one", "two") + ok(t, err) + equals(t, []float64{1.0, 2.0}, scores) + } + + t.Run("errors", func(t *testing.T) { + mustDo(t, c, + "ZMSCORE", + proto.Error(errWrongNumber("zmscore")), + ) + mustDo(t, c, + "ZMSCORE", "key", + proto.Error(errWrongNumber("zmscore")), + ) + + s.Set("str", "value") + mustDo(t, c, + "ZMSCORE", "str", "aap", + proto.Error(msgWrongType), + ) + }) +} + // Test ZRANGEBYLEX, ZREVRANGEBYLEX, ZLEXCOUNT func TestSortedSetRangeByLex(t *testing.T) { s, err := Run() diff --git a/db.go b/db.go index d6df96a2..388d5fbb 100644 --- a/db.go +++ b/db.go @@ -468,6 +468,16 @@ func (db *RedisDB) ssetScore(key, member string) float64 { return ss[member] } +// ssetMScore returns multiple scores of a list of members in a sorted set. +func (db *RedisDB) ssetMScore(key string, members []string) []float64 { + scores := make([]float64, 0, len(members)) + ss := db.sortedsetKeys[key] + for _, member := range members { + scores = append(scores, ss[member]) + } + return scores +} + // ssetRem is sorted set key delete. func (db *RedisDB) ssetRem(key, member string) bool { ss := db.sortedsetKeys[key] diff --git a/direct.go b/direct.go index cd2323c6..607c4e08 100644 --- a/direct.go +++ b/direct.go @@ -666,6 +666,24 @@ func (db *RedisDB) ZScore(k, member string) (float64, error) { return db.ssetScore(k, member), nil } +// ZScore gives scores of a list of members in a sorted set. +func (m *Miniredis) ZMScore(k string, members ...string) ([]float64, error) { + return m.DB(m.selectedDB).ZMScore(k, members) +} + +func (db *RedisDB) ZMScore(k string, members []string) ([]float64, error) { + db.master.Lock() + defer db.master.Unlock() + + if !db.exists(k) { + return nil, ErrKeyNotFound + } + if db.t(k) != "zset" { + return nil, ErrWrongType + } + return db.ssetMScore(k, members), nil +} + // XAdd adds an entry to a stream. `id` can be left empty or be '*'. // If a value is given normal XADD rules apply. Values should be an even // length. diff --git a/integration/sorted_set_test.go b/integration/sorted_set_test.go index 6633f90a..e6312d56 100644 --- a/integration/sorted_set_test.go +++ b/integration/sorted_set_test.go @@ -891,3 +891,24 @@ func TestZrandmember(t *testing.T) { c.Error("not an integer", "ZRANDMEMBER", "q", "two") }) } + +func TestZMScore(t *testing.T) { + testRaw(t, func(c *client) { + c.Do("ZADD", "q", "1.0", "key1") + c.Do("ZADD", "q", "2.0", "key2") + c.Do("ZADD", "q", "3.0", "key3") + c.Do("ZADD", "q", "4.0", "key4") + c.Do("ZADD", "q", "5.0", "key5") + + c.Do("ZMSCORE", "q", "key1") + c.Do("ZMSCORE", "q", "key1 key2 key3") + c.Do("ZMSCORE", "q", "nosuch") + c.Do("ZMSCORE", "nosuch", "key1") + c.Do("ZMSCORE", "nosuch", "key1", "key2") + + // failure cases + c.Error("wrong number", "ZMSCORE", "q") + c.Do("SET", "str", "I am a string") + c.Error("wrong kind", "ZMSCORE", "str", "key1") + }) +}