Skip to content

Commit

Permalink
Fix concurrent map read and map write in short page lookups
Browse files Browse the repository at this point in the history
Regression introduced in Hugo `v0.137.0`.

Fixes gohugoio#13019
  • Loading branch information
bep committed Nov 6, 2024
1 parent 2c3efc8 commit e90ce1c
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 49 deletions.
54 changes: 50 additions & 4 deletions common/maps/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@

package maps

import "sync"
import (
"sync"
)

// Cache is a simple thread safe cache backed by a map.
type Cache[K comparable, T any] struct {
m map[K]T
m map[K]T
hasBeenInitialized bool
sync.RWMutex
}

Expand All @@ -34,11 +37,16 @@ func (c *Cache[K, T]) Get(key K) (T, bool) {
return zero, false
}
c.RLock()
v, found := c.m[key]
v, found := c.get(key)
c.RUnlock()
return v, found
}

func (c *Cache[K, T]) get(key K) (T, bool) {
v, found := c.m[key]
return v, found
}

// GetOrCreate gets the value for the given key if it exists, or creates it if not.
func (c *Cache[K, T]) GetOrCreate(key K, create func() (T, error)) (T, error) {
c.RLock()
Expand All @@ -61,13 +69,49 @@ func (c *Cache[K, T]) GetOrCreate(key K, create func() (T, error)) (T, error) {
return v, nil
}

// InitAndGet initializes the cache if not already done and returns the value for the given key.
// The init state will be reset on Reset or Drain.
func (c *Cache[K, T]) InitAndGet(key K, init func(get func(key K) (T, bool), set func(key K, value T)) error) (T, error) {
var v T
c.RLock()
if !c.hasBeenInitialized {
c.RUnlock()
if err := func() error {
c.Lock()
defer c.Unlock()
// Double check in case another goroutine has initialized it in the meantime.
if !c.hasBeenInitialized {
err := init(c.get, c.set)
if err != nil {
return err
}
c.hasBeenInitialized = true
}
return nil
}(); err != nil {
return v, err
}
// Reacquire the read lock.
c.RLock()
}

v = c.m[key]
c.RUnlock()

return v, nil
}

// Set sets the given key to the given value.
func (c *Cache[K, T]) Set(key K, value T) {
c.Lock()
c.m[key] = value
c.set(key, value)
c.Unlock()
}

func (c *Cache[K, T]) set(key K, value T) {
c.m[key] = value
}

// ForEeach calls the given function for each key/value pair in the cache.
func (c *Cache[K, T]) ForEeach(f func(K, T)) {
c.RLock()
Expand All @@ -81,6 +125,7 @@ func (c *Cache[K, T]) Drain() map[K]T {
c.Lock()
m := c.m
c.m = make(map[K]T)
c.hasBeenInitialized = false
c.Unlock()
return m
}
Expand All @@ -94,6 +139,7 @@ func (c *Cache[K, T]) Len() int {
func (c *Cache[K, T]) Reset() {
c.Lock()
c.m = make(map[K]T)
c.hasBeenInitialized = false
c.Unlock()
}

Expand Down
76 changes: 37 additions & 39 deletions hugolib/content_map_page.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ import (
"github.com/gohugoio/hugo/hugolib/doctree"
"github.com/gohugoio/hugo/hugolib/pagesfromdata"
"github.com/gohugoio/hugo/identity"
"github.com/gohugoio/hugo/lazy"
"github.com/gohugoio/hugo/media"
"github.com/gohugoio/hugo/output"
"github.com/gohugoio/hugo/resources"
Expand Down Expand Up @@ -925,59 +924,58 @@ func newPageMap(i int, s *Site, mcache *dynacache.Cache, pageTrees *pageTrees) *
s: s,
}

m.pageReverseIndex = &contentTreeReverseIndex{
initFn: func(rm map[any]contentNodeI) {
add := func(k string, n contentNodeI) {
existing, found := rm[k]
if found && existing != ambiguousContentNode {
rm[k] = ambiguousContentNode
} else if !found {
rm[k] = n
}
m.pageReverseIndex = newContentTreeTreverseIndex(func(get func(key any) (contentNodeI, bool), set func(key any, val contentNodeI)) {
add := func(k string, n contentNodeI) {
existing, found := get(k)
if found && existing != ambiguousContentNode {
set(k, ambiguousContentNode)
} else if !found {
set(k, n)
}
}

w := &doctree.NodeShiftTreeWalker[contentNodeI]{
Tree: m.treePages,
LockType: doctree.LockTypeRead,
Handle: func(s string, n contentNodeI, match doctree.DimensionFlag) (bool, error) {
p := n.(*pageState)
if p.PathInfo() != nil {
add(p.PathInfo().BaseNameNoIdentifier(), p)
}
return false, nil
},
}
w := &doctree.NodeShiftTreeWalker[contentNodeI]{
Tree: m.treePages,
LockType: doctree.LockTypeRead,
Handle: func(s string, n contentNodeI, match doctree.DimensionFlag) (bool, error) {
p := n.(*pageState)
if p.PathInfo() != nil {
add(p.PathInfo().BaseNameNoIdentifier(), p)
}
return false, nil
},
}

if err := w.Walk(context.Background()); err != nil {
panic(err)
}
},
contentTreeReverseIndexMap: &contentTreeReverseIndexMap{},
}
if err := w.Walk(context.Background()); err != nil {
panic(err)
}
})

return m
}

func newContentTreeTreverseIndex(init func(get func(key any) (contentNodeI, bool), set func(key any, val contentNodeI))) *contentTreeReverseIndex {
return &contentTreeReverseIndex{
initFn: init,
mm: maps.NewCache[any, contentNodeI](),
}
}

type contentTreeReverseIndex struct {
initFn func(rm map[any]contentNodeI)
*contentTreeReverseIndexMap
initFn func(get func(key any) (contentNodeI, bool), set func(key any, val contentNodeI))
mm *maps.Cache[any, contentNodeI]
}

func (c *contentTreeReverseIndex) Reset() {
c.init.ResetWithLock().Unlock()
c.mm.Reset()
}

func (c *contentTreeReverseIndex) Get(key any) contentNodeI {
c.init.Do(func() {
c.m = make(map[any]contentNodeI)
c.initFn(c.contentTreeReverseIndexMap.m)
v, _ := c.mm.InitAndGet(key, func(get func(key any) (contentNodeI, bool), set func(key any, val contentNodeI)) error {
c.initFn(get, set)
return nil
})
return c.m[key]
}

type contentTreeReverseIndexMap struct {
init lazy.OnceMore
m map[any]contentNodeI
return v
}

type sitePagesAssembler struct {
Expand Down
76 changes: 76 additions & 0 deletions hugolib/content_map_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ import (
"fmt"
"path/filepath"
"strings"
"sync"
"testing"

qt "github.com/frankban/quicktest"
"github.com/gohugoio/hugo/identity"
)

func TestContentMapSite(t *testing.T) {
Expand Down Expand Up @@ -396,3 +398,77 @@ irrelevant
"<loc>https://example.org/en/sitemap.xml</loc>",
)
}

func TestContentTreeReverseIndex(t *testing.T) {
t.Parallel()

c := qt.New(t)

pageReverseIndex := newContentTreeTreverseIndex(
func(get func(key any) (contentNodeI, bool), set func(key any, val contentNodeI)) {
for i := 0; i < 10; i++ {
key := fmt.Sprint(i)
set(key, &testContentNode{key: key})
}
},
)

for i := 0; i < 10; i++ {
key := fmt.Sprint(i)
v := pageReverseIndex.Get(key)
c.Assert(v, qt.Not(qt.IsNil))
c.Assert(v.Path(), qt.Equals, key)
}
}

// Issue 13019.
func TestContentTreeReverseIndexPara(t *testing.T) {
t.Parallel()

var wg sync.WaitGroup

for i := 0; i < 10; i++ {
pageReverseIndex := newContentTreeTreverseIndex(
func(get func(key any) (contentNodeI, bool), set func(key any, val contentNodeI)) {
for i := 0; i < 10; i++ {
key := fmt.Sprint(i)
set(key, &testContentNode{key: key})
}
},
)

for j := 0; j < 10; j++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
pageReverseIndex.Get(fmt.Sprint(i))
}(j)
}
}
}

type testContentNode struct {
key string
}

func (n *testContentNode) GetIdentity() identity.Identity {
return identity.StringIdentity(n.key)
}

func (n *testContentNode) ForEeachIdentity(cb func(id identity.Identity) bool) bool {
panic("not supported")
}

func (n *testContentNode) Path() string {
return n.key
}

func (n *testContentNode) isContentNodeBranch() bool {
return false
}

func (n *testContentNode) resetBuildState() {
}

func (n *testContentNode) MarkStale() {
}
2 changes: 1 addition & 1 deletion lazy/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ type Init struct {
prev *Init
children []*Init

init OnceMore
init onceMore
out any
err error
f func(context.Context) (any, error)
Expand Down
10 changes: 5 additions & 5 deletions lazy/once.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ import (
// * it can be reset, so the action can be repeated if needed
// * it has methods to check if it's done or in progress

type OnceMore struct {
type onceMore struct {
mu sync.Mutex
lock uint32
done uint32
}

func (t *OnceMore) Do(f func()) {
func (t *onceMore) Do(f func()) {
if atomic.LoadUint32(&t.done) == 1 {
return
}
Expand All @@ -53,15 +53,15 @@ func (t *OnceMore) Do(f func()) {
f()
}

func (t *OnceMore) InProgress() bool {
func (t *onceMore) InProgress() bool {
return atomic.LoadUint32(&t.lock) == 1
}

func (t *OnceMore) Done() bool {
func (t *onceMore) Done() bool {
return atomic.LoadUint32(&t.done) == 1
}

func (t *OnceMore) ResetWithLock() *sync.Mutex {
func (t *onceMore) ResetWithLock() *sync.Mutex {
t.mu.Lock()
defer atomic.StoreUint32(&t.done, 0)
return &t.mu
Expand Down

0 comments on commit e90ce1c

Please sign in to comment.