From c81592bccf4278933f83ff2c5156d2653981b175 Mon Sep 17 00:00:00 2001 From: Vlad <13818348+walldiss@users.noreply.github.com> Date: Tue, 29 Oct 2024 22:29:31 +0100 Subject: [PATCH] simplify share availability --- share/availability/full/availability.go | 8 -- share/availability/light/availability.go | 41 +++---- share/availability/light/availability_test.go | 56 +++++----- share/availability/light/sample.go | 100 ++++-------------- share/availability/light/sample_test.go | 7 +- 5 files changed, 69 insertions(+), 143 deletions(-) diff --git a/share/availability/full/availability.go b/share/availability/full/availability.go index 91550849c8..4347d35b36 100644 --- a/share/availability/full/availability.go +++ b/share/availability/full/availability.go @@ -50,14 +50,6 @@ func (fa *ShareAvailability) SharesAvailable(ctx context.Context, header *header return nil } - // we assume the caller of this method has already performed basic validation on the - // given roots. If for some reason this has not happened, the node should panic. - if err := dah.ValidateBasic(); err != nil { - log.Errorw("Availability validation cannot be performed on a malformed DataAvailabilityHeader", - "err", err) - panic(err) - } - // a hack to avoid loading the whole EDS in mem if we store it already. if ok, _ := fa.store.HasByHeight(ctx, header.Height()); ok { return nil diff --git a/share/availability/light/availability.go b/share/availability/light/availability.go index def56d346b..94d1df3d7b 100644 --- a/share/availability/light/availability.go +++ b/share/availability/light/availability.go @@ -2,6 +2,7 @@ package light import ( "context" + "encoding/json" "errors" "fmt" "sync" @@ -17,9 +18,9 @@ import ( ) var ( - log = logging.Logger("share/light") - cacheAvailabilityPrefix = datastore.NewKey("sampling_result") - writeBatchSize = 2048 + log = logging.Logger("share/light") + samplingResultsPrefix = datastore.NewKey("sampling_result") + writeBatchSize = 2048 ) // ShareAvailability implements share.Availability using Data Availability Sampling technique. @@ -30,9 +31,6 @@ type ShareAvailability struct { getter shwap.Getter params Parameters - // TODO(@Wondertan): Once we come to parallelized DASer, this lock becomes a contention point - // Related to #483 - // TODO: Striped locks? :D dsLk sync.RWMutex ds *autobatch.Datastore } @@ -44,7 +42,7 @@ func NewShareAvailability( opts ...Option, ) *ShareAvailability { params := *DefaultParameters() - ds = namespace.Wrap(ds, cacheAvailabilityPrefix) + ds = namespace.Wrap(ds, samplingResultsPrefix) autoDS := autobatch.NewAutoBatching(ds, writeBatchSize) for _, opt := range opts { @@ -68,7 +66,7 @@ func (la *ShareAvailability) SharesAvailable(ctx context.Context, header *header } // load snapshot of the last sampling errors from disk - key := rootKey(dah) + key := datastoreKeyForRoot(dah) la.dsLk.RLock() last, err := la.ds.Get(ctx, key) la.dsLk.RUnlock() @@ -84,38 +82,30 @@ func (la *ShareAvailability) SharesAvailable(ctx context.Context, header *header return err case errors.Is(err, datastore.ErrNotFound): // No sampling result found, select new samples - samples, err = SampleSquare(len(dah.RowRoots), int(la.params.SampleAmount)) - if err != nil { - return err - } + samples = selectRandomSamples(len(dah.RowRoots), int(la.params.SampleAmount)) default: // Sampling result found, unmarshal it - samples, err = decodeSamples(last) + err = json.Unmarshal(last, &samples) if err != nil { return err } } - if err := dah.ValidateBasic(); err != nil { - log.Errorw("DAH validation failed", "error", err) - return err - } - var ( failedSamplesLock sync.Mutex failedSamples []Sample ) - log.Debugw("starting sampling session", "root", dah.String()) + log.Debugw("starting sampling session", "height", header.Height()) var wg sync.WaitGroup for _, s := range samples { wg.Add(1) go func(s Sample) { defer wg.Done() // check if the sample is available - _, err := la.getter.GetShare(ctx, header, int(s.Row), int(s.Col)) + _, err := la.getter.GetShare(ctx, header, s.Row, s.Col) if err != nil { - log.Debugw("error fetching share", "root", dah.String(), "row", s.Row, "col", s.Col) + log.Debugw("error fetching share", "height", header.Height(), "row", s.Row, "col", s.Col) failedSamplesLock.Lock() failedSamples = append(failedSamples, s) failedSamplesLock.Unlock() @@ -125,7 +115,10 @@ func (la *ShareAvailability) SharesAvailable(ctx context.Context, header *header wg.Wait() // store the result of the sampling session - bs := encodeSamples(failedSamples) + bs, err := json.Marshal(failedSamples) + if err != nil { + return fmt.Errorf("failed to marshal sampling result: %w", err) + } la.dsLk.Lock() err = la.ds.Put(ctx, key, bs) la.dsLk.Unlock() @@ -142,7 +135,7 @@ func (la *ShareAvailability) SharesAvailable(ctx context.Context, header *header // if any of the samples failed, return an error if len(failedSamples) > 0 { log.Errorw("availability validation failed", - "root", dah.String(), + "height", header.Height(), "failed_samples", failedSamples, ) return share.ErrNotAvailable @@ -150,7 +143,7 @@ func (la *ShareAvailability) SharesAvailable(ctx context.Context, header *header return nil } -func rootKey(root *share.AxisRoots) datastore.Key { +func datastoreKeyForRoot(root *share.AxisRoots) datastore.Key { return datastore.NewKey(root.String()) } diff --git a/share/availability/light/availability_test.go b/share/availability/light/availability_test.go index 3da6cea83d..4c19eca169 100644 --- a/share/availability/light/availability_test.go +++ b/share/availability/light/availability_test.go @@ -3,6 +3,7 @@ package light import ( "context" _ "embed" + "encoding/json" "sync" "testing" @@ -22,7 +23,7 @@ import ( "github.com/celestiaorg/celestia-node/share/shwap/p2p/shrex" ) -func TestSharesAvailableCaches(t *testing.T) { +func TestSharesAvailableSuccess(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -48,27 +49,29 @@ func TestSharesAvailableCaches(t *testing.T) { ds := datastore.NewMapDatastore() avail := NewShareAvailability(getter, ds) - // cache doesn't have eds yet - has, err := avail.ds.Has(ctx, rootKey(roots)) + // Ensure the datastore doesn't have the sampling result yet + has, err := avail.ds.Has(ctx, datastoreKeyForRoot(roots)) require.NoError(t, err) require.False(t, has) err = avail.SharesAvailable(ctx, eh) require.NoError(t, err) - // is now stored success result - result, err := avail.ds.Get(ctx, rootKey(roots)) + // Verify that the sampling result is stored with all samples marked as available + result, err := avail.ds.Get(ctx, datastoreKeyForRoot(roots)) require.NoError(t, err) - failed, err := decodeSamples(result) + + var failed []Sample + err = json.Unmarshal(result, &failed) require.NoError(t, err) require.Empty(t, failed) } -func TestSharesAvailableHitsCache(t *testing.T) { +func TestSharesAvailableSkipSampled(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - // create getter that always return ErrNotFound + // Create a getter that always returns ErrNotFound getter := mock.NewMockGetter(gomock.NewController(t)) getter.EXPECT(). GetShare(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). @@ -86,16 +89,19 @@ func TestSharesAvailableHitsCache(t *testing.T) { err := avail.SharesAvailable(ctx, eh) require.ErrorIs(t, err, share.ErrNotAvailable) - // put success result in cache - err = avail.ds.Put(ctx, rootKey(roots), []byte{}) + // Store a successful sampling result in the datastore + failed := []Sample{} + data, err := json.Marshal(failed) + require.NoError(t, err) + err = avail.ds.Put(ctx, datastoreKeyForRoot(roots), data) require.NoError(t, err) - // should hit cache after putting + // SharesAvailable should now return no error since the success sampling result is stored err = avail.SharesAvailable(ctx, eh) require.NoError(t, err) } -func TestSharesAvailableEmptyRoot(t *testing.T) { +func TestSharesAvailableEmptyEDS(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -117,13 +123,13 @@ func TestSharesAvailableFailed(t *testing.T) { ds := datastore.NewMapDatastore() avail := NewShareAvailability(getter, ds) - // create new eds, that is not available by getter + // Create new eds, that is not available by getter eds := edstest.RandEDS(t, 16) roots, err := share.NewAxisRoots(eds) require.NoError(t, err) eh := headertest.RandExtendedHeaderWithRoot(t, roots) - // getter doesn't have the eds, so it should fail + // Getter doesn't have the eds, so it should fail for all samples getter.EXPECT(). GetShare(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(libshare.Share{}, shrex.ErrNotFound). @@ -131,28 +137,26 @@ func TestSharesAvailableFailed(t *testing.T) { err = avail.SharesAvailable(ctx, eh) require.ErrorIs(t, err, share.ErrNotAvailable) - // cache should have failed results now - result, err := avail.ds.Get(ctx, rootKey(roots)) + // The datastore should now contain the sampling result with all samples in Remaining + result, err := avail.ds.Get(ctx, datastoreKeyForRoot(roots)) require.NoError(t, err) - failed, err := decodeSamples(result) + var failed []Sample + err = json.Unmarshal(result, &failed) require.NoError(t, err) require.Len(t, failed, int(avail.params.SampleAmount)) - // ensure that retry persists the failed samples selection - // create new getter with only the failed samples available, and add them to the onceGetter - onceGetter := newOnceGetter() - onceGetter.AddSamples(failed) - - // replace getter with the new one - avail.getter = onceGetter + // Simulate a getter that now returns shares successfully + successfulGetter := newOnceGetter() + successfulGetter.AddSamples(failed) + avail.getter = successfulGetter // should be able to retrieve all the failed samples now err = avail.SharesAvailable(ctx, eh) require.NoError(t, err) // onceGetter should have no more samples stored after the call - require.Empty(t, onceGetter.available) + require.Empty(t, successfulGetter.available) } type onceGetter struct { @@ -178,7 +182,7 @@ func (m onceGetter) AddSamples(samples []Sample) { func (m onceGetter) GetShare(_ context.Context, _ *header.ExtendedHeader, row, col int) (libshare.Share, error) { m.Lock() defer m.Unlock() - s := Sample{Row: uint16(row), Col: uint16(col)} + s := Sample{Row: row, Col: col} if _, ok := m.available[s]; ok { delete(m.available, s) return libshare.Share{}, nil diff --git a/share/availability/light/sample.go b/share/availability/light/sample.go index c8061cdb1e..17f6b2a59f 100644 --- a/share/availability/light/sample.go +++ b/share/availability/light/sample.go @@ -1,104 +1,42 @@ -// TODO(@Wondertan): Instead of doing sampling over the coordinates do a random walk over NMT trees. package light import ( crand "crypto/rand" - "encoding/binary" - "errors" "math/big" + + "golang.org/x/exp/maps" ) -// Sample is a point in 2D space over square. +// Sample represents a coordinate in a 2D data square. type Sample struct { - Row, Col uint16 + Row int `json:"row"` + Col int `json:"col"` } -// SampleSquare randomly picks *num* unique points from the given *width* square -// and returns them as samples. -func SampleSquare(squareWidth, num int) ([]Sample, error) { - ss := newSquareSampler(squareWidth, num) - err := ss.generateSample(num) - if err != nil { - return nil, err +// selectRandomSamples randomly picks unique coordinates from a square of given size. +func selectRandomSamples(squareSize, sampleCount int) []Sample { + total := squareSize * squareSize + if sampleCount > total { + sampleCount = total } - return ss.samples(), nil -} -type squareSampler struct { - squareWidth int - smpls map[Sample]struct{} -} - -func newSquareSampler(squareWidth, expectedSamples int) *squareSampler { - return &squareSampler{ - squareWidth: squareWidth, - smpls: make(map[Sample]struct{}, expectedSamples), - } -} - -// generateSample randomly picks unique point on a 2D spaces. -func (ss *squareSampler) generateSample(num int) error { - if num > ss.squareWidth*ss.squareWidth { - num = ss.squareWidth - } - - done := 0 - for done < num { + samples := make(map[Sample]struct{}, sampleCount) + for len(samples) < sampleCount { s := Sample{ - Row: randInt(ss.squareWidth), - Col: randInt(ss.squareWidth), + Row: randInt(squareSize), + Col: randInt(squareSize), } - - if _, ok := ss.smpls[s]; ok { - continue - } - - done++ - ss.smpls[s] = struct{}{} + samples[s] = struct{}{} } - - return nil -} - -func (ss *squareSampler) samples() []Sample { - samples := make([]Sample, 0, len(ss.smpls)) - for s := range ss.smpls { - samples = append(samples, s) - } - return samples + return maps.Keys(samples) } -func randInt(max int) uint16 { +func randInt(max int) int { n, err := crand.Int(crand.Reader, big.NewInt(int64(max))) if err != nil { panic(err) // won't panic as rand.Reader is endless } - return uint16(n.Uint64()) -} - -// encodeSamples encodes a slice of samples into a byte slice using little endian encoding. -func encodeSamples(samples []Sample) []byte { - bs := make([]byte, 0, len(samples)*4) - for _, s := range samples { - bs = binary.LittleEndian.AppendUint16(bs, s.Row) - bs = binary.LittleEndian.AppendUint16(bs, s.Col) - } - return bs -} - -// decodeSamples decodes a byte slice into a slice of samples. -func decodeSamples(bs []byte) ([]Sample, error) { - if len(bs)%4 != 0 { - return nil, errors.New("invalid byte slice length") - } - - samples := make([]Sample, 0, len(bs)/4) - for i := 0; i < len(bs); i += 4 { - samples = append(samples, Sample{ - Row: binary.LittleEndian.Uint16(bs[i : i+2]), - Col: binary.LittleEndian.Uint16(bs[i+2 : i+4]), - }) - } - return samples, nil + // n.Uint64() is safe as max is int + return int(n.Uint64()) } diff --git a/share/availability/light/sample_test.go b/share/availability/light/sample_test.go index 8d7656e688..2bdbb223b6 100644 --- a/share/availability/light/sample_test.go +++ b/share/availability/light/sample_test.go @@ -16,13 +16,12 @@ func TestSampleSquare(t *testing.T) { } for _, tt := range tests { - ss, err := SampleSquare(tt.width, tt.samples) - assert.NoError(t, err) + ss := selectRandomSamples(tt.width, tt.samples) assert.Len(t, ss, tt.samples) // check points are within width for _, s := range ss { - assert.Less(t, int(s.Row), tt.width) - assert.Less(t, int(s.Col), tt.width) + assert.Less(t, s.Row, tt.width) + assert.Less(t, s.Col, tt.width) } // checks samples are not equal for i, s1 := range ss {