Skip to content

Commit

Permalink
Integration: allow headers --reset (#3972)
Browse files Browse the repository at this point in the history
  • Loading branch information
AskAlexSharov authored Apr 26, 2022
1 parent 3906d7e commit e04f7fc
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 30 deletions.
91 changes: 62 additions & 29 deletions cmd/integration/commands/stages.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ func init() {

withDataDir(cmdStageHeaders)
withUnwind(cmdStageHeaders)
withReset(cmdStageHeaders)
withChain(cmdStageHeaders)
withHeimdall(cmdStageHeaders)

Expand Down Expand Up @@ -441,48 +442,80 @@ func init() {
rootCmd.AddCommand(cmdSetPrune)
}

// max is a helper function which returns the larger of the two given integers.
func max(a, b uint64) uint64 { //nolint:unparam
if a > b {
return a
}
return b
}

func stageHeaders(db kv.RwDB, ctx context.Context) error {
return db.Update(ctx, func(tx kv.RwTx) error {
if unwind > 0 {
if !(unwind > 0 || reset) {
log.Info("This command only works with --unwind or --reset options")
}

if reset {
progress, err := stages.GetStageProgress(tx, stages.Headers)
if err != nil {
return fmt.Errorf("read Bodies progress: %w", err)
}
if unwind > progress {
return fmt.Errorf("cannot unwind past 0")
}
if err = stages.SaveStageProgress(tx, stages.Headers, progress-unwind); err != nil {
return fmt.Errorf("saving Bodies progress failed: %w", err)
}
progress, err = stages.GetStageProgress(tx, stages.Headers)
if err != nil {
return fmt.Errorf("re-read Bodies progress: %w", err)
}
unwind = progress
}

progress, err := stages.GetStageProgress(tx, stages.Headers)
if err != nil {
return fmt.Errorf("read Bodies progress: %w", err)
}
var unwindTo uint64
if unwind > progress {
unwindTo = 1 // keep genesis
} else {
unwindTo = max(1, progress-unwind)
}

if err = stages.SaveStageProgress(tx, stages.Headers, unwindTo); err != nil {
return fmt.Errorf("saving Bodies progress failed: %w", err)
}
progress, err = stages.GetStageProgress(tx, stages.Headers)
if err != nil {
return fmt.Errorf("re-read Bodies progress: %w", err)
}
{ // hard-unwind stage_body also
if err := rawdb.DeleteNewBlocks(tx, progress+1); err != nil {
return err
}
// remove all canonical markers from this point
if err := tx.ForEach(kv.HeaderCanonical, dbutils.EncodeBlockNumber(progress+1), func(k, v []byte) error {
return tx.Delete(kv.HeaderCanonical, k, nil)
}); err != nil {
return err
}
if err := tx.ForEach(kv.HeaderTD, dbutils.EncodeBlockNumber(progress+1), func(k, v []byte) error {
return tx.Delete(kv.HeaderTD, k, nil)
}); err != nil {
return err
}
hash, err := rawdb.ReadCanonicalHash(tx, progress-1)
progressBodies, err := stages.GetStageProgress(tx, stages.Bodies)
if err != nil {
return err
return fmt.Errorf("read Bodies progress: %w", err)
}
if err = tx.Put(kv.HeadHeaderKey, []byte(kv.HeadHeaderKey), hash[:]); err != nil {
return err
if progress < progressBodies {
if err = stages.SaveStageProgress(tx, stages.Bodies, progress); err != nil {
return fmt.Errorf("saving Bodies progress failed: %w", err)
}
}
log.Info("Progress", "headers", progress)
return nil
}
log.Info("This command only works with --unwind option")
// remove all canonical markers from this point
if err := tx.ForEach(kv.HeaderCanonical, dbutils.EncodeBlockNumber(progress+1), func(k, v []byte) error {
return tx.Delete(kv.HeaderCanonical, k, nil)
}); err != nil {
return err
}
if err := tx.ForEach(kv.HeaderTD, dbutils.EncodeBlockNumber(progress+1), func(k, v []byte) error {
return tx.Delete(kv.HeaderTD, k, nil)
}); err != nil {
return err
}
hash, err := rawdb.ReadCanonicalHash(tx, progress-1)
if err != nil {
return err
}
if err = tx.Put(kv.HeadHeaderKey, []byte(kv.HeadHeaderKey), hash[:]); err != nil {
return err
}

log.Info("Progress", "headers", progress)
return nil
})
}
Expand Down
2 changes: 1 addition & 1 deletion eth/stagedsync/stage_headers.go
Original file line number Diff line number Diff line change
Expand Up @@ -1100,7 +1100,7 @@ func DownloadAndIndexSnapshotsIfNeed(s *StageState, ctx context.Context, tx kv.R
}
}

if s.BlockNumber == 0 {
if s.BlockNumber < 2 { // allow genesis
logEvery := time.NewTicker(logInterval)
defer logEvery.Stop()

Expand Down

0 comments on commit e04f7fc

Please sign in to comment.