Skip to content

Commit

Permalink
refactor(dao): improve QueryDBWithCache func and concurrent unit te…
Browse files Browse the repository at this point in the history
…sts (#884)
  • Loading branch information
qwqcode authored May 26, 2024
1 parent abbddfe commit d2161cc
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 131 deletions.
54 changes: 27 additions & 27 deletions internal/cache/action.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cache

import (
"fmt"
"reflect"

"github.com/ArtalkJS/Artalk/internal/log"
Expand All @@ -9,50 +10,49 @@ import (
)

var (
cacheFindStoreGroup = new(singleflight.Group)
cacheSingleflightGroup = new(singleflight.Group)
)

func (c *Cache) QueryDBWithCache(name string, dest any, queryDB func()) error {
if reflect.TypeOf(dest).Kind() != reflect.Ptr {
panic("The 'dest' param in 'QueryDBWithCache' func is expected to pointer type to update its data.")
}
func QueryDBWithCache[T any](c *Cache, name string, queryDB func() (T, error)) (T, error) {
// Use SingleFlight to prevent Cache Breakdown
v, err, _ := cacheSingleflightGroup.Do(name, func() (any, error) {
var val T

// use SingleFlight to prevent Cache Breakdown
v, err, _ := cacheFindStoreGroup.Do(name, func() (any, error) {
// query cache
err := c.FindCache(name, dest)
// Query from cache
err := c.FindCache(name, &val)

if err != nil { // cache miss
if err != nil {
// Miss cache

// call queryDB() the dest value will be updated
queryDB()
// Query from db
val, err := queryDB()
if err != nil {
return nil, err
}

if err := c.StoreCache(dest, name); err != nil {
// Store cache
if err := c.StoreCache(val, name); err != nil {
return nil, err
}

// because queryDB() had update dest value,
// no need to update it again, so return nil
return nil, nil
return val, nil
} else {
// Hit cache
return val, nil
}

return dest, nil
})

if err != nil {
return err
return *new(T), err
}

// update dest value only if cache hit,
// if cache miss, `dest` has been updated in `queryDB()` so no need to set again
if v != nil {
reflect.ValueOf(dest).Elem().Set(reflect.ValueOf(v).Elem()) // similar to `*dest = &v`
}

return nil
return v.(T), err
}

func (c *Cache) FindCache(name string, dest any) error {
if reflect.ValueOf(dest).Kind() != reflect.Ptr {
return fmt.Errorf("[FindCache] dest must be a pointer")
}

// `Get()` is Thread Safe, so no need to add Mutex
// @see https://github.com/go-redis/redis/issues/23
_, err := c.marshal.Get(c.ctx, name, dest)
Expand Down
108 changes: 68 additions & 40 deletions internal/cache/action_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ package cache_test
import (
"fmt"
"sync"
"sync/atomic"
"testing"

"github.com/ArtalkJS/Artalk/internal/cache"
"github.com/stretchr/testify/assert"
)

func TestAction(t *testing.T) {
cache := newCache(t)
cache := newTestCache(t)
defer cache.Close()

doCrudTest := func(t *testing.T, testKey string, testData any) {
Expand Down Expand Up @@ -85,63 +87,89 @@ func TestAction(t *testing.T) {
}

func TestQueryDBWithCache(t *testing.T) {
cache := newCache(t)
defer cache.Close()

type user struct {
Name string
Email string
}
t.Run("Simple", func(t *testing.T) {
cacheInstance := newTestCache(t)
defer cacheInstance.Close()

const key = "data_key_233"
var value = user{
Name: "qwqcode",
Email: "[email protected]",
}
type user struct {
Name string
Email string
}

doCachedFind := func() bool {
var data user
dbQueried := false
const key = "data_key_233"
var value = user{
Name: "qwqcode",
Email: "[email protected]",
}

err := cache.QueryDBWithCache(key, &data, func() {
// simulate db query result
data = user{
Name: value.Name,
Email: value.Email,
doCachedFind := func() bool {
dbQueried := false

data, err := cache.QueryDBWithCache(cacheInstance, key, func() (data user, err error) {
// simulate db query result
data = user{
Name: value.Name,
Email: value.Email,
}
dbQueried = true
return data, nil
})

if assert.NoError(t, err) {
assert.Equal(t, value, data)
}
dbQueried = true
})

if assert.NoError(t, err) {
assert.Equal(t, value, data)
return dbQueried
}

return dbQueried
}

if dbQueried := doCachedFind(); dbQueried {
assert.True(t, dbQueried, "first call `QueryDBWithCache`, db should be queried")
}
if dbQueried := doCachedFind(); dbQueried {
assert.True(t, dbQueried, "first call `QueryDBWithCache`, db should be queried")
}

if dbQueried := doCachedFind(); dbQueried {
assert.False(t, dbQueried, "second call `QueryDBWithCache`, db should not be queried")
}
if dbQueried := doCachedFind(); dbQueried {
assert.False(t, dbQueried, "second call `QueryDBWithCache`, db should not be queried")
}
})

t.Run("Concurrency", func(t *testing.T) {
const numRoutines = 100
cacheInstance := newTestCache(t)
defer cacheInstance.Close()

var wg sync.WaitGroup
wg.Add(numRoutines)
type mockStruct struct {
Value string
}

for i := 0; i < numRoutines; i++ {
wg := sync.WaitGroup{}
ready := make(chan struct{})
findCallTimes := int32(0)

const n = 1000
for i := 0; i < n; i++ {
wg.Add(1)
go func() {
defer wg.Done()
<-ready // make sure all goroutines start at the same time

dbQueried := doCachedFind()
assert.False(t, dbQueried)
data, err := cache.QueryDBWithCache(cacheInstance, "key", func() (data mockStruct, err error) {
atomic.AddInt32(&findCallTimes, 1)

return mockStruct{
Value: "concurrency_value",
}, nil
})

if assert.NoError(t, err) {
assert.Equal(t, "concurrency_value", data.Value, "data should be equal to the value returned by the query function")
}
}()
}

close(ready) // start all goroutines at the same time

wg.Wait()

if got := atomic.LoadInt32(&findCallTimes); got != 1 {
t.Errorf("expected findCallTimes to be 1, got %d", got)
}
})
}
4 changes: 2 additions & 2 deletions internal/cache/base_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"github.com/stretchr/testify/assert"
)

func newCache(t *testing.T) *cache.Cache {
func newTestCache(t *testing.T) *cache.Cache {
cache, err := cache.New(config.CacheConf{
Enabled: true,
Type: config.CacheTypeBuiltin,
Expand All @@ -20,7 +20,7 @@ func newCache(t *testing.T) *cache.Cache {
}

func TestNew(t *testing.T) {
cache := newCache(t)
cache := newTestCache(t)
defer cache.Close()

assert.NotNil(t, cache)
Expand Down
71 changes: 30 additions & 41 deletions internal/dao/query_find.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,26 @@ import (
"fmt"
"strings"

"github.com/ArtalkJS/Artalk/internal/cache"
"github.com/ArtalkJS/Artalk/internal/entity"
)

func (dao *Dao) QueryDBWithCache(name string, dest any, queryDB func()) error {
func QueryDBWithCache[T any](dao *Dao, name string, queryDB func() (T, error)) (T, error) {
if dao.cache == nil {
// directly call queryDB while cache is disabled
queryDB()
return nil

// gorm db query is thread-safe almost
// see "Method Chain Safety/Goroutine Safety" in https://gorm.io/docs/v2_release_note.html#Method-Chain-Safety-x2F-Goroutine-Safety
return queryDB()
}

return dao.cache.QueryDBWithCache(name, dest, queryDB)
return cache.QueryDBWithCache(dao.cache.Cache, name, queryDB)
}

func (dao *Dao) FindComment(id uint, checkers ...func(*entity.Comment) bool) entity.Comment {
var comment entity.Comment

dao.QueryDBWithCache(fmt.Sprintf(CommentByIDKey, id), &comment, func() {
comment, _ := QueryDBWithCache(dao, fmt.Sprintf(CommentByIDKey, id), func() (comment entity.Comment, err error) {
dao.DB().Where("id = ?", id).First(&comment)
return comment, nil
})

// the case with checkers
Expand Down Expand Up @@ -54,10 +56,11 @@ func (dao *Dao) FindCommentRootID(rid uint) uint {
// (Cached:parent-comments)
func (dao *Dao) FindCommentChildrenShallow(parentID uint, checkers ...func(*entity.Comment) bool) []entity.Comment {
var children []entity.Comment
var childIDs []uint

dao.QueryDBWithCache(fmt.Sprintf(CommentChildIDsByIDKey, parentID), &childIDs, func() {
childIDs, _ := QueryDBWithCache(dao, fmt.Sprintf(CommentChildIDsByIDKey, parentID), func() ([]uint, error) {
childIDs := []uint{}
dao.DB().Model(&entity.Comment{}).Where(&entity.Comment{Rid: parentID}).Select("id").Find(&childIDs)
return childIDs, nil
})

for _, childID := range childIDs {
Expand Down Expand Up @@ -88,27 +91,23 @@ func (dao *Dao) _findCommentChildrenOnce(source *[]entity.Comment, parentID uint

// 查找用户 (精确查找 name & email)
func (dao *Dao) FindUser(name string, email string) entity.User {
var user entity.User

// 查询缓存
dao.QueryDBWithCache(fmt.Sprintf(UserByNameEmailKey, strings.ToLower(name), strings.ToLower(email)), &user, func() {
// 不区分大小写
dao.DB().Where("LOWER(name) = LOWER(?) AND LOWER(email) = LOWER(?)", name, email).First(&user)
user, _ := QueryDBWithCache(dao, fmt.Sprintf(UserByNameEmailKey, strings.ToLower(name), strings.ToLower(email)), func() (user entity.User, err error) {
dao.DB().Where("LOWER(name) = LOWER(?) AND LOWER(email) = LOWER(?)", name, email).First(&user) // 不区分大小写
return user, nil
})

return user
}

// 查找用户 ID (仅根据 email)
func (dao *Dao) FindUserIdsByEmail(email string) []uint {
var userIds = []uint{}

// 查询缓存
dao.QueryDBWithCache(fmt.Sprintf(UserIDByEmailKey, strings.ToLower(email)), &userIds, func() {
dao.DB().Model(&entity.User{}).Where("LOWER(email) = LOWER(?)", email).Pluck("id", &userIds)
userIDs, _ := QueryDBWithCache(dao, fmt.Sprintf(UserIDByEmailKey, strings.ToLower(email)), func() ([]uint, error) {
userIDs := []uint{}
dao.DB().Model(&entity.User{}).Where("LOWER(email) = LOWER(?)", email).Pluck("id", &userIDs)
return userIDs, nil
})

return userIds
return userIDs
}

// 查找用户 (仅根据 email)
Expand All @@ -125,54 +124,44 @@ func (dao *Dao) FindUsersByEmail(email string) []entity.User {

// 查找用户 (通过 ID)
func (dao *Dao) FindUserByID(id uint) entity.User {
var user entity.User

// 查询缓存
dao.QueryDBWithCache(fmt.Sprintf("user#id=%d", id), &user, func() {
user, _ := QueryDBWithCache(dao, fmt.Sprintf("user#id=%d", id), func() (user entity.User, err error) {
dao.DB().Where("id = ?", id).First(&user)
return user, nil
})

return user
}

func (dao *Dao) FindPage(key string, siteName string) entity.Page {
var page entity.Page

dao.QueryDBWithCache(fmt.Sprintf(PageByKeySiteNameKey, key, siteName), &page, func() {
page, _ := QueryDBWithCache(dao, fmt.Sprintf(PageByKeySiteNameKey, key, siteName), func() (page entity.Page, err error) {
dao.DB().Where(&entity.Page{Key: key, SiteName: siteName}).First(&page)
return page, nil
})

return page
}

func (dao *Dao) FindPageByID(id uint) entity.Page {
var page entity.Page

dao.QueryDBWithCache(fmt.Sprintf(PageByIDKey, id), &page, func() {
page, _ := QueryDBWithCache(dao, fmt.Sprintf(PageByIDKey, id), func() (page entity.Page, err error) {
dao.DB().Where("id = ?", id).First(&page)
return page, nil
})

return page
}

func (dao *Dao) FindSite(name string) entity.Site {
var site entity.Site

// 查询缓存
dao.QueryDBWithCache(fmt.Sprintf(SiteByNameKey, name), &site, func() {
site, _ := QueryDBWithCache(dao, fmt.Sprintf(SiteByNameKey, name), func() (site entity.Site, err error) {
dao.DB().Where("name = ?", name).First(&site)
return site, nil
})

return site
}

func (dao *Dao) FindSiteByID(id uint) entity.Site {
var site entity.Site

dao.QueryDBWithCache(fmt.Sprintf(SiteByIDKey, id), &site, func() {
site, _ := QueryDBWithCache(dao, fmt.Sprintf(SiteByIDKey, id), func() (site entity.Site, err error) {
dao.DB().Where("id = ?", id).First(&site)
return site, nil
})

return site
}

Expand Down
Loading

0 comments on commit d2161cc

Please sign in to comment.