diff --git a/consensus/model/block_chain.go b/consensus/model/block_chain.go index 418f8d8d..f46724ed 100644 --- a/consensus/model/block_chain.go +++ b/consensus/model/block_chain.go @@ -3,6 +3,7 @@ package model import ( "github.com/Qitmeer/qng/common/hash" "github.com/Qitmeer/qng/core/types" + "github.com/Qitmeer/qng/core/types/pow" ) type BlockChain interface { @@ -20,6 +21,9 @@ type BlockChain interface { Stop() error GetBlockByOrder(order uint64) Block GetBlockById(id uint) Block + GetMainChainTip() Block FetchBlockByHash(hash *hash.Hash) (*types.SerializedBlock, error) GetBlockOrderByHash(hash *hash.Hash) (uint, error) + GetBlockHeader(block Block) *types.BlockHeader + ForeachBlueBlocks(start Block, depth uint, powType pow.PowType, fn func(block Block, header *types.BlockHeader) error) error } diff --git a/core/blockchain/blockindex.go b/core/blockchain/blockindex.go index 0d967e34..4799a69b 100644 --- a/core/blockchain/blockindex.go +++ b/core/blockchain/blockindex.go @@ -7,6 +7,7 @@ import ( "github.com/Qitmeer/qng/consensus/forks" "github.com/Qitmeer/qng/consensus/model" "github.com/Qitmeer/qng/core/types" + "github.com/Qitmeer/qng/core/types/pow" "github.com/Qitmeer/qng/meerdag" ) @@ -59,6 +60,10 @@ func (b *BlockChain) GetBlockById(id uint) model.Block { return b.bd.GetBlockById(id) } +func (b *BlockChain) GetMainChainTip() model.Block { + return b.bd.GetMainChainTip() +} + // BlockOrderByHash returns the order of the block with the given hash in the // chain. // @@ -338,8 +343,9 @@ func (b *BlockChain) fetchHeaderByHash(hash *hash.Hash) (*types.BlockHeader, err return nil, fmt.Errorf("unable to find block header %v db %v", hash, err) } -func (b *BlockChain) GetBlockHeader(ib meerdag.IBlock) *types.BlockHeader { - if ib == nil { +func (b *BlockChain) GetBlockHeader(block model.Block) *types.BlockHeader { + ib, ok := block.(meerdag.IBlock) + if ib == nil || !ok { return nil } if ib.GetData() != nil { @@ -357,3 +363,16 @@ func (b *BlockChain) GetBlockHeader(ib meerdag.IBlock) *types.BlockHeader { } return header } + +func (b *BlockChain) ForeachBlueBlocks(start model.Block, depth uint, powType pow.PowType, fn func(block model.Block, header *types.BlockHeader) error) error { + return b.bd.Foreach(start.(meerdag.IBlock), depth, meerdag.Blue, func(block meerdag.IBlock) (bool, error) { + blockHeader := b.GetBlockHeader(block) + if blockHeader == nil { + return false, fmt.Errorf("No blockHeader:%s", block.GetHash().String()) + } + if blockHeader.Pow.GetPowType() != powType { + return false, nil + } + return true, fn(block, blockHeader) + }) +} diff --git a/meerdag/blocktransf.go b/meerdag/blocktransf.go index 33ef9761..17c33f9d 100644 --- a/meerdag/blocktransf.go +++ b/meerdag/blocktransf.go @@ -433,7 +433,7 @@ func (bd *MeerDAG) GetLayer(id uint) uint { return bd.GetBlockById(id).GetLayer() } -func (bd *MeerDAG) Foreach(start IBlock, depth uint, filter FilterType, fn func(block IBlock) error) error { +func (bd *MeerDAG) Foreach(start IBlock, depth uint, filter FilterType, fn func(block IBlock) (bool, error)) error { if _, ok := bd.instance.(*Phantom); !ok { return fmt.Errorf("Not support Foreach:%s", bd.instance.GetName()) } @@ -442,13 +442,15 @@ func (bd *MeerDAG) Foreach(start IBlock, depth uint, filter FilterType, fn func( cur := start for count < depth && cur != nil { if cur.GetID() != start.GetID() { - err := fn(cur) + ret, err := fn(cur) if err != nil { return err } - count++ - if count >= depth { - break + if ret { + count++ + if count >= depth { + break + } } } if cur.GetID() == 0 { @@ -465,11 +467,13 @@ func (bd *MeerDAG) Foreach(start IBlock, depth uint, filter FilterType, fn func( return fmt.Errorf("No block:id=%s", das[i]) } pBlock := block.(*PhantomBlock) - err := fn(pBlock) + ret, err := fn(pBlock) if err != nil { return err } - count++ + if ret { + count++ + } } } diff --git a/meerdag/test/phantom_test.go b/meerdag/test/phantom_test.go index dd113b3e..89348a9d 100644 --- a/meerdag/test/phantom_test.go +++ b/meerdag/test/phantom_test.go @@ -467,10 +467,10 @@ func Test_ForeachFig1(t *testing.T) { ph.UpdateVirtualBlockOrder() - err := bd.Foreach(bd.GetMainChainTip(), meerdag.MaxId, meerdag.All, func(block meerdag.IBlock) error { + err := bd.Foreach(bd.GetMainChainTip(), meerdag.MaxId, meerdag.All, func(block meerdag.IBlock) (bool, error) { t.Logf("block id:%d hash:%s order:%d", block.GetID(), block.GetHash().String(), block.GetOrder()) order = append(order, block.GetID()) - return nil + return true, nil }) if err != nil { t.Fatal(err) @@ -496,10 +496,10 @@ func Test_ForeachFig2(t *testing.T) { ph.UpdateVirtualBlockOrder() - err := bd.Foreach(bd.GetMainChainTip(), meerdag.MaxId, meerdag.All, func(block meerdag.IBlock) error { + err := bd.Foreach(bd.GetMainChainTip(), meerdag.MaxId, meerdag.All, func(block meerdag.IBlock) (bool, error) { t.Logf("block id:%d hash:%s order:%d", block.GetID(), block.GetHash().String(), block.GetOrder()) order = append(order, block.GetID()) - return nil + return true, nil }) if err != nil { t.Fatal(err) @@ -525,10 +525,10 @@ func Test_ForeachFig4(t *testing.T) { ph.UpdateVirtualBlockOrder() - err := bd.Foreach(bd.GetMainChainTip(), meerdag.MaxId, meerdag.All, func(block meerdag.IBlock) error { + err := bd.Foreach(bd.GetMainChainTip(), meerdag.MaxId, meerdag.All, func(block meerdag.IBlock) (bool, error) { t.Logf("block id:%d hash:%s order:%d", block.GetID(), block.GetHash().String(), block.GetOrder()) order = append(order, block.GetID()) - return nil + return true, nil }) if err != nil { t.Fatal(err)