Skip to content

Commit

Permalink
refactor(share/availability): simplify light availability (#3895)
Browse files Browse the repository at this point in the history
  • Loading branch information
walldiss authored Oct 31, 2024
1 parent 7a70dd5 commit 816f46e
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 141 deletions.
8 changes: 0 additions & 8 deletions share/availability/full/availability.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,6 @@ func (fa *ShareAvailability) SharesAvailable(ctx context.Context, header *header
return nil
}

// we assume the caller of this method has already performed basic validation on the
// given roots. If for some reason this has not happened, the node should panic.
if err := dah.ValidateBasic(); err != nil {
log.Errorw("Availability validation cannot be performed on a malformed DataAvailabilityHeader",
"err", err)
panic(err)
}

// a hack to avoid loading the whole EDS in mem if we store it already.
if ok, _ := fa.store.HasByHeight(ctx, header.Height()); ok {
return nil
Expand Down
38 changes: 16 additions & 22 deletions share/availability/light/availability.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package light

import (
"context"
"encoding/json"
"errors"
"fmt"
"sync"
Expand All @@ -17,9 +18,9 @@ import (
)

var (
log = logging.Logger("share/light")
cacheAvailabilityPrefix = datastore.NewKey("sampling_result")
writeBatchSize = 2048
log = logging.Logger("share/light")
samplingResultsPrefix = datastore.NewKey("sampling_result")
writeBatchSize = 2048
)

// ShareAvailability implements share.Availability using Data Availability Sampling technique.
Expand All @@ -30,9 +31,6 @@ type ShareAvailability struct {
getter shwap.Getter
params Parameters

// TODO(@Wondertan): Once we come to parallelized DASer, this lock becomes a contention point
// Related to #483
// TODO: Striped locks? :D
dsLk sync.RWMutex
ds *autobatch.Datastore
}
Expand All @@ -44,7 +42,7 @@ func NewShareAvailability(
opts ...Option,
) *ShareAvailability {
params := *DefaultParameters()
ds = namespace.Wrap(ds, cacheAvailabilityPrefix)
ds = namespace.Wrap(ds, samplingResultsPrefix)
autoDS := autobatch.NewAutoBatching(ds, writeBatchSize)

for _, opt := range opts {
Expand All @@ -68,7 +66,7 @@ func (la *ShareAvailability) SharesAvailable(ctx context.Context, header *header
}

// load snapshot of the last sampling errors from disk
key := rootKey(dah)
key := datastoreKeyForRoot(dah)
la.dsLk.RLock()
last, err := la.ds.Get(ctx, key)
la.dsLk.RUnlock()
Expand All @@ -84,37 +82,30 @@ func (la *ShareAvailability) SharesAvailable(ctx context.Context, header *header
return err
case errors.Is(err, datastore.ErrNotFound):
// No sampling result found, select new samples
samples, err = SampleSquare(len(dah.RowRoots), int(la.params.SampleAmount))
if err != nil {
return err
}
samples = selectRandomSamples(len(dah.RowRoots), int(la.params.SampleAmount))
default:
// Sampling result found, unmarshal it
samples, err = decodeSamples(last)
err = json.Unmarshal(last, &samples)
if err != nil {
return err
}
}

if err := dah.ValidateBasic(); err != nil {
return err
}

var (
failedSamplesLock sync.Mutex
failedSamples []Sample
)

log.Debugw("starting sampling session", "root", dah.String())
log.Debugw("starting sampling session", "height", header.Height())
var wg sync.WaitGroup
for _, s := range samples {
wg.Add(1)
go func(s Sample) {
defer wg.Done()
// check if the sample is available
_, err := la.getter.GetShare(ctx, header, int(s.Row), int(s.Col))
_, err := la.getter.GetShare(ctx, header, s.Row, s.Col)
if err != nil {
log.Debugw("error fetching share", "root", dah.String(), "row", s.Row, "col", s.Col)
log.Debugw("error fetching share", "height", header.Height(), "row", s.Row, "col", s.Col)
failedSamplesLock.Lock()
failedSamples = append(failedSamples, s)
failedSamplesLock.Unlock()
Expand All @@ -124,7 +115,10 @@ func (la *ShareAvailability) SharesAvailable(ctx context.Context, header *header
wg.Wait()

// store the result of the sampling session
bs := encodeSamples(failedSamples)
bs, err := json.Marshal(failedSamples)
if err != nil {
return fmt.Errorf("failed to marshal sampling result: %w", err)
}
la.dsLk.Lock()
err = la.ds.Put(ctx, key, bs)
la.dsLk.Unlock()
Expand All @@ -145,7 +139,7 @@ func (la *ShareAvailability) SharesAvailable(ctx context.Context, header *header
return nil
}

func rootKey(root *share.AxisRoots) datastore.Key {
func datastoreKeyForRoot(root *share.AxisRoots) datastore.Key {
return datastore.NewKey(root.String())
}

Expand Down
56 changes: 30 additions & 26 deletions share/availability/light/availability_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package light
import (
"context"
_ "embed"
"encoding/json"
"sync"
"testing"

Expand All @@ -22,7 +23,7 @@ import (
"github.com/celestiaorg/celestia-node/share/shwap/p2p/shrex"
)

func TestSharesAvailableCaches(t *testing.T) {
func TestSharesAvailableSuccess(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

Expand All @@ -48,27 +49,29 @@ func TestSharesAvailableCaches(t *testing.T) {
ds := datastore.NewMapDatastore()
avail := NewShareAvailability(getter, ds)

// cache doesn't have eds yet
has, err := avail.ds.Has(ctx, rootKey(roots))
// Ensure the datastore doesn't have the sampling result yet
has, err := avail.ds.Has(ctx, datastoreKeyForRoot(roots))
require.NoError(t, err)
require.False(t, has)

err = avail.SharesAvailable(ctx, eh)
require.NoError(t, err)

// is now stored success result
result, err := avail.ds.Get(ctx, rootKey(roots))
// Verify that the sampling result is stored with all samples marked as available
result, err := avail.ds.Get(ctx, datastoreKeyForRoot(roots))
require.NoError(t, err)
failed, err := decodeSamples(result)

var failed []Sample
err = json.Unmarshal(result, &failed)
require.NoError(t, err)
require.Empty(t, failed)
}

func TestSharesAvailableHitsCache(t *testing.T) {
func TestSharesAvailableSkipSampled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

// create getter that always return ErrNotFound
// Create a getter that always returns ErrNotFound
getter := mock.NewMockGetter(gomock.NewController(t))
getter.EXPECT().
GetShare(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Expand All @@ -86,16 +89,19 @@ func TestSharesAvailableHitsCache(t *testing.T) {
err := avail.SharesAvailable(ctx, eh)
require.ErrorIs(t, err, share.ErrNotAvailable)

// put success result in cache
err = avail.ds.Put(ctx, rootKey(roots), []byte{})
// Store a successful sampling result in the datastore
failed := []Sample{}
data, err := json.Marshal(failed)
require.NoError(t, err)
err = avail.ds.Put(ctx, datastoreKeyForRoot(roots), data)
require.NoError(t, err)

// should hit cache after putting
// SharesAvailable should now return no error since the success sampling result is stored
err = avail.SharesAvailable(ctx, eh)
require.NoError(t, err)
}

func TestSharesAvailableEmptyRoot(t *testing.T) {
func TestSharesAvailableEmptyEDS(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

Expand All @@ -117,42 +123,40 @@ func TestSharesAvailableFailed(t *testing.T) {
ds := datastore.NewMapDatastore()
avail := NewShareAvailability(getter, ds)

// create new eds, that is not available by getter
// Create new eds, that is not available by getter
eds := edstest.RandEDS(t, 16)
roots, err := share.NewAxisRoots(eds)
require.NoError(t, err)
eh := headertest.RandExtendedHeaderWithRoot(t, roots)

// getter doesn't have the eds, so it should fail
// Getter doesn't have the eds, so it should fail for all samples
getter.EXPECT().
GetShare(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(libshare.Share{}, shrex.ErrNotFound).
AnyTimes()
err = avail.SharesAvailable(ctx, eh)
require.ErrorIs(t, err, share.ErrNotAvailable)

// cache should have failed results now
result, err := avail.ds.Get(ctx, rootKey(roots))
// The datastore should now contain the sampling result with all samples in Remaining
result, err := avail.ds.Get(ctx, datastoreKeyForRoot(roots))
require.NoError(t, err)

failed, err := decodeSamples(result)
var failed []Sample
err = json.Unmarshal(result, &failed)
require.NoError(t, err)
require.Len(t, failed, int(avail.params.SampleAmount))

// ensure that retry persists the failed samples selection
// create new getter with only the failed samples available, and add them to the onceGetter
onceGetter := newOnceGetter()
onceGetter.AddSamples(failed)

// replace getter with the new one
avail.getter = onceGetter
// Simulate a getter that now returns shares successfully
successfulGetter := newOnceGetter()
successfulGetter.AddSamples(failed)
avail.getter = successfulGetter

// should be able to retrieve all the failed samples now
err = avail.SharesAvailable(ctx, eh)
require.NoError(t, err)

// onceGetter should have no more samples stored after the call
require.Empty(t, onceGetter.available)
require.Empty(t, successfulGetter.available)
}

type onceGetter struct {
Expand All @@ -178,7 +182,7 @@ func (m onceGetter) AddSamples(samples []Sample) {
func (m onceGetter) GetShare(_ context.Context, _ *header.ExtendedHeader, row, col int) (libshare.Share, error) {
m.Lock()
defer m.Unlock()
s := Sample{Row: uint16(row), Col: uint16(col)}
s := Sample{Row: row, Col: col}
if _, ok := m.available[s]; ok {
delete(m.available, s)
return libshare.Share{}, nil
Expand Down
100 changes: 19 additions & 81 deletions share/availability/light/sample.go
Original file line number Diff line number Diff line change
@@ -1,104 +1,42 @@
// TODO(@Wondertan): Instead of doing sampling over the coordinates do a random walk over NMT trees.
package light

import (
crand "crypto/rand"
"encoding/binary"
"errors"
"math/big"

"golang.org/x/exp/maps"
)

// Sample is a point in 2D space over square.
// Sample represents a coordinate in a 2D data square.
type Sample struct {
Row, Col uint16
Row int `json:"row"`
Col int `json:"col"`
}

// SampleSquare randomly picks *num* unique points from the given *width* square
// and returns them as samples.
func SampleSquare(squareWidth, num int) ([]Sample, error) {
ss := newSquareSampler(squareWidth, num)
err := ss.generateSample(num)
if err != nil {
return nil, err
// selectRandomSamples randomly picks unique coordinates from a square of given size.
func selectRandomSamples(squareSize, sampleCount int) []Sample {
total := squareSize * squareSize
if sampleCount > total {
sampleCount = total
}
return ss.samples(), nil
}

type squareSampler struct {
squareWidth int
smpls map[Sample]struct{}
}

func newSquareSampler(squareWidth, expectedSamples int) *squareSampler {
return &squareSampler{
squareWidth: squareWidth,
smpls: make(map[Sample]struct{}, expectedSamples),
}
}

// generateSample randomly picks unique point on a 2D spaces.
func (ss *squareSampler) generateSample(num int) error {
if num > ss.squareWidth*ss.squareWidth {
num = ss.squareWidth
}

done := 0
for done < num {
samples := make(map[Sample]struct{}, sampleCount)
for len(samples) < sampleCount {
s := Sample{
Row: randInt(ss.squareWidth),
Col: randInt(ss.squareWidth),
Row: randInt(squareSize),
Col: randInt(squareSize),
}

if _, ok := ss.smpls[s]; ok {
continue
}

done++
ss.smpls[s] = struct{}{}
samples[s] = struct{}{}
}

return nil
}

func (ss *squareSampler) samples() []Sample {
samples := make([]Sample, 0, len(ss.smpls))
for s := range ss.smpls {
samples = append(samples, s)
}
return samples
return maps.Keys(samples)
}

func randInt(max int) uint16 {
func randInt(max int) int {
n, err := crand.Int(crand.Reader, big.NewInt(int64(max)))
if err != nil {
panic(err) // won't panic as rand.Reader is endless
}

return uint16(n.Uint64())
}

// encodeSamples encodes a slice of samples into a byte slice using little endian encoding.
func encodeSamples(samples []Sample) []byte {
bs := make([]byte, 0, len(samples)*4)
for _, s := range samples {
bs = binary.LittleEndian.AppendUint16(bs, s.Row)
bs = binary.LittleEndian.AppendUint16(bs, s.Col)
}
return bs
}

// decodeSamples decodes a byte slice into a slice of samples.
func decodeSamples(bs []byte) ([]Sample, error) {
if len(bs)%4 != 0 {
return nil, errors.New("invalid byte slice length")
}

samples := make([]Sample, 0, len(bs)/4)
for i := 0; i < len(bs); i += 4 {
samples = append(samples, Sample{
Row: binary.LittleEndian.Uint16(bs[i : i+2]),
Col: binary.LittleEndian.Uint16(bs[i+2 : i+4]),
})
}
return samples, nil
// n.Uint64() is safe as max is int
return int(n.Uint64())
}
Loading

0 comments on commit 816f46e

Please sign in to comment.